A differential geometry toolkit for machine learning
Many problems in machine learning have a geometric structure that standard Euclidean methods cannot exploit. When data lies on or near a manifold, the right geometric tools can reveal structure that is invisible in ambient coordinates. In this post, I will walk through the core constructions of Riemannian geometry that appear repeatedly in my work on representation learning and generative models: metrics, connections, and curvature. For each concept, I will describe the mathematical definition and highlight why it matters for machine learning.
Riemannian metrics and pullback geometry
A Riemannian metric \(g\) on a manifold \(\mathcal{M}\) assigns an inner product to each tangent space, allowing us to measure lengths and angles. In coordinates \(\{x^i\}\), the metric is represented by a symmetric positive definite matrix \(g_{ij}(x)\).
For machine learning, the relevant situation is usually a learned map \(F_\theta: \mathcal{Z} \to \mathcal{X}\) from a latent space \(\mathcal{Z}\) to a data space \(\mathcal{X}\). If \(\mathcal{X}\) has a metric \(g\), then the pullback metric on \(\mathcal{Z}\) is:
\[(F^*g)_{ij}(z) = \sum_{a,b} \frac{\partial F^a}{\partial z^i} g_{ab}(F(z)) \frac{\partial F^b}{\partial z^j}\]When \(\mathcal{X} = \mathbb{R}^n\) with the Euclidean metric, this simplifies to \(G = J^\top J\) where \(J\) is the Jacobian of \(F\). The pullback metric encodes how distances in data space look from the perspective of the latent space. It captures the local geometry of the learned representation.
For a normalizing flow trained via maximum likelihood, the map \(F_\theta\) is a diffeomorphism between the latent and data spaces. The pullback metric on the latent space then tells us how the flow distorts space. Regions where the metric has large eigenvalues correspond to directions in latent space that are stretched by the flow, while small eigenvalues correspond to compressed directions.
Connections and the covariant derivative
A metric alone lets us measure lengths, but to compare vectors at different points we need a connection. The Levi-Civita connection is the unique connection that is compatible with the metric (preserving inner products under parallel transport) and is torsion-free. It is specified by the Christoffel symbols:
\[\Gamma^k_{ij} = \frac{1}{2} g^{kl}\left(\frac{\partial g_{il}}{\partial x^j} + \frac{\partial g_{jl}}{\partial x^i} - \frac{\partial g_{ij}}{\partial x^l}\right)\]The covariant derivative \(\nabla_{E_i} V\) of a vector field \(V = V^j E_j\) along a direction \(E_i\) is:
\[\nabla_{E_i} V = \left(\frac{\partial V^k}{\partial x^i} + \Gamma^k_{ij} V^j\right) E_k\]There is a useful interpretation of the covariant derivative in terms of normalizing flows. Suppose \(F_\theta: \mathcal{Z} \to \mathcal{X}\) is a parameterized bijection and \(W_\theta\) is a vector field on \(\mathcal{Z}\). We can push \(W_\theta\) forward to \(\mathcal{X}\) using the Jacobian of \(F_\theta\), vary \(\theta\), and then pull the result back. The covariant derivative measures how \(W_\theta\) changes after accounting for the fact that the coordinate system itself (given by \(F_\theta\)) is also changing:
\[\nabla_{\frac{d}{d\theta}} W_\theta = (F_\theta^{-1})_* \frac{d}{d\theta} \left((F_\theta)_* W_\theta\right)\]In components, this becomes:
\[\left(\nabla_{\frac{d}{d\theta}} W_\theta\right)^i = \frac{\partial W^i_\theta}{\partial \theta} + \left(\frac{\partial x}{\partial z}\right)^{-1}_{ij} \frac{\partial^2 x^j_\theta}{\partial \theta \, \partial z^k} W^k_\theta\]The second term is the correction that accounts for the change of coordinates, and it is exactly the Christoffel symbol contribution.
Curvature
The Riemann curvature tensor measures the failure of parallel transport to commute. Given three vector fields \(E_i, E_j, E_k\), it is defined as:
\[R(E_i, E_j)E_k = \nabla_{E_i} \nabla_{E_j} E_k - \nabla_{E_j} \nabla_{E_i} E_k - \nabla_{[E_i, E_j]} E_k\]Expanding in a local frame where \(\nabla_{E_j} E_k = \Gamma^l_{jk} E_l\) and the Lie bracket is \([E_i, E_j] = c^l_{ij} E_l\), the components are:
\[{R_{ijk}}^m = E_i(\Gamma^m_{jk}) - E_j(\Gamma^m_{ik}) + \Gamma^l_{jk}\Gamma^m_{il} - \Gamma^l_{ik}\Gamma^m_{jl} - c^l_{ij}\Gamma^m_{lk}\]In a coordinate basis where the Lie brackets vanish, the last term drops out and we recover the standard textbook formula.
The Ricci tensor is the contraction \(R_{ij} = {R_{ikj}}^k\), and the scalar curvature is \(R = g^{ij} R_{ij}\). These give progressively coarser summaries of the curvature.
Why curvature matters for representation learning
Curvature has direct implications for finding good coordinate systems. A fundamental question in representation learning is whether a learned representation can be decomposed into independent factors. In the geometric language, this asks whether the pullback metric can be diagonalized by a coordinate transformation.
A classical result in Riemannian geometry states that a metric can be brought to diagonal form if and only if certain components of the Riemann curvature tensor vanish (specifically, the off-diagonal sectional curvatures must be zero). This places a hard geometric constraint on when independent factorization of a representation is possible. When the curvature obstruction is nonzero, no coordinate change can make the factors independent, and any factored representation will necessarily lose information.
Riemann normal coordinates
At any point \(p \in \mathcal{M}\), we can construct Riemann normal coordinates in which the metric is locally Euclidean and the Christoffel symbols vanish:
\[g_{ij}(p) = \delta_{ij}, \quad \Gamma^k_{ij}(p) = 0\]These coordinates are constructed using the exponential map, which sends tangent vectors at \(p\) to points on the manifold by following geodesics. The logarithmic map inverts this.
Normal coordinates are useful computationally because they simplify geometric expressions at the point of interest. In the local_coordinates library, this is implemented as a coordinate transformation that takes any Riemannian metric and produces the normal coordinate representation, allowing geometric computations to be performed in a frame where first-order metric effects vanish.
Geodesics
A geodesic is a curve that locally minimizes length. It satisfies the geodesic equation:
\[\ddot{x}^k + \Gamma^k_{ij} \dot{x}^i \dot{x}^j = 0\]The exponential map \(\exp_p(v)\) sends a tangent vector \(v \in T_p\mathcal{M}\) to the point reached by following the geodesic starting at \(p\) with initial velocity \(v\) for unit time. The logarithmic map \(\log_p(q)\) returns the initial velocity needed to reach \(q\) from \(p\) along a geodesic.
For metrics arising from normalizing flows, geodesics trace out the shortest paths through data space as measured by the learned geometry. Computing these geodesics requires the Christoffel symbols, which in turn require the metric and its derivatives. This is one of the motivations for building a library that can compute all of these quantities automatically using JAX’s autodifferentiation.
Enjoy Reading This Article?
Here are some more articles you might like to read next: