Adjoint matching and stochastic optimal control for diffusion models
Suppose we have a pretrained flow-based generative model that produces samples from a distribution \(p(x)\), and we want to control this model so that it instead generates from a target distribution \(p^*(x)\). For example, we might want to finetune the model so that its samples maximize a reward function \(r(x)\), targeting the distribution proportional to \(p(x)e^{r(x)}\). Adjoint matching solves this problem by formulating it as a stochastic optimal control (SOC) problem and providing an efficient algorithm to solve it. In this post, we will derive the necessary control theory from scratch and then show how it connects to the finetuning problem.
Part 1: Deterministic optimal control and the HJB equation
We start with the deterministic case because it is more transparent and the stochastic extension follows naturally. Let \(x_t\) be a curve on a Riemannian manifold \((\mathcal{M}, g)\) with velocity \(\dot{x}_t = b_t(x_t) + L_t(x_t)u_t(x_t)\) where \(b_t\) is a base drift and \(u_t\) is a control function. We define the cost functional:
\[J_0[u](x_0) = \int_{0}^{1} \frac{1}{2}\|u_t(x_t)\|^2 + f_t(x_t) \, dt + \Phi(x_1)\]subject to \(\dot{x}_t = b_t(x_t) + L_t(x_t)u_t(x_t)\) with \(x_{t=0} = x_0\). Our goal is to find the optimal control \(u_t^*\) that minimizes this cost, along with the value function \(V_0^*(x_0) = \min_u J_0[u](x_0)\).
Deriving the optimality conditions via Lagrange multipliers
The constraint that the control must satisfy is the continuity equation, which relates the density \(\rho_t\) to the velocity:
\[\frac{\partial \rho_t}{\partial t} + \text{Div}(\rho_t \dot{x}_t) = 0\]We write the expected cost and enforce the constraint using a Lagrange multiplier \(s_t: \mathcal{M} \to \mathbb{R}\):
\[u^* = \max_{s}\min_{u}\int_{0}^{1} \int_{\mathcal{M}} \rho_t \left(\frac{1}{2}\|u_t\|^2 + f_t\right) dV_g \, dt + \int_{\mathcal{M}} \rho_1 \Phi \, dV_g + \int_{0}^{1} \int_{\mathcal{M}} s_t\left(\frac{\partial \rho_t}{\partial t} + \text{Div}(\rho_t \dot{x}_t)\right) dV_g \, dt\]Integrating by parts on the constraint term and isolating the dependence on \(u\), we find:
\[\int_{0}^{1} \int_{\mathcal{M}} \rho_t \left\langle \frac{1}{2}u_t - L_t^\top \nabla s_t, u_t \right\rangle dV_g \, dt\]The first-order condition gives:
\[u_t^*(x_t) = L_t(x_t)^\top \nabla s_t(x_t)\]Substituting back and finding the stationary conditions for \(s_t\) (which requires a second round of integration by parts and exploiting the continuity equation), we arrive at the Hamilton-Jacobi-Bellman equation. Writing the value function as \(V_t^* = -s_t\):
\[\frac{d V_t^*(x_t)}{dt} = -\frac{1}{2}\|L_t(x_t)^\top \nabla V_t^*(x_t)\|^2 - f_t(x_t), \quad V_1^*(x_1) = \Phi(x_1)\]with optimal control \(u_t^*(x_t) = -L_t(x_t)^\top \nabla V_t^*(x_t)\).
Part 2: Stochastic optimal control
The stochastic extension replaces the deterministic dynamics with an SDE:
\[dx_s = \left(b_s(x_s) + L_s u_s(x_s;\theta)\right)ds + L_s \, dW_s\]The cost functional becomes:
\[J_t[u_\theta](x_t) = \mathbb{E}\left[ \int_t^1 \frac{1}{2}\|u_s(x_s;\theta)\|^2 + f_s(x_s) \, ds + g(x_1) \;\Big|\; x_t \right]\]The derivation follows the same Lagrangian approach but now the Kolmogorov forward equation replaces the continuity equation, and the resulting HJB equation acquires a diffusion term:
\[\frac{\partial V_t}{\partial t} + \langle \nabla V_t, b_t \rangle + \frac{1}{2}\text{Div}(L_tL_t^T \nabla V_t) - \frac{1}{2}\|L_t^T \nabla V_t\|^2 + f_t = 0, \quad V_1 = g\]Part 3: The adjoint matching algorithm
Solving the HJB equation directly is intractable in high dimensions. Adjoint matching avoids this by exploiting the fact that the gradient of the cost with respect to the parameters of the control is zero only at the optimal control.
The gradient can be computed using the adjoint method, which introduces the adjoint process \(a_s\) that, for a fixed trajectory sampled from the SDE, satisfies:
\[\frac{da_s}{ds} = -\nabla b_s(x_s)^T a_s - \nabla f_s(x_s) - \nabla u_s(x_s)^T \left(L_s^T a_s + u_s(x_s)\right), \quad a_1 = \nabla g(x_1)\]The gradient of the cost with respect to \(\theta\) is then:
\[\frac{\partial J_t[u_\theta](x_t)}{\partial \theta} = \mathbb{E}\left[ \int_t^1 \left\langle \frac{\partial u_s(x_s;\theta)}{\partial \theta}, L_s^T a_s + u_s(x_s;\theta)\right\rangle \;\Big|\; x_t \right]\]This expression is zero only when \(L_s^T a_s + u_s(x_s;\theta) = 0\). This motivates the lean adjoint \(\tilde{a}_s\), which drops the term that vanishes at optimality:
\[\frac{d\tilde{a}_s}{ds} = -\nabla b_s(x_s)^T \tilde{a}_s - \nabla f_s(x_s), \quad \tilde{a}_1 = \nabla g(x_1)\]The adjoint matching loss is:
\[\mathcal{L}(\theta) = \mathbb{E}\left[ \int_0^1 \|L_s^T \tilde{a}_s + u_s(x_s;\theta)\|^2 \, ds \right]\]The training procedure for each gradient step is:
- Discretize the time interval \((t_1,\dots,t_N) \subset [0,1]\)
- Simulate the controlled SDE forward and save \((x_{t_1},\dots,x_{t_N})\)
- Solve the lean adjoint ODE backward to obtain \((\tilde{a}_{t_1},\dots,\tilde{a}_{t_N})\)
- Stop gradients through \(x_{t_{1:N}}\) and \(\tilde{a}_{t_{1:N}}\)
- Compute the loss \(\mathcal{L}(\theta)\) by averaging the integrand over time steps
Part 4: Connecting SOC to finetuning diffusion models
The final piece is showing that the specific finetuning problem for diffusion models can be cast as an SOC problem of the form above.
Memoryless SDEs
Most flow-based generative models built from linear stochastic interpolants can be reparameterized as memoryless SDEs:
\[dx_t = (-K_t x_t - l_t + L_tL_t^T \nabla \log p_t(x_t)) \, dt + L_t \, dW_t\]This process is called memoryless because its joint distribution at the initial and final times factors independently, which follows from the fact that its time reversal is a linear SDE whose marginal at the endpoint is always the Gaussian prior regardless of the starting distribution.
Similarly, a memoryless SDE targeting the finetuned distribution \(p_t^*(x_t)\) has drift \(-K_t x_t - l_t + L_tL_t^T \nabla \log p_t^*(x_t)\).
The value function satisfies HJB
Both the base and finetuned marginals satisfy the Fokker-Planck equation. The log-density form of the Fokker-Planck equation for the base model is:
\[\frac{\partial \log p_t}{\partial t} + \langle \nabla \log p_t, -K_t x_t - l_t\rangle - \text{Tr}(K_t) + \frac{1}{2}\text{Div}(L_tL_t^T \nabla \log p_t) + \frac{1}{2}\|L_t^T \nabla \log p_t\|^2 = 0\]The same holds for \(\log p_t^*\). Defining \(V_t := \log p_t - \log p_t^*\) and subtracting the two equations yields:
\[\frac{\partial V_t}{\partial t} + \langle \nabla V_t, -K_tx_t - l_t + L_tL_t^T \nabla \log p_t \rangle + \frac{1}{2}\text{Div}(L_tL_t^T \nabla V_t) - \frac{1}{2}\|L_t^T \nabla V_t\|^2 = 0\]This is exactly the HJB equation with \(f_t = 0\), base drift \(b_t = -K_t x_t - l_t + L_tL_t^T \nabla \log p_t\), and terminal cost \(g(x_1) = \log p_1(x_1) - \log p_1^*(x_1)\). Therefore, the finetuning problem is an SOC problem:
\[J_t[u](x_t) = \mathbb{E}\left[ \int_t^1 \frac{1}{2}\|u_s(x_s)\|^2 \, ds + \log \frac{p_1(x_1)}{p_1^*(x_1)} \;\Big|\; x_t \right]\]The optimal control is \(u_t^*(x_t) = -L_t^T(\nabla \log p_t - \nabla \log p_t^*)\), and the controlled drift becomes \(-K_t x_t - l_t + L_tL_t^T \nabla \log p_t^*\), which is exactly the drift of the finetuned model.
This means we can apply the adjoint matching algorithm from Part 3 to efficiently finetune any flow-based generative model.
Enjoy Reading This Article?
Here are some more articles you might like to read next: