A Flowing Intro to Flow Matching

Foreword


Flow matching is a generative modelling technique developed in 2022 by Yaron Lipman et al. A paper is a highly processed artifact which often does not tell you the full story of how the authors discovered the results; rather, results are presented in a formal tone in the main text and the appendix contains proofs which just verify the result but don't develop them. In this blog post I try to smoothly develop the main ideas and results, ideally so that you feel like you could have come up with it yourself. This is the way I feel when I read in Feynman's lectures on physics and this introduction tries to approximate that. You can view this blog post as a middle-ground between the paper and the very extensive "official" 83-page long Flow Matching Guide and Code, also by Yaron Lipman and others. Plus, here and there I give more context than the paper (for example on the continuity equation) and provide 3D-interactive visualizations. Have fun!

Generative Modeling

The abstract mathematical goal of generative modeling is to draw samples from an unknown data distribution, unknown in the sense that we only have a finite number of samples, but no explicit representation like a formula for its density or an algorithm to obtain a ground truth sample from it.

On a high level, the basic idea of generative modeling is to:

  1. First take a sample x0\boldsymbol{x}_0 in latent space from some simple distribution which is easy to sample from, e.g., x0N(0,I)\boldsymbol{x}_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}) via the Box-Muller transform.

  2. Then apply a (potentially stochastic) transformation ϕ\phi derived from the data to obtain x1=ϕ(x0)\boldsymbol{x}_1 = \phi(\boldsymbol{x}_0).

Ideally, ϕ\phi is chosen such that the distribution of the random variable x1\boldsymbol{x}_1 is as close as possible to the data distribution qq, in symbols, x1q\boldsymbol{x}_1 \sim q.

Flows and Density Paths

You can imagine that it is difficult to learn such a transformation ϕ\phi directly via a KL-divergence, Wasserstein distance or something similar. This motivates to do it with intermediate steps. But modelling the transformation with a finite number of composed learned transformations, such as in Normalizing Flows, comes with several fundamental drawbacks, most notably constraints on expressivity and issues with computational scalability in practice. So why not model the transformation as a continuous path? Why not learn a time-dependent transformation ϕ:[0,1]×RdRd:(t,x)ϕt(x)\phi: [0,1] \times \mathbb{R}^d \to \mathbb{R}^d: (t, \boldsymbol{x}) \mapsto \phi_t(\boldsymbol{x}) with ϕ0=id\phi_0 = \text{id} and ϕ1=ϕ\phi_1 = \phi? Well, learning parameters for every t[0,1]t \in [0,1] is of course infeasible. And taking the time as an input to a parametric model doesn't make the problem any easier.

But there is a remedy. To this end, let us think of ϕ\phi as moving a sample x0\boldsymbol{x}_0 through the data space via the parametric curve tϕt(x0)t \mapsto \phi_t(\boldsymbol{x}_0). If ϕt\phi_t is smooth, for example diffeomorphic for every t[0,1]t \in [0,1], then it is not far fetched to think of ϕ\phi as a flow which flows particles (samples) from their initial position to another place.

Definition: Flow

For our purposes, a flow is a time-dependent map ϕ:[0,1]×RdRd\phi: [0,1] \times \mathbb{R}^d \to \mathbb{R}^d.

Together with an initial distribution x0N(0,Id)\boldsymbol{x}_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_d), a flow induces a continuous path of distributions, namely the distribtions of the random variables xt:=ϕt(x0)\boldsymbol{x}_t := \phi_t(\boldsymbol{x}_0). If every ϕt\phi_t is diffeomorphic, every distribution on the path has a density and the corresponding probability density path pt:Rd[0,)p_t: \mathbb{R}^d \to [0,\infty) can be computed via a change of variables as:

pt(xt)=p0(ϕt1(xt))detϕt1(xt)=p0(x0)detϕt(x0).\begin{align*} p_t(\boldsymbol{x}_t) = p_0(\phi_t^{-1}(\boldsymbol{x}_t))|\det \nabla \phi_t^{-1}(\boldsymbol{x}_t)| = \frac{p_0(\boldsymbol{x}_0)}{|\det \nabla \phi_t(\boldsymbol{x}_0)|} \end{align*}.

