Couplings¤
generax.distributions.coupling.JointCoupling
¤
Given two batches of samples from two distributions, this will compute a discrete distribution as done in multisample flow matching
Source code in generax/distributions/coupling.py
class JointCoupling(eqx.Module, ABC):
"""Given two batches of samples from two distributions, this
will compute a discrete distribution as done in [multisample flow matching](https://arxiv.org/pdf/2304.14772.pdf)
$$q(x_0,x_1) = \sum_{i,j}\pi_{i,j}\delta(x_0 - x_0^i)\delta(x_1 - x_1^j))$$
"""
batch_size: int
x0: Array
x1: Array
logits: Array
def __init__(self, x0: Array, x1: Array):
"""Initialize the coupling
**Arguments**:
- x0: A batch of samples from p(x_0)
- x1: A batch of samples from p(x_1)
"""
self.batch_size = x0.shape[0]
self.x0 = x0
self.x1 = x1
self.logits = self.compute_logits()
assert self.logits.shape == (self.batch_size, self.batch_size)
@abstractmethod
def compute_logits(self):
"""Compute $\log \pi_{i,j}$"""
pass
def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
"""Resample from the coupling
**Arguments**:
- rng_key: The random number generator key
**Returns**:
A sample from q(x_0|x_1)
"""
idx = jax.random.categorical(rng_key, self.logits, axis=0)
return self.x0[idx]
__init__(self, x0: Array, x1: Array)
¤
Initialize the coupling
Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)
Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
"""Initialize the coupling
**Arguments**:
- x0: A batch of samples from p(x_0)
- x1: A batch of samples from p(x_1)
"""
self.batch_size = x0.shape[0]
self.x0 = x0
self.x1 = x1
self.logits = self.compute_logits()
assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self)
abstractmethod
¤
Compute \(\log \pi_{i,j}\)
Source code in generax/distributions/coupling.py
@abstractmethod
def compute_logits(self):
"""Compute $\log \pi_{i,j}$"""
pass
sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array
¤
Resample from the coupling
Arguments: - rng_key: The random number generator key
Returns: A sample from q(x_0|x_1)
Source code in generax/distributions/coupling.py
def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
"""Resample from the coupling
**Arguments**:
- rng_key: The random number generator key
**Returns**:
A sample from q(x_0|x_1)
"""
idx = jax.random.categorical(rng_key, self.logits, axis=0)
return self.x0[idx]
generax.distributions.coupling.UniformCoupling (JointCoupling)
¤
This is a uniform coupling between two distributions
Source code in generax/distributions/coupling.py
class UniformCoupling(JointCoupling):
"""This is a uniform coupling between two distributions"""
def compute_logits(self) -> Array:
"""Compute the logits for the coupling"""
return jnp.ones((self.batch_size, self.batch_size))/self.batch_size
def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
return self.x0
__init__(self, x0: Array, x1: Array)
¤
Initialize the coupling
Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)
Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
"""Initialize the coupling
**Arguments**:
- x0: A batch of samples from p(x_0)
- x1: A batch of samples from p(x_1)
"""
self.batch_size = x0.shape[0]
self.x0 = x0
self.x1 = x1
self.logits = self.compute_logits()
assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self) -> Array
¤
Compute the logits for the coupling
Source code in generax/distributions/coupling.py
def compute_logits(self) -> Array:
"""Compute the logits for the coupling"""
return jnp.ones((self.batch_size, self.batch_size))/self.batch_size
generax.distributions.coupling.OTTCoupling (JointCoupling)
¤
Optimal transport coupling using the ott library. This class uses the sinkhorn solver to compute the optimal transport coupling.
Source code in generax/distributions/coupling.py
class OTTCoupling(JointCoupling):
"""Optimal transport coupling using the [ott library](https://ott-jax.readthedocs.io/en/latest/).
This class uses the sinkhorn solver to compute the optimal transport coupling.
"""
def compute_logits(self) -> Array:
"""Solve for the optimal transport couplings"""
# Create a point cloud object
geom = pointcloud.PointCloud(self.x0, self.x1)
# Define the loss function
ot_prob = linear_problem.LinearProblem(geom)
# Create a sinkhorn solver
solver = sinkhorn.Sinkhorn()
# Solve the OT problem
ot = solver(ot_prob)
# Return the coupling
mat = ot.matrix
return jnp.log(mat + 1e-8)
__init__(self, x0: Array, x1: Array)
¤
Initialize the coupling
Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)
Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
"""Initialize the coupling
**Arguments**:
- x0: A batch of samples from p(x_0)
- x1: A batch of samples from p(x_1)
"""
self.batch_size = x0.shape[0]
self.x0 = x0
self.x1 = x1
self.logits = self.compute_logits()
assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self) -> Array
¤
Solve for the optimal transport couplings
Source code in generax/distributions/coupling.py
def compute_logits(self) -> Array:
"""Solve for the optimal transport couplings"""
# Create a point cloud object
geom = pointcloud.PointCloud(self.x0, self.x1)
# Define the loss function
ot_prob = linear_problem.LinearProblem(geom)
# Create a sinkhorn solver
solver = sinkhorn.Sinkhorn()
# Solve the OT problem
ot = solver(ot_prob)
# Return the coupling
mat = ot.matrix
return jnp.log(mat + 1e-8)