Blog 3: Approaching Flow Matching Mathematically

Introduction

In the world of computational structural biology you might have heard of diffusion models as the current big thing in generative modelling. Diffusion models are great because primarily they look cool when you visualise the denoising process to generate a protein structure (checkout RFdiffusion Colab notebook), but also because they are state of the art at diverse and designable protein backbone structure generation.

Originally emerging from computer vision, a lot of work has been built up around their application to macromolecules - especially exciting is their harmonious union with geometric deep learning in the case of SE(3) equivariance (see FrameDiff). I don’t know about you but I get particularly excited about geometric deep learning, mostly because it involves objectively dope words like “manifold” and “Riemannian”, better yet “Riemannian manifolds” - woah! (see Bronstein’s geometric deep learning for more fun vocabulary to add to your vernacular- like “geodesic”, Geometric Deep Learning).

But we’re getting side tracked. Diffusion is a square to rectangle case of score-based generative models with the clause that diffusion refers explicitly to the learning of a time-dependent score function that is typically learned via a denoising process. Checkout Jakub Tomczak’s blog for more on diffusion and score-based generative models. Flow matching, although technically different to score-based generative models, also makes use of transformations to gaussian but is generally faster and not constrained to discrete time steps (or even Gaussian priors). So the big question is, how does one flow match?

This question is particularly personal to me as my current DPhil focuses heavily on utilising flow matching for solving some particularly exciting problems in Biology. Although, despite hours of nose in book literature and dreams of electric maths my understanding of the deeper theory is sketchy at best so I’m using this blog as a way to further that understanding whilst hopefully being somewhat educationally useful for others by giving verry much a noob’s walkthrough of Flow Matching. Specifically, because maths looks cool and doing it might help me get cool internships ,this blog will focus on the underpinning mathematics of it all coming from the perspective of a traditionally trained biochemist and mathematical amateur like myself.

Flow Matching is a powerful generalisation of diffusion in that it steps away from the discrete time steps denoising approach to a broader definition of a “flow” over a continuous timescale - image some time dependent field shaping/morphing the source distribution into our data target distribution. Let’s say we have a large number of images as examples that represent some higher level data distribution that encompasses the set of all images, much of the efforts of deep learning in recent years has been improving on generative models that can reliably sample from this distribution (with or without learning it). Diffusion does this by learning the inverse of a sequential noising process that transforms the data into Gaussian noise. In flow matching we similarly use a Gaussian as a source distribution but increase our abstraction by viewing each intermediary step as a probability distribution that is a diffeomorphism of the source distribution and target distribution who’s transformation from the source is described by a vector field dependent on time in the continuous space rather than a discrete set of steps. We can then sample from the model simply by integrating the learned vector field which happens to be significantly faster than the denoising process.

Flow Mat(c)h

Much of the math below I’ve recited from this fantastic resource from Meta and the original flow matching paper, and tried to wrap it in a more accessible description. This is also very much the simplest formulation of flow matching so if you’re interested in reading further, particularly for non-Euclidean approaches, I refer back to Meta’s paper.

In flow matching we aim to learn the parameters $\theta$ of a velocity field $v_t$ that acts on a flow $\psi_t(x)$ which describes the change over time of some sample $x$, where $x_0 = \psi_0(x)$ and is formulated as an ODE $\frac{d}{dt}(\psi_t(x)) = v_t(\psi_t(x))$. $v_t^{\theta}$ is complex and typically learned by neural network but is generally intractable, we will see how we resolve this later. $v_t$ provides deterministic trajectories morphing a source distribution $p_0$ to a distribution $p_t$ at time point $t \in [0,1]$ transporting distributions forward in time. Thus, $p_t$ represents a probability path that is realised into a probability distribution at any $t$. The velocity field $v_t$ is learned from a set of training data $X_1 \sim q$ functioning as an empirical approximation of an unknown underlying distribution $q$ where we assign $p_1 = q$, and a Gaussian source distribution $p_0 = \mathcal{N}(x | 0,I)$. Learning $v_t^{\theta}$ allows generative interpolation from any sample $x_0 = X_0 \sim p_0$ to $x_1 = X_1 \sim p_1$ from any time point $t$ so that the distribution of generated $p_1 \approx q$. In other words to generatively sample from $q$ we can simply sample from the Gaussian source $p_0$ then integrate to $t=1$ (or an earlier time point to achieve results similar to partial diffusion).

A quick note on definitions to avoid some of the confusion I encountered:

  • $q$ is the underlying unknown probability distribution we aim to model and from which we have some samples of data which makes up our empirical training set approximation $p_1 \approx q$.
  • $X_1$ is the random variable $X_1 \sim p_1$ which represents performing a single draw from the training set.
  • $x_1$ is the realisation or actual value of the the draw $X_1$, in the case of the MNIST dataset it would be a single image in the training set.
  • Similarly, $X_0$ is the random variable $X_0 \sim p_0$ where $p_0$ is our source distribution which we choose to be something we can easily sample from - a Gaussian in this case.
  • So $p_t$ is a path of probability distributions (probability path) which is a function that returns a probability distribution for a given timepoint $t$.