Below you can see an animation of a probability density path which transforms a standard two-variate normal distribution x0N(0,I2)\boldsymbol{x}_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_2) into x1N((21)T,12I2)\boldsymbol{x}_1 \sim \mathcal{N}\left(\begin{pmatrix}2 & 1\end{pmatrix}^T, \frac{1}{2}\boldsymbol{I}_2 \right) via linear interpolation of the means and standard deviations, i.e., ϕt(x0)=xt=σtx0+μt\phi_t(\boldsymbol{x}_0) = \boldsymbol{x}_t = \sigma_t\boldsymbol{x}_0 + \boldsymbol{\mu}_t and pt(x)=N(x;μt,σt2I2)p_t(\boldsymbol{x}) = \mathcal{N}(\boldsymbol{x};\,\boldsymbol{\mu}_t, \sigma_t^2\boldsymbol{I}_2) with

μt=(1t)0+t(21)Tandσt=(1t)1+t12.\begin{align*} \boldsymbol{\mu}_t = (1-t)\boldsymbol{0} + t\begin{pmatrix}2 & 1\end{pmatrix}^T \quad \text{and} \quad \sigma_t = (1-t)\cdot 1 + t\cdot\frac{1}{\sqrt{2}} \end{align*}.
Drag to rotate • Scroll to zoom
t=0.00t = 0.00

In this context a flow is also called a continuous normalizing flow (CNF) as the associated distributions Pxt1\mathbb{P} \circ \boldsymbol{x}^{-1}_t are probability distribution, i.e., (Pxt1)(Rd)=1.(\mathbb{P} \circ \boldsymbol{x}^{-1}_t)(\mathbb{R}^d) = 1. (There is also the concept of signed measures with a total mass of greater than one.)

A flow in turn can be induced by a time-dependent vector field vt:RdRd\boldsymbol{v}_t: \mathbb{R}^d \to \mathbb{R}^d. Intuitively, the vector field specifies how to locally simulate a sample at a given time. Formally, we can define the flow to be the solution of a collection of initial value problems with ordinary differential equations

ϕt(x0)t=vt(ϕt(x0))\begin{align*} \frac{\partial \phi_t(\boldsymbol{x}_0)}{\partial t} = \boldsymbol{v}_t(\phi_t(\boldsymbol{x}_0)) \end{align*}

and initial conditions ϕ0(x0)=x0\phi_0(\boldsymbol{x}_0) = \boldsymbol{x}_0. Under sufficient regularity conditions on the vector field, there exists exactly one flow satisfying these conditions. For example, according to the global version of the Picard-Lindelöf theorem a sufficient condition is that v\boldsymbol{v} is continuous and satisfies a global Lipschitz condition in its second argument, i.e., there exists a global constant L>0L > 0 such that

vt(x1)vt(x2)Lx1x2\begin{align*} \|\boldsymbol{v}_t(\boldsymbol{x}_1) - \boldsymbol{v}_t(\boldsymbol{x}_2)\| \leq L\|\boldsymbol{x}_1 - \boldsymbol{x}_2\| \end{align*}

for all x1,x2Rd\boldsymbol{x}_1, \boldsymbol{x}_2 \in \mathbb{R}^d and t[0,1]t \in [0,1]. In such cases it also makes sense to speak of the probability path induced some initial distribution and the vector field. The distributions of this probability path must not necessarily be absolutely continuous (with respect to the Lebesgue measure). However, if the vector field is sufficiently smooth in its second argument, then the flow is diffeomorphic at every time and hence we get a probability density path.

As we are going to see, it turns out that a particular characterization of when exactly a given probability density path is induced by a given vector field is of central importance for developing a tractable training algorithm. Concretely, this characterization is given by the satisfaction of the differential form of the continuity equation, a.k.a., mass conservation equation:

