Skip to main content

Scientific ML

Adjoint Sensitivity Method

Compute gradients through an ODE solver by integrating a backward adjoint ODE, trading O(NT) activation memory for O(1) memory at the cost of a second integration.

AdvancedTier 2StableSupporting~35 min

Why This Matters

A Neural ODE integrated with a Runge-Kutta solver at NN steps holds an activation tensor at every step. Backpropagating through the unrolled solver costs O(Nd)O(N \cdot d) memory, the same as a residual network of depth NN. For an adaptive solver that takes thousands of steps on a stiff trajectory, this is fatal.

The adjoint sensitivity method computes the same gradient by integrating a second ODE backward in time, recovering activations on the fly. Memory cost drops to O(d)O(d) in depth, independent of NN. The price is a second forward solve and the loss of bit-exact gradients (the recovered trajectory differs from the original by solver error).

This trade is not new to deep learning. Pontryagin derived the adjoint equation in 1962 for optimal control. The control community has used it for sixty years to differentiate through dynamical systems. Chen, Rubanova, Bettencourt, and Duvenaud (2018) imported it into deep learning and made Neural ODEs trainable at scale.

Setup

Let z(t)Rdz(t) \in \mathbb{R}^d evolve under z˙=fθ(z(t),t)\dot z = f_\theta(z(t), t) from t=t0t = t_0 to t=t1t = t_1, with initial state z(t0)z(t_0). Let L=L(z(t1))\mathcal{L} = L(z(t_1)) be a scalar loss depending on the final state. We want gradients with respect to:

  • the initial state z(t0)z(t_0),
  • the parameters θ\theta, and
  • the time bounds t0t_0 and t1t_1.

The naive approach unrolls the solver and applies reverse-mode autodiff. The adjoint approach defines a(t)=L/z(t)a(t) = \partial \mathcal{L} / \partial z(t) and derives an ODE for a(t)a(t) that runs backward in time.

The Adjoint Equation

Theorem

Adjoint Sensitivity for ODEs

Statement

Define the adjoint a(t)=L/z(t)a(t) = \partial \mathcal{L} / \partial z(t). Then a(t)a(t) satisfies the backward ODE

da(t)dt=a(t)fθ(z(t),t)z\frac{d a(t)}{dt} = -a(t)^\top \, \frac{\partial f_\theta(z(t), t)}{\partial z}

with terminal condition a(t1)=L/z(t1)a(t_1) = \partial L / \partial z(t_1). The parameter gradient is recovered by

dLdθ=t1t0a(t)fθ(z(t),t)θdt.\frac{d \mathcal{L}}{d \theta} = -\int_{t_1}^{t_0} a(t)^\top \, \frac{\partial f_\theta(z(t), t)}{\partial \theta} \, dt.

Both integrals are solved by integrating an augmented system from t1t_1 back to t0t_0, alongside the reconstructed state z(t)z(t).

Intuition

Think of a(t)a(t) as the sensitivity of the loss to a perturbation of the state at time tt. A perturbation at time tt propagates forward through the dynamics and ends up shifting z(t1)z(t_1) by the linearized flow map. The chain rule says a(t)=a(t1)(Jacobian of flow from t to t1)a(t) = a(t_1)^\top \cdot (\text{Jacobian of flow from } t \text{ to } t_1). Differentiating this in tt gives the adjoint ODE. The minus sign is because aa is defined with the loss at the end of the interval, so it accumulates as tt decreases.

Proof Sketch

Form the Lagrangian J(θ)=L(z(t1))+t0t1a(t)(z˙fθ(z,t))dt\mathcal{J}(\theta) = L(z(t_1)) + \int_{t_0}^{t_1} a(t)^\top (\dot z - f_\theta(z, t)) dt, where a(t)a(t) is a Lagrange multiplier. The integral is zero on the trajectory. Take the variation in θ\theta, integrate the aδz˙a^\top \dot{\delta z} term by parts using δz(t0)=0\delta z(t_0) = 0, and collect terms in δz(t)\delta z(t). Setting the coefficient of δz(t)\delta z(t) to zero gives the adjoint ODE. The boundary term at t1t_1 gives the terminal condition. The remaining a(f/θ)δθ\int a^\top (\partial f / \partial \theta) \delta\theta gives the parameter gradient.

Why It Matters

Three quantities flow backward from t1t_1 to t0t_0 in a single augmented ODE: the reconstructed state z(t)z(t), the adjoint a(t)a(t), and the running parameter gradient L/θ\partial \mathcal{L}/\partial \theta. Total memory: three vectors, independent of how many solver steps are taken. A solver that takes 10,000 adaptive steps still uses constant memory in depth.

Failure Mode

The reconstructed z(t)z(t) from the backward solve is not the same trajectory as the forward solve, because adaptive step sizes and floating-point error diverge. The adjoint gradient is therefore an approximation of the discretize-then-differentiate gradient, and the two can disagree noticeably for stiff systems. See "Discretize vs Optimize" below.

The Augmented Backward ODE

The full backward integration runs three ODEs jointly from t1t_1 to t0t_0:

z˙=fθ(z,t)a˙=afθ(z,t)zg˙=afθ(z,t)θ\begin{aligned} \dot z &= f_\theta(z, t) \\ \dot a &= -a^\top \, \frac{\partial f_\theta(z, t)}{\partial z} \\ \dot g &= -a^\top \, \frac{\partial f_\theta(z, t)}{\partial \theta} \end{aligned}

with terminal conditions z(t1)z(t_1) from the forward solve, a(t1)=L/z(t1)a(t_1) = \partial L/\partial z(t_1), and g(t1)=0g(t_1) = 0. At t=t0t = t_0, g(t0)=L/θg(t_0) = \partial \mathcal{L}/\partial \theta and a(t0)=L/z(t0)a(t_0) = \partial \mathcal{L}/\partial z(t_0).

Each backward solver step requires evaluating fθ/z\partial f_\theta / \partial z and fθ/θ\partial f_\theta / \partial \theta. Both are vector-Jacobian products done with a single reverse-mode pass through the network defining fθf_\theta.

Discretize vs Optimize

Two philosophies clash here, and they give different gradients.

Optimize-then-discretize (the adjoint method). Derive the continuous adjoint ODE first, then discretize it with a numerical solver. This gives the O(1)O(1) memory benefit but produces gradients that are not the exact gradient of any particular discrete computation. They are the discretized version of the true continuous gradient.

Discretize-then-optimize (backprop through the solver). Pick a discrete solver (RK4, Dormand-Prince), unroll it, and apply reverse-mode autodiff to the unrolled graph. This gives the exact gradient of the discrete forward computation but costs O(Nd)O(N \cdot d) memory.

The two gradients agree only in the limit of vanishing step size. For adaptive solvers on stiff problems they can disagree by enough to destabilize training. Gholaminejad et al. (2019, ANODE) documented cases where the adjoint method's gradients are wrong by tens of percent. The fix they proposed (ANODE-V2) recomputes activations at the same step locations as the forward pass, recovering bit-exact agreement with discretize-then-optimize at the cost of storing solver state.

In practice, well-conditioned Neural ODEs (small TT, Lipschitz fθf_\theta) have agreement to several decimal places and the adjoint method is the default. For stiff or chaotic dynamics, prefer discretize-then-optimize with gradient checkpointing.

Pontryagin's Maximum Principle

The adjoint equation is half of Pontryagin's maximum principle (Pontryagin, Boltyanskii, Gamkrelidze, Mishchenko, 1962). The other half is an optimality condition on the control:

u(t)=argmaxuH(z(t),a(t),u,t)u^*(t) = \arg\max_u \, H(z(t), a(t), u, t)

where H=af(z,u,t)L(z,u,t)H = a^\top f(z, u, t) - L(z, u, t) is the Hamiltonian. In a Neural ODE, the "control" is the parameter vector θ\theta and we use gradient descent rather than the maximum principle to update it. So Neural ODE training uses the adjoint equation but not the maximum principle in full.

The maximum principle is the optimality condition for the constrained optimization problem of optimal control. The adjoint equation alone is the gradient computation. They are bundled in textbooks but separable in practice, and Neural ODEs use only the gradient half.

Common Confusions

Watch Out

The adjoint method is not just backprop

Standard backprop on an unrolled solver is discretize-then-optimize: it gives the exact gradient of the discrete forward pass. The adjoint method is optimize-then-discretize: it gives a discretized approximation of the continuous gradient. The two are different gradients of different objects, and they happen to agree in the limit of vanishing step size.

Watch Out

O(1) memory is in depth, not in width

