Skip to content

Couplings¤

generax.distributions.coupling.JointCoupling ¤

Given two batches of samples from two distributions, this will compute a discrete distribution as done in multisample flow matching

\[q(x_0,x_1) = \sum_{i,j}\pi_{i,j}\delta(x_0 - x_0^i)\delta(x_1 - x_1^j))\]
Source code in generax/distributions/coupling.py
class JointCoupling(eqx.Module, ABC):
  """Given two batches of samples from two distributions, this
  will compute a discrete distribution as done in [multisample flow matching](https://arxiv.org/pdf/2304.14772.pdf)

  $$q(x_0,x_1) = \sum_{i,j}\pi_{i,j}\delta(x_0 - x_0^i)\delta(x_1 - x_1^j))$$

  """
  batch_size: int
  x0: Array
  x1: Array
  logits: Array

  def __init__(self, x0: Array, x1: Array):
    """Initialize the coupling

    **Arguments**:
      - x0: A batch of samples from p(x_0)
      - x1: A batch of samples from p(x_1)
    """
    self.batch_size = x0.shape[0]
    self.x0 = x0
    self.x1 = x1
    self.logits = self.compute_logits()
    assert self.logits.shape == (self.batch_size, self.batch_size)

  @abstractmethod
  def compute_logits(self):
    """Compute $\log \pi_{i,j}$"""
    pass

  def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
    """Resample from the coupling

    **Arguments**:
      - rng_key: The random number generator key

    **Returns**:
      A sample from q(x_0|x_1)
    """
    idx = jax.random.categorical(rng_key, self.logits, axis=0)
    return self.x0[idx]
__init__(self, x0: Array, x1: Array) ¤

Initialize the coupling

Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)

Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
  """Initialize the coupling

  **Arguments**:
    - x0: A batch of samples from p(x_0)
    - x1: A batch of samples from p(x_1)
  """
  self.batch_size = x0.shape[0]
  self.x0 = x0
  self.x1 = x1
  self.logits = self.compute_logits()
  assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self) abstractmethod ¤

Compute \(\log \pi_{i,j}\)

Source code in generax/distributions/coupling.py
@abstractmethod
def compute_logits(self):
  """Compute $\log \pi_{i,j}$"""
  pass
sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array ¤

Resample from the coupling

Arguments: - rng_key: The random number generator key

Returns: A sample from q(x_0|x_1)

Source code in generax/distributions/coupling.py
def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
  """Resample from the coupling

  **Arguments**:
    - rng_key: The random number generator key

  **Returns**:
    A sample from q(x_0|x_1)
  """
  idx = jax.random.categorical(rng_key, self.logits, axis=0)
  return self.x0[idx]

generax.distributions.coupling.UniformCoupling (JointCoupling) ¤

This is a uniform coupling between two distributions

Source code in generax/distributions/coupling.py
class UniformCoupling(JointCoupling):
  """This is a uniform coupling between two distributions"""

  def compute_logits(self) -> Array:
    """Compute the logits for the coupling"""
    return jnp.ones((self.batch_size, self.batch_size))/self.batch_size

  def sample_x0_given_x1(self, rng_key: PRNGKeyArray) -> Array:
    return self.x0
__init__(self, x0: Array, x1: Array) ¤

Initialize the coupling

Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)

Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
  """Initialize the coupling

  **Arguments**:
    - x0: A batch of samples from p(x_0)
    - x1: A batch of samples from p(x_1)
  """
  self.batch_size = x0.shape[0]
  self.x0 = x0
  self.x1 = x1
  self.logits = self.compute_logits()
  assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self) -> Array ¤

Compute the logits for the coupling

Source code in generax/distributions/coupling.py
def compute_logits(self) -> Array:
  """Compute the logits for the coupling"""
  return jnp.ones((self.batch_size, self.batch_size))/self.batch_size

generax.distributions.coupling.OTTCoupling (JointCoupling) ¤

Optimal transport coupling using the ott library. This class uses the sinkhorn solver to compute the optimal transport coupling.

Source code in generax/distributions/coupling.py
class OTTCoupling(JointCoupling):
  """Optimal transport coupling using the [ott library](https://ott-jax.readthedocs.io/en/latest/).
  This class uses the sinkhorn solver to compute the optimal transport coupling.

  """
  def compute_logits(self) -> Array:
    """Solve for the optimal transport couplings"""

    # Create a point cloud object
    geom = pointcloud.PointCloud(self.x0, self.x1)

    # Define the loss function
    ot_prob = linear_problem.LinearProblem(geom)

    # Create a sinkhorn solver
    solver = sinkhorn.Sinkhorn()

    # Solve the OT problem
    ot = solver(ot_prob)

    # Return the coupling
    mat = ot.matrix
    return jnp.log(mat + 1e-8)
__init__(self, x0: Array, x1: Array) ¤

Initialize the coupling

Arguments: - x0: A batch of samples from p(x_0) - x1: A batch of samples from p(x_1)

Source code in generax/distributions/coupling.py
def __init__(self, x0: Array, x1: Array):
  """Initialize the coupling

  **Arguments**:
    - x0: A batch of samples from p(x_0)
    - x1: A batch of samples from p(x_1)
  """
  self.batch_size = x0.shape[0]
  self.x0 = x0
  self.x1 = x1
  self.logits = self.compute_logits()
  assert self.logits.shape == (self.batch_size, self.batch_size)
compute_logits(self) -> Array ¤

Solve for the optimal transport couplings

Source code in generax/distributions/coupling.py
def compute_logits(self) -> Array:
  """Solve for the optimal transport couplings"""

  # Create a point cloud object
  geom = pointcloud.PointCloud(self.x0, self.x1)

  # Define the loss function
  ot_prob = linear_problem.LinearProblem(geom)

  # Create a sinkhorn solver
  solver = sinkhorn.Sinkhorn()

  # Solve the OT problem
  ot = solver(ot_prob)

  # Return the coupling
  mat = ot.matrix
  return jnp.log(mat + 1e-8)