ptt+(vtpt)=0.\begin{align*} \frac{\partial p_t}{\partial t} + \nabla \cdot (\boldsymbol{v}_t p_t) = 0 \end{align*}.

That means, at any point in space the rate of change in density is equal to the convergence (negative divergence) of the flux vtpt\boldsymbol{v}_t p_t at that point. The flux of a quantity has dimension "amount of quantity flowing through a unit area per unit time." So it is a measure for how much quantity flows through an infinitesimal disk centered at a point and oriented orthogonal to the velocity vector in an infinitesimal time interval, normalized for the disk's area and the time interval length. And the convergence (divergence) of a velocity field is a measure for how much of a uniformly distributed quantity is destroyed (created) at a point per unit time. Given these interpretations, the continuity equation should roughly make sense. In the expandable box below you can find an informal derivation of the continuity equation using the even more intuitive integral form of the continuity equation and Gauss's divergence theorem.

Conditional --> Marginal

The flow matching idea starts with the observation that for learning a flow it is enough to learn a suitable deep parametric model of a vector field. At inference time one can then simulate a sample via some numerical ODE method according to the vector field, e.g., Euler integration:

x(k+1)/N=xk/N+1Nvt(xk/N)\begin{align*} \boldsymbol{x}_{(k+1)/N} = \boldsymbol{x}_{k / N} + \frac{1}{N}\,\boldsymbol{v}_t(\boldsymbol{x}_{k / N}) \end{align*}

where NN is the number of steps. If the numerical ODE solver is differentiable, and they usually are, one could setup a loss function measuring the difference between two distributions using two corresponding sample cohorts and then train the vector field's parameters by a first order gradient descent based method. However, this involves simulation during training and is there not nearly as scalable as other generative modeling methods, when using a fine simulation resolution, i.e., a large number of steps. Ideally, we would like to train a CNF simulation free with a very simple and scalable objective, like directly regressing the vector field weighted by the distribution of the input variable. The following thought might seem non-sensical at first. But bear with me. If we already knew a vector field vt\boldsymbol{v}_t inducing some ptp_t with p0N(0,Id)p_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_d) and p1qp_1 \sim q, we could use it to define an ideal loss function

LFM(θ)=Et[0,1],xtpt[vt(xt;θ)vt(xt)22].\begin{align*} \mathcal{L}_{\text{FM}}(\boldsymbol{\theta}) = \mathbb{E}_{t\sim[0,1],\,\boldsymbol{x}_t\sim p_t} \left[ \|\boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta}) - \boldsymbol{v}_t(\boldsymbol{x}_t)\|^2_2 \right] \end{align*}.

And now comes the key idea which sparking the development of the whole method. While we don't know how a good vt\boldsymbol{v}_t looks like, we have an idea what vt\boldsymbol{v}_t and ptp_t are supposed to do conditioned on knowing where xt\boldsymbol{x}_t ends up, i.e., conditioned on x1\boldsymbol{x}_1.

Key Idea of Simulation Free CNF Training

Define pt(xx1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) reasonbly and find a vt(xx1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) which induces it. Then marginalize the conditional vector field such that it induces the marginalized density path.

For ptp_t we want to define a pt(xx1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) with p0(xx1)=N(x;0,Id)p_0(\boldsymbol{x}\,|\,\boldsymbol{x}_1) = \mathcal{N}(\boldsymbol{x};\,\boldsymbol{0}, \boldsymbol{I}_d) and p1(xx1)=δ(xx1)p_1(\boldsymbol{x}\,|\,\boldsymbol{x}_1) = \delta(\boldsymbol{x} - \boldsymbol{x}_1). But let us treat the more general case of a normal distribution p1(xx1)=N(x;x1,σmin2Id)p_1(\boldsymbol{x}\,|\,\boldsymbol{x}_1) = \mathcal{N}(\boldsymbol{x};\,\boldsymbol{x}_1, \sigma_{\text{min}}^2\boldsymbol{I}_d) for some (very small) standard deviation σmin0\sigma_{\text{min}} \geq 0 with N(x;x1,0):=δ(xx1)\mathcal{N}(\boldsymbol{x};\,\boldsymbol{x}_1, \boldsymbol{0}) := \delta(\boldsymbol{x} - \boldsymbol{x}_1).