The adjoint method uses memory proportional to the state dimension dd, not proportional to the number of solver steps NN. A solver taking 10,000 steps on a d=1024d = 1024 state still uses O(d)O(d) memory. But the state itself can be arbitrarily large, so "O(1) memory" is shorthand for "no dependence on solver step count," not "constant total bytes."

Watch Out

Reconstructing z(t) backward is not free

The backward solve of z˙=fθ(z,t)\dot z = f_\theta(z, t) runs the same dynamics in reverse. For dissipative systems (contraction in time), the forward dynamics shrink errors and the backward dynamics amplify them. Numerical reconstruction diverges from the true trajectory exponentially fast. This is why checkpointing, which saves zz at a few intermediate times, is often more robust than full backward reconstruction.

Summary

  • The adjoint sensitivity method computes ODE gradients by integrating a backward adjoint ODE, costing O(d)O(d) memory in depth instead of O(Nd)O(N \cdot d) for solver-unrolled backprop.
  • The adjoint a(t)=L/z(t)a(t) = \partial \mathcal{L}/\partial z(t) satisfies a˙=afθ/z\dot a = -a^\top \partial f_\theta / \partial z with terminal condition a(t1)=L/z(t1)a(t_1) = \partial L / \partial z(t_1).
  • Three ODEs run jointly backward: state, adjoint, and accumulating parameter gradient.
  • Optimize-then-discretize (adjoint) gives a discretized continuous gradient; discretize-then-optimize (backprop) gives an exact discrete gradient. They differ by solver error and can disagree noticeably on stiff systems.
  • Pontryagin (1962) derived the adjoint equation for optimal control; Chen et al. (2018) imported it into deep learning for Neural ODEs.

Exercises

ExerciseCore

Problem

Consider the linear scalar ODE z˙=αz\dot z = -\alpha z with z(0)=z0z(0) = z_0, scalar parameter α>0\alpha > 0, integration interval [0,T][0, T], and loss L=12z(T)2\mathcal{L} = \tfrac{1}{2} z(T)^2. Write down the adjoint equation, solve it analytically, and compute L/α\partial \mathcal{L}/\partial \alpha. Verify against direct differentiation of the closed-form solution z(T)=z0eαTz(T) = z_0 e^{-\alpha T}.

ExerciseAdvanced

Problem

Suppose a Neural ODE on a d=64d = 64 state is integrated with an adaptive Dormand-Prince solver that takes 500 steps on average. Compare the memory required for (a) standard backprop through the unrolled solver, (b) gradient checkpointing with N\sqrt{N} checkpoints, and (c) the adjoint method. Then explain why a researcher training on chaotic Lorenz dynamics might prefer (a) or (b) over (c) despite the higher memory cost.

References

Canonical:

  • Pontryagin, Boltyanskii, Gamkrelidze, Mishchenko, "The Mathematical Theory of Optimal Processes" (1962), Chapter 1 derives the maximum principle and the adjoint equation
  • Chen, Rubanova, Bettencourt, Duvenaud, "Neural Ordinary Differential Equations" (NeurIPS 2018, arXiv:1806.07366), Section 2 and Appendix B

Critical analysis:

  • Gholaminejad, Keutzer, Biros, "ANODE: Unconditionally Accurate Memory-Efficient Gradients for Neural ODEs" (IJCAI 2019, arXiv:1902.10298) — documents cases where the adjoint method's gradients diverge from discretize-then-optimize and proposes a fix
  • Onken, Ruthotto, "Discretize-Optimize vs Optimize-Discretize for Time-Series Regression and Continuous Normalizing Flows" (arXiv:2005.13420)

Continuous control background:

  • Bryson, Ho, "Applied Optimal Control: Optimization, Estimation, and Control" (1975), Chapter 2 covers continuous-time gradient computation via adjoints
  • Lewis, Vrabie, Syrmos, "Optimal Control" (3rd edition 2012), Chapter 3

Neural-ODE-specific:

  • Kidger, "On Neural Differential Equations" (PhD thesis 2022, arXiv:2202.02435), Chapter 3 covers gradient computation in detail
  • Massaroli, Poli, Park, Yamashita, Asama, "Dissecting Neural ODEs" (NeurIPS 2020, arXiv:2002.08071), Section 3 compares gradient methods

Next Topics

Last reviewed: April 17, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.