The flow $\psi_t(x)$ and velocity field $v_t$ are a bit more abstract. The flow $\psi_t(x)$ describes the transport of a sample through time which I interpret as it being a mapping function that simply maps a sample $x_0$ at $t=0$ to its respective position at $t$. This is the determinism of flow matching. The flow is often also described as a push-forward operator but that is perhaps more useful when discussing non-euclidean spaces. $v_t$ can be seen as a forcefield that gives every point $\psi_t(x)$ in the space of each distribution in $p_t$ a direction and speed telling it where it goes next.

The core approach of Flow Matching, specifically Conditional Flow Matching (CNF) as introduced by Lipman et al., is how we construct the probability path $p_t$ for learning the vector field $v_t^{\theta}$. Specifically, we can formulate $p_t$ as an average over the set of all conditional probability paths ${p_t}$ for each different data endpoint (training point) within $X_1 = x_1$. With a Gaussian source distribution we get the following conditional for each separate datapoint in $x_1$: $$p_{t|1}(x|x_1) = \mathcal{N}(x|tx_1, (1-t)^2I)$$ This is useful as it makes things actually tractable. We can recover the full marginal probability path $p_t(x)$ over all $x_1$ by essentially taking an average over each separate conditional datapoint in the empirical training set - something known as mixture of conditionals: $$p_t(x) = \int p_{t|1}(x|x_1)q(x_1)dx_1$$ Note this is specifically a marginalisation of $X_1$ with respect to $q$ which is not equivalent to the standard marginalisation over a joint distribution as that applies to cases of multiple variables which would be those at other time points hence why the below is an average and does not make use of the product rule. Apologies if this was obvious but it was a point of terminology confusion for me.

In the case of $t=0$ we see how the above resolves into the source distribution: $$p_{0|1}(x|x_1) = \mathcal{N}(x|0 \cdot x_1, (1-0)^2I)$$

Where the mean of the Normal becomes $0 \times x_1 = 0$: $$p_0(x) = \int \mathcal{N}(x|0,I)q(x_1) dx_1 = \mathcal{N}(x|0,I)$$

For $t=1$ we see how with mean $tx_1 = x_1$ and variance $(1-t)^2I = 0$ we get $p_{1|1}(x|x_1) = \mathcal{N}(x|x_1,0)$ which becomes the Dirac delta measure $\delta(x-x_1)$ which is essentially assigning all probability mass to $x=x_1$, therefore: $$p_1(x) = \int \delta(x - x_1)q(x_1)dx_1 = q(x)$$

Having recovered our source and target distributions with this formulation we have shown how $p_t(x)$ satisfies a conditional optimal-transport, also known as a linear path, which allows us to define the random variable $X_t \sim p_t$ as a linear combination of $X_0 \sim p$ and $X_1 \sim q$: $$X_t = tX_1 + (1-t)X_0 \sim p_t$$

Flow matching is constructed in this way as it enables the above closed form and tractable formulation of a probability path $p_t$ with smooth interpolation between the Gaussian source and target distribution allowing generation of samples from any intermediate distribution by a linear combination of a Gaussian and the data. In other words we are defining a linear interpolation between pairs of samples from $X_0$ and $X_1$ that is deterministic and linear in time from start to end corresponding to the straight line (or geodesic) in Euclidean space.

This linear path is associated with the true velocity field $v_t$ which we can now learn by randomly sampling timepoints $t \sim \mathcal{U}[0,1]$ and determining $x_t$ by the above linear interpolation allowing us to learn a parameterised $v_t^{\theta}(x)$ by neural network with the following Mean Squared Error (MSE) loss:

$$ \mathcal{L}_{\text{FM}}^{\theta} = \mathbb{E}_{t,X_t} ||v_t^{\theta}(X_t) - v_t(X_t) ||^2 $$

When conditioned on a single randomly selected target example $X_1 = x_1$ we reduce the complexity of this joint over two high-dimensional distributions $p(X_t, X_1)$ to a lower dimension $p(X_t | X_1 = x_1)$ by fixing the endpoint which allows tractable sampling. Thus, we adjust the linear combination to the conditional case:

$$ X_{t|1} = tx_1 + (1-t)X_0 \quad \sim \quad p_{t|1}(\cdot|x_1) = \mathcal{N}(\cdot | tx_1, (1-t)^2I) $$

By differentiating the above with respect to $t$ we can get the rate of change of $X_{t|1}$ which is its instantaneous velocity given the fixed endpoint $x_1$:

$$ \frac{d}{dt} X_t = x_1 - X_0 $$

So, for the conditional process we simply get a constant vector of $x_1 - X_0$. Assuming we have a randomly selected timepoint $t \sim \mathcal{U}[0,1]$ and have drawn a sample $x = X_t \sim p_t$ at that time point then we also have $x = X_t = tx_1 + (1-t)X_0$ where in the case of $t \neq 1$ (as that would lead to division by 0) we rearrange to get $X_0$: $$ X_0 = \frac{x - tx_1}{1-t} $$