Above we already saw one possible way how to do that: Linearly interpolating in parameter space. Let us slightly generalize it by using mean and standard deviation paths which might depend on x1\boldsymbol{x}_1:

pt(xx1):=N(x;μt(x1),σt(x1)2Id).\begin{align*} p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) := \mathcal{N}(\boldsymbol{x};\,\boldsymbol{\mu}_t(\boldsymbol{x}_1), \sigma_t(\boldsymbol{x}_1)^2\boldsymbol{I}_d) \end{align*}.

By marginalizing over x1\boldsymbol{x}_1 we then get pt(x)=pt(xx1)p1(x1)dx1p_t(\boldsymbol{x}) = \int p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) p_1(\boldsymbol{x}_1)\,d\boldsymbol{x}_1 which we can interpret as a continuous Gaussian mixture model. Actually, if the dataset is an empirical distribution with a finite number of samples it is exactly a classical Gaussian mixture model:

pt(x)=pt(xx1)(1Ni=1Nδ(x1x1(i)))dx1=1Ni=1Npt(xx1(i)).\begin{align*} p_t(\boldsymbol{x}) = \int p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) \left(\frac{1}{N} \sum_{i=1}^N \delta(\boldsymbol{x}_1 - \boldsymbol{x}^{(i)}_1) \right)\,d\boldsymbol{x}_1 = \frac{1}{N} \sum_{i=1}^N p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1^{(i)}) \end{align*}.

Below you can see an animation of such a path where the data samples are drawn uniformly from an annulus with inner radius 2 and outer radius 3 and σmin=0.2\sigma_{\text{min}} = 0.2. I encourage you to play around with the slider to see how the path changes when chaning the number of data points. Especially for 1000 data points the dataset shape becomes obvious!

Drag to rotate • Scroll to zoom
t=0.00t = 0.00
# dataset points: 20

But what is a vt(xx1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) (or equivalently ϕt(x0x1)\phi_t(\boldsymbol{x}_0\,|\,\boldsymbol{x}_1)) which induces pt(xx1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)? The continuity equation tells us that if a solution exists, there are infinitely many as we can take any divergence free vector field, devide by ptp_t, and add it to the one we found. Since the divergence operator is linear, the new vector field still satisfies the continuity equation. One concrete option for our Gaussian path above is obviously given by ϕt(x0x1):=σt(x1)x0+μt(x1).\phi_t(\boldsymbol{x}_0\,|\,\boldsymbol{x}_1) := \sigma_t(\boldsymbol{x}_1)\boldsymbol{x}_0 + \boldsymbol{\mu}_t(\boldsymbol{x}_1). And since

tϕt(x0x1)=σ˙t(x1)x0+μt˙(x1)=σ˙t(x1)σt(x1)(ϕt(x0x1)μt(x1))+μ˙t(x1)\frac{\partial}{\partial t}\phi_t(\boldsymbol{x}_0\,|\,\boldsymbol{x}_1) = \dot{\sigma}_t(\boldsymbol{x}_1)\boldsymbol{x}_0 + \dot{\boldsymbol{\mu}_t}(\boldsymbol{x}_1) = \frac{\dot{\sigma}_t(\boldsymbol{x}_1)}{\sigma_t(\boldsymbol{x}_1)}(\phi_t(\boldsymbol{x}_0\,|\,\boldsymbol{x}_1) - \boldsymbol{\mu}_t(\boldsymbol{x}_1)) + \dot{\boldsymbol{\mu}}_t(\boldsymbol{x}_1)

