A Flowing Intro to Flow Matching
Foreword
Flow matching is a generative modelling technique developed in 2022 by Yaron Lipman et al. A paper is a highly processed artifact which often does not tell you the full story of how the authors discovered the results; rather, results are presented in a formal tone in the main text and the appendix contains proofs which just verify the result but don't develop them. In this blog post I try to smoothly develop the main ideas and results, ideally so that you feel like you could have come up with it yourself. This is the way I feel when I read in Feynman's lectures on physics and this introduction tries to approximate that. You can view this blog post as a middle-ground between the paper and the very extensive "official" 83-page long Flow Matching Guide and Code, also by Yaron Lipman and others. Plus, here and there I give more context than the paper (for example on the continuity equation) and provide 3D-interactive visualizations. Have fun!
Generative Modeling
The abstract mathematical goal of generative modeling is to draw samples from an unknown data distribution, unknown in the sense that we only have a finite number of samples, but no explicit representation like a formula for its density or an algorithm to obtain a ground truth sample from it.
On a high level, the basic idea of generative modeling is to:
-
First take a sample in latent space from some simple distribution which is easy to sample from, e.g., via the Box-Muller transform.
-
Then apply a (potentially stochastic) transformation derived from the data to obtain .
Ideally, is chosen such that the distribution of the random variable is as close as possible to the data distribution , in symbols, .
Flows and Density Paths
You can imagine that it is difficult to learn such a transformation directly via a KL-divergence, Wasserstein distance or something similar. This motivates to do it with intermediate steps. But modelling the transformation with a finite number of composed learned transformations, such as in Normalizing Flows, comes with several fundamental drawbacks, most notably constraints on expressivity and issues with computational scalability in practice. So why not model the transformation as a continuous path? Why not learn a time-dependent transformation with and ? Well, learning parameters for every is of course infeasible. And taking the time as an input to a parametric model doesn't make the problem any easier.
But there is a remedy. To this end, let us think of as moving a sample through the data space via the parametric curve . If is smooth, for example diffeomorphic for every , then it is not far fetched to think of as a flow which flows particles (samples) from their initial position to another place.
Definition: Flow
For our purposes, a flow is a time-dependent map .
Together with an initial distribution , a flow induces a continuous path of distributions, namely the distribtions of the random variables . If every is diffeomorphic, every distribution on the path has a density and the corresponding probability density path can be computed via a change of variables as:
Below you can see an animation of a probability density path which transforms a standard two-variate normal distribution into via linear interpolation of the means and standard deviations, i.e., and with
In this context a flow is also called a continuous normalizing flow (CNF) as the associated distributions are probability distribution, i.e., (There is also the concept of signed measures with a total mass of greater than one.)
A flow in turn can be induced by a time-dependent vector field . Intuitively, the vector field specifies how to locally simulate a sample at a given time. Formally, we can define the flow to be the solution of a collection of initial value problems with ordinary differential equations
and initial conditions . Under sufficient regularity conditions on the vector field, there exists exactly one flow satisfying these conditions. For example, according to the global version of the Picard-Lindelöf theorem a sufficient condition is that is continuous and satisfies a global Lipschitz condition in its second argument, i.e., there exists a global constant such that
for all and . In such cases it also makes sense to speak of the probability path induced some initial distribution and the vector field. The distributions of this probability path must not necessarily be absolutely continuous (with respect to the Lebesgue measure). However, if the vector field is sufficiently smooth in its second argument, then the flow is diffeomorphic at every time and hence we get a probability density path.
As we are going to see, it turns out that a particular characterization of when exactly a given probability density path is induced by a given vector field is of central importance for developing a tractable training algorithm. Concretely, this characterization is given by the satisfaction of the differential form of the continuity equation, a.k.a., mass conservation equation:
That means, at any point in space the rate of change in density is equal to the convergence (negative divergence) of the flux at that point. The flux of a quantity has dimension "amount of quantity flowing through a unit area per unit time." So it is a measure for how much quantity flows through an infinitesimal disk centered at a point and oriented orthogonal to the velocity vector in an infinitesimal time interval, normalized for the disk's area and the time interval length. And the convergence (divergence) of a velocity field is a measure for how much of a uniformly distributed quantity is destroyed (created) at a point per unit time. Given these interpretations, the continuity equation should roughly make sense. In the expandable box below you can find an informal derivation of the continuity equation using the even more intuitive integral form of the continuity equation and Gauss's divergence theorem.
Conditional --> Marginal
The flow matching idea starts with the observation that for learning a flow it is enough to learn a suitable deep parametric model of a vector field. At inference time one can then simulate a sample via some numerical ODE method according to the vector field, e.g., Euler integration:
where is the number of steps. If the numerical ODE solver is differentiable, and they usually are, one could setup a loss function measuring the difference between two distributions using two corresponding sample cohorts and then train the vector field's parameters by a first order gradient descent based method. However, this involves simulation during training and is there not nearly as scalable as other generative modeling methods, when using a fine simulation resolution, i.e., a large number of steps. Ideally, we would like to train a CNF simulation free with a very simple and scalable objective, like directly regressing the vector field weighted by the distribution of the input variable. The following thought might seem non-sensical at first. But bear with me. If we already knew a vector field inducing some with and , we could use it to define an ideal loss function
And now comes the key idea which sparking the development of the whole method. While we don't know how a good looks like, we have an idea what and are supposed to do conditioned on knowing where ends up, i.e., conditioned on .
Key Idea of Simulation Free CNF Training
Define reasonbly and find a which induces it. Then marginalize the conditional vector field such that it induces the marginalized density path.
For we want to define a with and . But let us treat the more general case of a normal distribution for some (very small) standard deviation with .
Above we already saw one possible way how to do that: Linearly interpolating in parameter space. Let us slightly generalize it by using mean and standard deviation paths which might depend on :
By marginalizing over we then get which we can interpret as a continuous Gaussian mixture model. Actually, if the dataset is an empirical distribution with a finite number of samples it is exactly a classical Gaussian mixture model:
Below you can see an animation of such a path where the data samples are drawn uniformly from an annulus with inner radius 2 and outer radius 3 and . I encourage you to play around with the slider to see how the path changes when chaning the number of data points. Especially for 1000 data points the dataset shape becomes obvious!
But what is a (or equivalently ) which induces ? The continuity equation tells us that if a solution exists, there are infinitely many as we can take any divergence free vector field, devide by , and add it to the one we found. Since the divergence operator is linear, the new vector field still satisfies the continuity equation. One concrete option for our Gaussian path above is obviously given by And since
we obtain that induces .
So far, so good. But how do we "marginalize" the conditional vector field such that it induces the desired marginal density path? It seems plausible that an ansatz like
for some function might work. So let us try to find which makes the continuity equation work. Recall that
is necessary and sufficient for to induce . If we substitute with the definition of and the ansatz for , and differentiate under the integral sign with Leibniz's rule, we want to have for every that
Now the idea in the design of is to leverage the fact that and fullfil the continuity equation. If gets rid of and contributes the factor , we can see that the integrand can be factored such that one factor is our conditional continuity equation! So when setting
the left side is nothing but
which is exactly what we want. One caveat here is that must have full support. But requiring that the conditional densities have full support is sufficient for this to be true. Let us summarize the result:
How to aggregate conditional densities and vector fields
If a conditional density path is induced by a vector field , then
induces
In hindsight, we can make a nice interpretation of the way the conditional vector field is to be marginalized: By Bayes' theorem the weighting factor is exactly the conditional density of given the random variable So we may also write
This makes a lot of sense! Intuitively, when simulating a particle and deciding where to move next, we have to not only take into account knowledge of the final distribution, but also of the current observation since the already passed time and taken path has rendered some final destinations infeasible (conditional density of 0).
Conditional Objective
Let us now return the ideal flow matching objective
It is still intractable as we don't know how to sample from , let alone how to compute . But instead we can sample conditionally and then try to match the conditional vector field:
This is finally a tractable and scalable objective for training a CNF. But how does it compare to the ideal flow matching objective? Is it giving us the same set of parameters when minimized? First note that both terms inside the (linear) expectation operators are of the form . Let us compare them one by one, but let us forget about the marginalization of as this is identical for both objectives.
The terms are indentical since
Note that we assumed that chaning the integration order is fine. According to Fubini's theorem the Lebesgue integrability of the integrand is sufficient for this to be true. It is fair to say that for virtually all practical purposes vanishes sufficiently quickly for the integral to be finite, e.g., when is Gaussian. Nevertheless, one should keep this in mind when modeling crazy data distributions.
The mixed terms are also identical for a change of integration order yields
And what about the terms? Well, the clue is that they don't depend on . So the objectives are identical up to an additive constant which doesn't depend on which means they have the same set of minimizers
Equivalence of the conditional and naive flow matching objectives
The conditional flow matching objective and the naive one are identical up to an additive constant, or equivalently, they have the same gradient.
So first order methods like Adam behave the same if we choose as a surrogate for .
Concrete Conditional Path
Now that we have a tractable and scalable objective for training a CNF, we can start to think about concrete conditional probability paths.
Unarguably, the most natural choice are the straightforward mean/std linear interpolations, i.e.,
Their derivatives are and . So according to our derived formula above
In the conditional flow matching objective is sampled according to . Equivalently, we can sample and then take Plugging this into the above formula for the conditional vector field terms cancel out and we get
In the limit of this simplifies even further to . This yields the (for its simplicity) most popular and much celebrated flow matching objective of
where . Intuitively, this trains the vector field to match the expected direction from "source to target" given an intermediate "simulation step".
Implementation in Practice
Let us now train a small flow matching model on the above mentioned annulus toy dataset!
1import torch
2import torch.nn as nn
3
4
5class FlowNet(nn.Module):
6 def __init__(self, d, h=64):
7 super().__init__()
8 self.d = d
9 self.net = nn.Sequential(
10 nn.Linear(1 + d, h), nn.ELU(),
11 nn.Linear(h, h), nn.ELU(),
12 nn.Linear(h, h), nn.ELU(),
13 nn.Linear(h, d),
14 )
15
16 def forward(self, t, x):
17 return self.net(torch.cat([t, x], dim=-1))
18
19
20def train(model, device, train_data_loader, num_epochs=1):
21 optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
22 for epoch in range(num_epochs):
23 for step, x1 in enumerate(train_data_loader):
24 t = torch.rand(x1.shape[0], 1, device=device)
25 x0 = torch.randn(x1.shape[0], model.d, device=device)
26 x1 = x1.to(device)
27 xt = (1 - t) * x0 + t * x1
28 target = x1 - x0
29 loss = ((model(t, xt) - target) ** 2).mean()
30 loss.backward()
31 optimizer.step()
32 optimizer.zero_grad()
33
34
35class AnnulusDataset(torch.utils.data.Dataset):
36 def __init__(self, num_samples):
37 self.num_samples = num_samples
38 angles = torch.rand(num_samples) * 2 * torch.pi
39 radii = torch.rand(num_samples) * 1 + 2
40 self.data = torch.stack(
41 [radii * torch.cos(angles), radii * torch.sin(angles)], dim=-1
42 )
43
44 def __len__(self):
45 return self.num_samples
46
47 def __getitem__(self, idx):
48 return self.data[idx]
49
50
51device = "cuda" if torch.cuda.is_available() else "cpu"
52model = FlowNet(d=2).to(device)
53batch_size = 256
54data = AnnulusDataset(num_samples=batch_size * 10_000)
55train_data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
56train(model, device, train_data_loader, num_epochs=1)
57torch.save(model.state_dict(), "flow_matching_on_annulus.pth")After running the above code we can check that the model has learned the annulus dataset by sampling from it and plotting the samples. Below we use Euler integration to "walk along the vector field" (technically, solving for the flow) because it is very simple, but any numerical ODE solver will do.
1model = FlowNet(d=2)
2model.load_state_dict(torch.load("flow_matching_on_annulus.pth"))
3device = "cuda" if torch.cuda.is_available() else "cpu"
4model.to(device)
5
6
7def sample(model, device, num_samples, num_steps=50):
8 eval_time_points = torch.linspace(0, 1, num_steps)
9 step_delta = 1 / num_steps
10 x0 = torch.randn(num_samples, model.d, device=device)
11 xt = x0.clone()
12 for t in eval_time_points:
13 xt += (
14 model.forward(t * torch.ones(num_samples, 1, device=device), xt)
15 * step_delta
16 )
17 return xt
18
19
20model.eval()
21samples = sample(model, device, num_samples=10_000, num_steps=50).detach().cpu().numpy()
22
23plt.figure(figsize=(6, 6))
24plt.scatter(samples[:, 0], samples[:, 1], s=1)
25plt.savefig("flow_matching_toy.png")
It works, nice!