Which we plug into our conditional velocity leading to the instantaneous velocity at $t$ given the current point $x$ and corresponding datapoint $x_1$: $$ v_t(x|x_1) = \frac{x_1 - x}{1 - t} $$

This generates the conditional probability path $p_{t|1}(x_t|x_1)$. Now this simple conditional velocity field can be used to rewrite the previous loss function as the conditional expectation over the posterior of possible endpoints $X_1$ conditioned on the current state $X_t = x$:

$$ v_t(x) = \mathbb{E}[v_t(x|X_1) | X_t = x] = \frac{\mathbb{E}[X_1 | X_t=x] - x}{1-t} $$

Showing how the instantaneous velocity of $x$ is the posterior expectation $\mathbb{E}[X_1 | X_t=x]$ of the training data minus the current position and divided by the time remaining - simple enough! This can also be seen as a weighted average over the probability paths:

$$ v_t(x) = \int v_t(x|x_1)\frac{p_{t|1}(x|x_1)q(x_1)}{p_t(x)}dx_1 $$

Which similarly gives the marginal vector field. So, we are left with a simple training recipe where, upon sampling a random timepoint $t$, $X_0$, $x_1 = X_1$, and interpolating $X_t = tX_1 + (1-t)X_0$, we train the learnable vector field $v_t^{\theta}$ by averaging over the conditional velocity vectors for training-source pairs then regressing the learnable vector field $v_t^{\theta}$ to the conditional velocity vector $\frac{x_1 - x_t}{1-t}$ (see the algorithm at the end).

A final powerful observation, who’s derivation I leave to the pros in the referenced papers, is that both the marginal and conditional velocity field loss functions have the same gradients for learning $v_t^{\theta}$:

$$ \bigtriangledown_{\theta}\mathcal{L}_{\text{FM}}(\theta) = \bigtriangledown_{\theta}\mathcal{L}_{\text{CFM}}(\theta) $$

This represents one of the most desirable properties of Flow Matching (specifically Conditional Flow Matching - CFM) which allows one to train just using per-sample conditionals. In other words, we Flow Match on individual samples. Thus, the simplest implementation of Flow Matching using a Gaussian source distribution with training data in the Euclidean space and exploiting the conditional loss gives the following final form:

$$ \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t,X_0,X_1}||v_t^{\theta}(X_t) - (X_1 - X_0) ||^2 $$

We can also define the more general realisation which applies to non-Euclidean space when considering $\psi_t(x)$ as the appropriate push-forward:

$$ \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t,X_0,X_1}||v_t^{\theta}(\psi_t(x_0)) - \frac{d}{dt}\psi_t(x_0)||^2 $$

Where in the Euclidean space the second differential term simplifies to $x_1 - x_0$. Things get even more interesting applying the above generalised definition to the Riemannian world but this blog is long enough and has depleted my IQ reserves. So, to answer the question of how to flow match? Do it on individual samples!

Vs Diffusion

For reference, here are the pros of Flow Matching against typical Diffusion:

  • Faster. More efficient sampling as we only need to solve an ODE.
  • Sample from any timepoint without needing to employ tricks like consistency models.
  • Determinism
  • Beyond Gaussian source distributions

Everyone Loves an Algorithm

Given source distribution $p_0 = \mathcal{N}(0,I)$ and target distribution $p_1 = q$ defined by a set of training data examples $x_1 = X_1 \sim q$, learn a vector field $v_t^{\theta}$ with parameters $\theta$ with neural network that takes as input $(x,t)$. Train by repeating for $N$ epochs:

  • Draw data example $x_1 \sim q$ (or minibatch of data points)
  • Sample $x_0 \sim p_0$ (or minibatch of samples)
  • Sample $T$ timepoints $t \sim \mathcal{U}[0,1], \quad t\neq 1$ for each sample (or minibatch of samples)
  • Determine $x_t$ by linear interpolation: $x = tx_1 + (1-t)x_0$
  • Compute per-sample target conditional velocity: $v_t(x|x_1) = \frac{x_1 - x}{1 - t} = x_1 - x_0$
  • Predict estimate of $v_t^{\theta}(x_t)$ from neural network for each sample
  • Compute loss: $\mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t,x_0,x_1} || v_t^{\theta}(x_t) - (x_1 - x_0)||^2$
  • To sample we can simply draw $x_0 \sim p_0$ and integrate the ODE $v_t^{\theta}(x,t)$ over $t = [0:1]$ with a standard ODE solver.

Footnote $p_{t|1}$ is notational weirdness that means $p_t(x|X_1 = x_1)$ meaning probability distribution at time $t$ conditioned on the single endpoint sample $x_1$ NOT time $t$ given 1 or given $t=1$. For example $x_1$ can be a single image from the MNIST set where $X_1$ is the random variable meaning when we sample the random variable $X_1$ representing the training set we draw a particular instance or single example image $x_1$

0%