we obtain that vt(xx1)=σ˙t(x1)σt(x1)(xμt(x1))+μ˙t(x1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) = \frac{\dot{\sigma}_t(\boldsymbol{x}_1)}{\sigma_t(\boldsymbol{x}_1)} (\boldsymbol{x} - \boldsymbol{\mu}_t(\boldsymbol{x}_1)) + \dot{\boldsymbol{\mu}}_t(\boldsymbol{x}_1) induces pt(xx1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1).

So far, so good. But how do we "marginalize" the conditional vector field such that it induces the desired marginal density path? It seems plausible that an ansatz like

vt(x):=vt(xx1)f(x,x1)dx1\boldsymbol{v}_t(\boldsymbol{x}) := \int \boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) f(\boldsymbol{x},\boldsymbol{x}_1)\,d\boldsymbol{x}_1

for some function ff might work. So let us try to find ff which makes the continuity equation work. Recall that

ptt+(vtpt)=0.\begin{align*} \frac{\partial p_t}{\partial t} + \nabla \cdot (\boldsymbol{v}_t p_t) = 0 \end{align*}.

is necessary and sufficient for vt\boldsymbol{v}_t to induce ptp_t. If we substitute with the definition of pt\boldsymbol{p}_t and the ansatz for vt\boldsymbol{v}_t, and differentiate under the integral sign with Leibniz's rule, we want to have for every x and t\boldsymbol{x}~\text{and}~t that

(tpt(xx1)p1(x1)+(pt(x)vt(xx1)f(x,x1)))dx1=!0.\begin{align*} \int \left(\frac{\partial}{\partial t}p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1) + \nabla \cdot (p_t(\boldsymbol{x}) \boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)f(\boldsymbol{x},\boldsymbol{x}_1))\right)\,d\boldsymbol{x}_1 \overset{!}{=} 0 \end{align*}.

Now the idea in the design of ff is to leverage the fact that vt(xx1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) and pt(xx1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) fullfil the continuity equation. If ff gets rid of pt(x)p_t(\boldsymbol{x}) and contributes the factor pt(xx1)p1(x1)p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1), we can see that the integrand can be factored such that one factor is our conditional continuity equation! So when setting

f(x,x1)=pt(xx1)p1(x1)pt(x)\begin{align*} f(\boldsymbol{x},\boldsymbol{x}_1) = \frac{p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1)}{p_t(\boldsymbol{x})} \end{align*}

the left side is nothing but

(tpt(xx1)+(pt(xx1)vt(xx1)))=0p1(x1)dx1=0\int \underbrace{\left(\frac{\partial}{\partial t}p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) + \nabla \cdot (p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1))\right)}_{=0}p_1(\boldsymbol{x}_1)\,d\boldsymbol{x}_1 = 0

which is exactly what we want. One caveat here is that pt(x)p_t(\boldsymbol{x}) must have full support. But requiring that the conditional densities have full support is sufficient for this to be true. Let us summarize the result:

How to aggregate conditional densities and vector fields

If a conditional density path pt(xx1)>0p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) > 0 is induced by a vector field vt(xx1)\boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1), then

vt(x)=vt(xx1)pt(xx1)p1(x1)pt(x)dx1\begin{align*} \boldsymbol{v}_t(\boldsymbol{x}) = \int \boldsymbol{v}_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1) \frac{p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1)}{p_t(\boldsymbol{x})}\,d\boldsymbol{x}_1 \end{align*}

induces

pt(x)=pt(xx1)p1(x1)dx1.\begin{align*} p_t(\boldsymbol{x}) = \int p_t(\boldsymbol{x}\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1)\,d\boldsymbol{x}_1 \end{align*}.

In hindsight, we can make a nice interpretation of the way the conditional vector field is to be marginalized: By Bayes' theorem the weighting factor is exactly the conditional density of x1\boldsymbol{x}_1 given the random variable xt.\boldsymbol{x}_t. So we may also write

vt(xt)=vt(xtx1)p(x1xt)dx1.\begin{align*} \boldsymbol{v}_t(\boldsymbol{x}_t) = \int \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) p(\boldsymbol{x}_1\,|\,\boldsymbol{x}_t)\,d\boldsymbol{x}_1 \end{align*}.

This makes a lot of sense! Intuitively, when simulating a particle and deciding where to move next, we have to not only take into account knowledge of the final distribution, but also of the current observation since the already passed time and taken path has rendered some final destinations infeasible (conditional density of 0).

Conditional Objective

Let us now return the ideal flow matching objective

LFM(θ)=Et[0,1],xtpt[vt(xt;θ)vt(xt)22].\begin{align*} \mathcal{L}_{\text{FM}}(\boldsymbol{\theta}) = \mathbb{E}_{t\sim[0,1],\,\boldsymbol{x}_t\sim p_t} \left[ \|\boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta}) - \boldsymbol{v}_t(\boldsymbol{x}_t)\|^2_2 \right] \end{align*}.

It is still intractable as we don't know how to sample from ptp_t, let alone how to compute vt(xt)\boldsymbol{v}_t(\boldsymbol{x}_t). But instead we can sample conditionally and then try to match the conditional vector field:

LCFM(θ)=Et[0,1],x1p1,xtpt(xtx1)[vt(xt;θ)vt(xtx1)22].\begin{align*} \mathcal{L}_{\text{CFM}}(\boldsymbol{\theta}) = \mathbb{E}_{t\sim[0,1],\,\boldsymbol{x}_1\sim p_1,\,\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)} \left[ \|\boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta}) - \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)\|^2_2 \right] \end{align*}.

This is finally a tractable and scalable objective for training a CNF. But how does it compare to the ideal flow matching objective? Is it giving us the same set of parameters when minimized? First note that both terms inside the (linear) expectation operators are of the form (ab)(ab)=aa2ab+bb(\boldsymbol{a}-\boldsymbol{b})^\top (\boldsymbol{a}-\boldsymbol{b}) = \boldsymbol{a}^\top \boldsymbol{a} - 2\boldsymbol{a}^\top \boldsymbol{b} + \boldsymbol{b}^\top \boldsymbol{b}. Let us compare them one by one, but let us forget about the marginalization of tt as this is identical for both objectives.

The aa\boldsymbol{a}^\top \boldsymbol{a} terms are indentical since

vt(xt;θ)22pt(xtx1)p1(x1)dxtdx1=vt(xt;θ)22p(xt,x1)dxt=pt(xt)dx1.\int \|\boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta})\|^2_2\,p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1)\,d\boldsymbol{x}_td\boldsymbol{x}_1 = \int \|\boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta})\|^2_2 \underbrace{\int p(\boldsymbol{x}_t, \boldsymbol{x}_1)\,d\boldsymbol{x}_t}_{=p_t(\boldsymbol{x}_t)}d\boldsymbol{x}_1.

Note that we assumed that chaning the integration order is fine. According to Fubini's theorem the Lebesgue integrability of the integrand is sufficient for this to be true. It is fair to say that for virtually all practical purposes pt(xtx1)p1(x1)p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1) vanishes sufficiently quickly for the integral to be finite, e.g., when pt(xtx1)p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) is Gaussian. Nevertheless, one should keep this in mind when modeling crazy data distributions.

The mixed terms are also identical for a change of integration order yields

vt(xt;θ)vt(xtx1)pt(xtx1)p1(x1)dxtdx1=vt(xt;θ)(vt(xtx1)p(xt,x1)pt(xt)dx1)=vt(xt)pt(xt)dxt.\begin{align*} \int \boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta}) ^\top\boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1)p_1(\boldsymbol{x}_1)\,d\boldsymbol{x}_td\boldsymbol{x}_1 = \int \boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta})^\top \underbrace{\left(\int \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) \frac{p(\boldsymbol{x}_t, \boldsymbol{x}_1)}{p_t(\boldsymbol{x}_t)}\,d\boldsymbol{x}_1\right)}_{=\boldsymbol{v}_t(\boldsymbol{x}_t)} p_t(\boldsymbol{x}_t)\,d\boldsymbol{x}_t. \end{align*}

And what about the bb\boldsymbol{b}^\top \boldsymbol{b} terms? Well, the clue is that they don't depend on θ\boldsymbol{\theta}. So the objectives are identical up to an additive constant which doesn't depend on θ\boldsymbol{\theta} which means they have the same set of minimizers

Equivalence of the conditional and naive flow matching objectives

The conditional flow matching objective and the naive one are identical up to an additive constant, or equivalently, they have the same gradient.

So first order methods like Adam behave the same if we choose LCFM\mathcal{L}_{\text{CFM}} as a surrogate for LFM\mathcal{L}_{\text{FM}}.

Concrete Conditional Path

Now that we have a tractable and scalable objective for training a CNF, we can start to think about concrete conditional probability paths.

Unarguably, the most natural choice are the straightforward mean/std linear interpolations, i.e.,

μt(x1)=tx1σt(x1)=(1t)1+tσmin.\begin{align*} \boldsymbol{\mu}_t(\boldsymbol{x}_1) = t\boldsymbol{x}_1 && \sigma_t(\boldsymbol{x}_1) = (1-t)\cdot 1 + t\sigma_{\text{min}}\end{align*}.

Their derivatives are μ˙t(x1)=x1\dot{\boldsymbol{\mu}}_t(\boldsymbol{x}_1) = \boldsymbol{x}_1 and σ˙t(x1)=σmin1\dot{\sigma}_t(\boldsymbol{x}_1) = \sigma_{\text{min}} - 1. So according to our derived formula above

vt(xtx1)=σ˙t(x1)σt(x1)(xtμt(x1))+μ˙t(x1)=σmin11(1σmin)t(xttx1)+x1.\begin{align*} \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) = \frac{\dot{\sigma}_t(\boldsymbol{x}_1)}{\sigma_t(\boldsymbol{x}_1)}(\boldsymbol{x}_t - \boldsymbol{\mu}_t(\boldsymbol{x}_1)) + \dot{\boldsymbol{\mu}}_t(\boldsymbol{x}_1) = \frac{\sigma_{\text{min}} - 1}{1 - (1 - \sigma_{\text{min}})t}(\boldsymbol{x}_t - t\boldsymbol{x}_1) + \boldsymbol{x}_1 \end{align*}.

In the conditional flow matching objective xt\boldsymbol{x}_t is sampled according to xtpt(xtx1)\boldsymbol{x}_t \sim p_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1). Equivalently, we can sample x0N(0,Id)\boldsymbol{x}_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_d) and then take xt:=ϕt(x0x1)=σt(x1)x0+μt(x1)=(1(1σmin)t)x0+tx1.\boldsymbol{x}_t := \phi_t(\boldsymbol{x}_0\,|\,\boldsymbol{x}_1) = \sigma_t(\boldsymbol{x}_1)\boldsymbol{x}_0 + \boldsymbol{\mu}_t(\boldsymbol{x}_1) = (1-(1-\sigma_{\text{min}})t)\boldsymbol{x}_0 + t\boldsymbol{x}_1. Plugging this into the above formula for the conditional vector field terms cancel out and we get

vt(xtx1)=(σmin1)x0+x1.\begin{align*} \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) = (\sigma_\text{min} - 1)\boldsymbol{x}_0 + \boldsymbol{x}_1 \end{align*}.

In the limit of σmin0\sigma_{\text{min}} \to 0 this simplifies even further to limσmin0vt(xtx1)=x1x0\lim_{\sigma_{\text{min}} \to 0} \boldsymbol{v}_t(\boldsymbol{x}_t\,|\,\boldsymbol{x}_1) = \boldsymbol{x}_1 - \boldsymbol{x}_0. This yields the (for its simplicity) most popular and much celebrated flow matching objective of

