local_coordinates

A JAX framework for differential geometry computations on Riemannian manifolds.

local_coordinates provides a type-safe, coordinate-aware system for performing differential geometry computations. It leverages JAX’s automatic differentiation to compute gradients and Hessians, enabling second-order geometric computations like curvature tensors and geodesics.

Key features

  • Jets for second-order automatic differentiation, storing value, gradient, and Hessian
  • Riemannian metrics with inner products, index raising and lowering
  • Levi-Civita connection and Christoffel symbol computation
  • Curvature tensors including the Riemann curvature tensor, Ricci tensor, and scalar curvature
  • Coordinate transformations with pullback metrics under smooth maps
  • Riemann normal coordinates where the metric is locally Euclidean
  • Geodesics via exponential and logarithmic maps using Taylor approximation or ODE integration

Quick start

import jax.numpy as jnp
from local_coordinates.jet import Jet, function_to_jet
from local_coordinates.basis import get_standard_basis
from local_coordinates.metric import RiemannianMetric
from local_coordinates.connection import get_levi_civita_connection
from local_coordinates.riemann import get_riemann_curvature_tensor, get_ricci_tensor

def metric_components(x):
    return jnp.array([
        [1 + 0.1*x[0]**2, 0.0],
        [0.0, 1 + 0.1*x[1]**2]
    ])

p = jnp.array([1.0, 1.0])
basis = get_standard_basis(p)
metric_jet = function_to_jet(metric_components, p)
metric = RiemannianMetric(basis=basis, components=metric_jet)

connection = get_levi_civita_connection(metric)
riemann = get_riemann_curvature_tensor(connection)
ricci = get_ricci_tensor(connection, R=riemann)

Architecture

local_coordinates/
├── jet.py              # Second-order Taylor data
├── jacobian.py         # Coordinate transformations
├── basis.py            # Tangent space bases
├── frame.py            # Bases with Lie brackets
├── tangent.py          # Tangent vectors
├── tensor.py           # Generic (k,l) tensors
├── metric.py           # Riemannian metrics
├── connection.py       # Christoffel symbols, covariant derivatives
├── riemann.py          # Curvature tensors
├── normal_coords.py    # Riemann normal coordinates
└── exponential_map.py  # Exponential and logarithmic maps