LCFM(θ)=Et[0,1],x1p1,x0N(0,Id)[x1x0vt(xt;θ)22].\begin{align*} \mathcal{L}_{\text{CFM}}(\boldsymbol{\theta}) = \mathbb{E}_{t\sim[0,1],\,\boldsymbol{x}_1\sim p_1,\,\boldsymbol{x}_0 \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_d)} \left[ \|\boldsymbol{x}_1 - \boldsymbol{x}_0 - \boldsymbol{v}_t(\boldsymbol{x}_t;\,\boldsymbol{\theta})\|^2_2 \right] \end{align*}.

where xt=(1t)x0+tx1\boldsymbol{x}_t = (1 - t)\boldsymbol{x}_0 + t\boldsymbol{x}_1. Intuitively, this trains the vector field to match the expected direction from "source to target" given an intermediate "simulation step".

Implementation in Practice

Let us now train a small flow matching model on the above mentioned annulus toy dataset!

1import torch
2import torch.nn as nn
3
4
5class FlowNet(nn.Module):
6  def __init__(self, d, h=64):
7      super().__init__()
8      self.d = d
9      self.net = nn.Sequential(
10          nn.Linear(1 + d, h), nn.ELU(),
11          nn.Linear(h, h), nn.ELU(),
12          nn.Linear(h, h), nn.ELU(),
13          nn.Linear(h, d),
14      )
15
16  def forward(self, t, x):
17      return self.net(torch.cat([t, x], dim=-1))
18
19
20def train(model, device, train_data_loader, num_epochs=1):
21  optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
22  for epoch in range(num_epochs):
23      for step, x1 in enumerate(train_data_loader):
24          t = torch.rand(x1.shape[0], 1, device=device)
25          x0 = torch.randn(x1.shape[0], model.d, device=device)
26          x1 = x1.to(device)
27          xt = (1 - t) * x0 + t * x1
28          target = x1 - x0
29          loss = ((model(t, xt) - target) ** 2).mean()
30          loss.backward()
31          optimizer.step()
32          optimizer.zero_grad()
33
34
35class AnnulusDataset(torch.utils.data.Dataset):
36  def __init__(self, num_samples):
37      self.num_samples = num_samples
38      angles = torch.rand(num_samples) * 2 * torch.pi
39      radii = torch.rand(num_samples) * 1 + 2
40      self.data = torch.stack(
41          [radii * torch.cos(angles), radii * torch.sin(angles)], dim=-1
42      )
43
44  def __len__(self):
45      return self.num_samples
46
47  def __getitem__(self, idx):
48      return self.data[idx]
49
50
51device = "cuda" if torch.cuda.is_available() else "cpu"
52model = FlowNet(d=2).to(device)
53batch_size = 256
54data = AnnulusDataset(num_samples=batch_size * 10_000)
55train_data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
56train(model, device, train_data_loader, num_epochs=1)
57torch.save(model.state_dict(), "flow_matching_on_annulus.pth")

After running the above code we can check that the model has learned the annulus dataset by sampling from it and plotting the samples. Below we use Euler integration to "walk along the vector field" (technically, solving for the flow) because it is very simple, but any numerical ODE solver will do.

1model = FlowNet(d=2)
2model.load_state_dict(torch.load("flow_matching_on_annulus.pth"))
3device = "cuda" if torch.cuda.is_available() else "cpu"
4model.to(device)
5
6
7def sample(model, device, num_samples, num_steps=50):
8  eval_time_points = torch.linspace(0, 1, num_steps)
9  step_delta = 1 / num_steps
10  x0 = torch.randn(num_samples, model.d, device=device)
11  xt = x0.clone()
12  for t in eval_time_points:
13      xt += (
14          model.forward(t * torch.ones(num_samples, 1, device=device), xt)
15          * step_delta
16      )
17  return xt
18
19
20model.eval()
21samples = sample(model, device, num_samples=10_000, num_steps=50).detach().cpu().numpy()
22
23plt.figure(figsize=(6, 6))
24plt.scatter(samples[:, 0], samples[:, 1], s=1)
25plt.savefig("flow_matching_toy.png")
Flow matching on annulus

It works, nice!