Skip to content

Coupling¤

generax.flows.coupling.RavelParameters ¤

Flatten and concatenate the parameters of a eqx.Module

Source code in generax/flows/coupling.py
class RavelParameters(eqx.Module):
  """Flatten and concatenate the parameters of a eqx.Module
  """

  shapes_and_sizes: Sequence[Tuple[Tuple[int], int]] = eqx.field(static=True)
  flat_params_size: Tuple[int] = eqx.field(static=True)
  static: Any = eqx.field(static=True)
  treedef: Any = eqx.field(static=True)
  indices: np.ndarray = eqx.field(static=True)

  def __init__(self, module):

    # Split the parameters into dynamic and static
    params, self.static = eqx.partition(module, eqx.is_array)

    # Flatten the parameters so that we can extract its sizes
    leaves, self.treedef = jax.tree_util.tree_flatten(params)

    # Get the shape and size of each leaf
    self.shapes_and_sizes = [(leaf.shape, leaf.size) for leaf in leaves]

    # Flatten the parameters
    flat_params = jnp.concatenate([leaf.ravel() for leaf in leaves])

    # Keep track of the size of the flattened parameters
    self.flat_params_size = flat_params.size

    # Keep track of the split points for each paramter in the flattened array
    self.indices = np.cumsum(np.array([0] + [size for _, size in self.shapes_and_sizes]))

  def flatten_params(self, module: eqx.Module) -> Array:
    # Split the parameters into dynamic and static
    params, _ = eqx.partition(module, eqx.is_array)

    # Flatten the parameters so that we can extract its sizes
    leaves, _ = jax.tree_util.tree_flatten(params)

    # Flatten the parameters
    flat_params = jnp.concatenate([leaf.ravel() for leaf in leaves])
    return flat_params

  def __call__(self, flat_params: Array) -> eqx.Module:
    flat_params = flat_params.ravel() # Flatten the parameters completely
    leaves = []
    for i, (shape, size) in enumerate(self.shapes_and_sizes):

      # Extract each leaf from the flattened parameters and reshape it
      buffer = flat_params[self.indices[i]: self.indices[i + 1]]
      if buffer.size != misc.list_prod(shape):
        raise ValueError(f'Expected total size of {misc.list_prod(shape)} but got {buffer.size}')
      leaf = buffer.reshape(shape)
      leaves.append(leaf)

    # Turn the leaves back into a tree
    params = jax.tree_util.tree_unflatten(self.treedef, leaves)

    return eqx.combine(params, self.static)
__init__(self, module) ¤
Source code in generax/flows/coupling.py
def __init__(self, module):

  # Split the parameters into dynamic and static
  params, self.static = eqx.partition(module, eqx.is_array)

  # Flatten the parameters so that we can extract its sizes
  leaves, self.treedef = jax.tree_util.tree_flatten(params)

  # Get the shape and size of each leaf
  self.shapes_and_sizes = [(leaf.shape, leaf.size) for leaf in leaves]

  # Flatten the parameters
  flat_params = jnp.concatenate([leaf.ravel() for leaf in leaves])

  # Keep track of the size of the flattened parameters
  self.flat_params_size = flat_params.size

  # Keep track of the split points for each paramter in the flattened array
  self.indices = np.cumsum(np.array([0] + [size for _, size in self.shapes_and_sizes]))
__call__(self, flat_params: Array) -> Module ¤

Call self as a function.

Source code in generax/flows/coupling.py
def __call__(self, flat_params: Array) -> eqx.Module:
  flat_params = flat_params.ravel() # Flatten the parameters completely
  leaves = []
  for i, (shape, size) in enumerate(self.shapes_and_sizes):

    # Extract each leaf from the flattened parameters and reshape it
    buffer = flat_params[self.indices[i]: self.indices[i + 1]]
    if buffer.size != misc.list_prod(shape):
      raise ValueError(f'Expected total size of {misc.list_prod(shape)} but got {buffer.size}')
    leaf = buffer.reshape(shape)
    leaves.append(leaf)

  # Turn the leaves back into a tree
  params = jax.tree_util.tree_unflatten(self.treedef, leaves)

  return eqx.combine(params, self.static)

generax.flows.coupling.Coupling (BijectiveTransform) ¤

Parametrize a flow over half of the inputs using the other half. The conditioning network will be fixed

# Example of intended usage:

def initialize_scale(transform_input_shape, key):
  return ShiftScale(input_shape=transform_input_shape,
                    key=key,
                    **kwargs)

def initialize_network(net_input_shape, net_output_size, key):
  return ResNet(input_shape=net_input_shape,
                out_size=net_output_size,
                key=key,
                **kwargs)

layer = Coupling(transform_init=initialize_scale,
                 net_init=initialize_network,
                 input_shape=input_shape,
                 cond_shape=cond_shape,
                 key=key,
                 reverse_conditioning=True,
                 split_dim=1)

z, log_det = layer(x, y)

Attributes: - params_to_transform: A module that turns an array of parameters into an eqx.Module. - scale: A scalar that we'll use to start with small parameter values - net: The neural network to use.

Source code in generax/flows/coupling.py
class Coupling(BijectiveTransform):
  """Parametrize a flow over half of the inputs using the other half.
  The conditioning network will be fixed

  ```python
  # Example of intended usage:

  def initialize_scale(transform_input_shape, key):
    return ShiftScale(input_shape=transform_input_shape,
                      key=key,
                      **kwargs)

  def initialize_network(net_input_shape, net_output_size, key):
    return ResNet(input_shape=net_input_shape,
                  out_size=net_output_size,
                  key=key,
                  **kwargs)

  layer = Coupling(transform_init=initialize_scale,
                   net_init=initialize_network,
                   input_shape=input_shape,
                   cond_shape=cond_shape,
                   key=key,
                   reverse_conditioning=True,
                   split_dim=1)

  z, log_det = layer(x, y)
  ```

  **Attributes**:
  - `params_to_transform`: A module that turns an array of parameters into an eqx.Module.
  - `scale`: A scalar that we'll use to start with small parameter values
  - `net`: The neural network to use.
  """

  net: eqx.Module
  scale: Array
  params_to_transform: RavelParameters

  split_dim: Optional[int] = eqx.field(static=True)
  reverse_conditioning: bool = eqx.field(static=True)

  def __init__(self,
               transform_init: Callable[[Tuple[int]],BijectiveTransform],
               net_init: Callable[[Tuple[int],int],eqx.Module],
               input_shape: Tuple[int],
               cond_shape: Optional[Tuple[int]] = None,
               split_dim: Optional[int] = None,
               reverse_conditioning: Optional[bool] = False,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `transform`: The bijective transformation to use.
    - `net`: The neural network to generate the transform parameters.
    - `input_shape`: The shape of the input
    - `cond_shape`: The shape of the conditioning information
    - `split_dim`: The number of dimension to split the last axis on.  If `None`, defaults to `dim//2`.
    - `reverse_conditioning`: If `True`, condition on the first part of the input instead of the second part.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     cond_shape=cond_shape,
                     **kwargs)

    k1, k2 = random.split(key, 2)

    self.split_dim = split_dim if split_dim is not None else input_shape[-1]//2
    self.reverse_conditioning = reverse_conditioning

    # Get the shapes of the input to the transform and the network
    transform_input_shape, net_input_shape = self.get_split_shapes(input_shape)
    transform = transform_init(transform_input_shape, key=k1)

    net_output_size = self.get_net_output_shapes(input_shape, transform)
    net = net_init(net_input_shape, net_output_size, key=k2)

    self.net = net

    # Use this to turn an eqx module into an array and vice-versa
    self.params_to_transform = RavelParameters(transform)

    # Also initialize the parameters to be close to 0
    self.scale = random.normal(key, (1,))*0.01

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'
    x1, x2 = self.split(x)
    net = self.net.data_dependent_init(x2, y=y, key=key)

    # Turn the new parameters into a new module
    get_net = lambda tree: tree.net
    updated_layer = eqx.tree_at(get_net, self, net)
    return updated_layer

  def split(self, x: Array) -> Tuple[Array, Array]:
    """Split the input into two halves."""
    x1, x2 = x[..., :self.split_dim], x[..., self.split_dim:]
    if self.reverse_conditioning:
      return x2, x1
    return x1, x2

  def combine(self, x1: Array, x2: Array) -> Array:
    """Combine the two halves of the input."""
    if self.reverse_conditioning:
      return jnp.concatenate([x2, x1], axis=-1)
    return jnp.concatenate([x1, x2], axis=-1)

  def get_split_shapes(self,
                       input_shape: Tuple[int]) -> Tuple[Tuple[int]]:
    x1_dim, x2_dim = self.split_dim, input_shape[-1] - self.split_dim
    x1_shape = input_shape[:-1] + (x1_dim,)
    x2_shape = input_shape[:-1] + (x2_dim,)
    if self.reverse_conditioning:
      return x2_shape, x1_shape
    return x1_shape, x2_shape

  def get_net_output_shapes(self,
                            input_shape: Tuple[int],
                            transform: BijectiveTransform) -> Tuple[Tuple[int],int]:
    """
    **Arguments**:
    - `input_shape`: The shape of the input
    - `transform`: The bijective transformation to use.

    **Returns**:
    - `net_output_size`: The size of the output of the neural network.  This is a single integer
                         because the network is expected to produce a single vector.
    """
    x1_shape, x2_shape = self.get_split_shapes(input_shape)
    if x1_shape != transform.input_shape:
      raise ValueError(f'The transform {transform} needs to have an input shape equal to {x1_shape}.  Use `get_input_shapes` to get this shape.')
    params_to_transform = RavelParameters(transform)
    net_output_size = params_to_transform.flat_params_size
    return net_output_size

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    # Split the input into two halves
    x1, x2 = self.split(x)
    params = self.net(x2, y=y, **kwargs)
    params *= self.scale
    assert params.size == self.params_to_transform.flat_params_size

    # Apply the transformation to x1 given x2
    transform = self.params_to_transform(params)
    z1, log_det = transform(x1, y=y, inverse=inverse, **kwargs)

    z = self.combine(z1, x2)
    return z, log_det
__init__(self, transform_init: Callable[[Tuple[int]], BijectiveTransform], net_init: Callable[[Tuple[int], int], equinox._module.Module], input_shape: Tuple[int], cond_shape: Optional[Tuple[int]] = None, split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • transform: The bijective transformation to use.
  • net: The neural network to generate the transform parameters.
  • input_shape: The shape of the input
  • cond_shape: The shape of the conditioning information
  • split_dim: The number of dimension to split the last axis on. If None, defaults to dim//2.
  • reverse_conditioning: If True, condition on the first part of the input instead of the second part.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/coupling.py
def __init__(self,
             transform_init: Callable[[Tuple[int]],BijectiveTransform],
             net_init: Callable[[Tuple[int],int],eqx.Module],
             input_shape: Tuple[int],
             cond_shape: Optional[Tuple[int]] = None,
             split_dim: Optional[int] = None,
             reverse_conditioning: Optional[bool] = False,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `transform`: The bijective transformation to use.
  - `net`: The neural network to generate the transform parameters.
  - `input_shape`: The shape of the input
  - `cond_shape`: The shape of the conditioning information
  - `split_dim`: The number of dimension to split the last axis on.  If `None`, defaults to `dim//2`.
  - `reverse_conditioning`: If `True`, condition on the first part of the input instead of the second part.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   cond_shape=cond_shape,
                   **kwargs)

  k1, k2 = random.split(key, 2)

  self.split_dim = split_dim if split_dim is not None else input_shape[-1]//2
  self.reverse_conditioning = reverse_conditioning

  # Get the shapes of the input to the transform and the network
  transform_input_shape, net_input_shape = self.get_split_shapes(input_shape)
  transform = transform_init(transform_input_shape, key=k1)

  net_output_size = self.get_net_output_shapes(input_shape, transform)
  net = net_init(net_input_shape, net_output_size, key=k2)

  self.net = net

  # Use this to turn an eqx module into an array and vice-versa
  self.params_to_transform = RavelParameters(transform)

  # Also initialize the parameters to be close to 0
  self.scale = random.normal(key, (1,))*0.01
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/coupling.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/coupling.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'
  x1, x2 = self.split(x)
  net = self.net.data_dependent_init(x2, y=y, key=key)

  # Turn the new parameters into a new module
  get_net = lambda tree: tree.net
  updated_layer = eqx.tree_at(get_net, self, net)
  return updated_layer
split(self, x: Array) -> Tuple[Array, Array] ¤

Split the input into two halves.

Source code in generax/flows/coupling.py
def split(self, x: Array) -> Tuple[Array, Array]:
  """Split the input into two halves."""
  x1, x2 = x[..., :self.split_dim], x[..., self.split_dim:]
  if self.reverse_conditioning:
    return x2, x1
  return x1, x2
get_split_shapes(self, input_shape: Tuple[int]) -> Tuple[Tuple[int]] ¤
Source code in generax/flows/coupling.py
def get_split_shapes(self,
                     input_shape: Tuple[int]) -> Tuple[Tuple[int]]:
  x1_dim, x2_dim = self.split_dim, input_shape[-1] - self.split_dim
  x1_shape = input_shape[:-1] + (x1_dim,)
  x2_shape = input_shape[:-1] + (x2_dim,)
  if self.reverse_conditioning:
    return x2_shape, x1_shape
  return x1_shape, x2_shape
get_net_output_shapes(self, input_shape: Tuple[int], transform: BijectiveTransform) -> Tuple[Tuple[int], int] ¤

Arguments: - input_shape: The shape of the input - transform: The bijective transformation to use.

Returns: - net_output_size: The size of the output of the neural network. This is a single integer because the network is expected to produce a single vector.

Source code in generax/flows/coupling.py
def get_net_output_shapes(self,
                          input_shape: Tuple[int],
                          transform: BijectiveTransform) -> Tuple[Tuple[int],int]:
  """
  **Arguments**:
  - `input_shape`: The shape of the input
  - `transform`: The bijective transformation to use.

  **Returns**:
  - `net_output_size`: The size of the output of the neural network.  This is a single integer
                       because the network is expected to produce a single vector.
  """
  x1_shape, x2_shape = self.get_split_shapes(input_shape)
  if x1_shape != transform.input_shape:
    raise ValueError(f'The transform {transform} needs to have an input shape equal to {x1_shape}.  Use `get_input_shapes` to get this shape.')
  params_to_transform = RavelParameters(transform)
  net_output_size = params_to_transform.flat_params_size
  return net_output_size
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.BijectiveTransform.__call__.

Source code in generax/flows/coupling.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  # Split the input into two halves
  x1, x2 = self.split(x)
  params = self.net(x2, y=y, **kwargs)
  params *= self.scale
  assert params.size == self.params_to_transform.flat_params_size

  # Apply the transformation to x1 given x2
  transform = self.params_to_transform(params)
  z1, log_det = transform(x1, y=y, inverse=inverse, **kwargs)

  z = self.combine(z1, x2)
  return z, log_det

generax.flows.coupling.TimeDependentCoupling (Coupling, TimeDependentBijectiveTransform) ¤

Time dependent coupling transform. At t=0, this will pass parameters of 0s to the transform. ```

Attributes: - params_to_transform: A module that turns an array of parameters into an eqx.Module. - scale: A scalar that we'll use to start with small parameter values - net: The neural network to use.

Source code in generax/flows/coupling.py
class TimeDependentCoupling(Coupling, TimeDependentBijectiveTransform):
  """Time dependent coupling transform.  At t=0, this will pass parameters of 0s
  to the transform.
  ```

  **Attributes**:
  - `params_to_transform`: A module that turns an array of parameters into an eqx.Module.
  - `scale`: A scalar that we'll use to start with small parameter values
  - `net`: The neural network to use.
  """

  def data_dependent_init(self,
                          t: Array,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `t`: The time to initialize the parameters with.
    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'
    x1, x2 = self.split(x)
    net = self.net.data_dependent_init(t, x2, y=y, key=key)

    # Turn the new parameters into a new module
    def get_net(tree): return tree.net
    updated_layer = eqx.tree_at(get_net, self, net)
    return updated_layer

  def __call__(self,
               t: Array,
               xt: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `xt`: The input to the transformation.  If inverse=True, then should be x0
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (x0, log_det)
    """
    assert xt.shape == self.input_shape, 'Only works on unbatched data'

    # Split the input into two halves
    x1, x2 = self.split(xt)
    params = self.net(t, x2, y=y, **kwargs)
    params *= self.scale*t
    assert params.size == self.params_to_transform.flat_params_size

    # Apply the transformation to x1 given x2
    transform = self.params_to_transform(params)
    z1, log_det = transform(x1, y=y, inverse=inverse, **kwargs)

    z = self.combine(z1, x2)
    return z, log_det
__init__(self, transform_init: Callable[[Tuple[int]], BijectiveTransform], net_init: Callable[[Tuple[int], int], equinox._module.Module], input_shape: Tuple[int], cond_shape: Optional[Tuple[int]] = None, split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • transform: The bijective transformation to use.
  • net: The neural network to generate the transform parameters.
  • input_shape: The shape of the input
  • cond_shape: The shape of the conditioning information
  • split_dim: The number of dimension to split the last axis on. If None, defaults to dim//2.
  • reverse_conditioning: If True, condition on the first part of the input instead of the second part.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/coupling.py
def __init__(self,
             transform_init: Callable[[Tuple[int]],BijectiveTransform],
             net_init: Callable[[Tuple[int],int],eqx.Module],
             input_shape: Tuple[int],
             cond_shape: Optional[Tuple[int]] = None,
             split_dim: Optional[int] = None,
             reverse_conditioning: Optional[bool] = False,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `transform`: The bijective transformation to use.
  - `net`: The neural network to generate the transform parameters.
  - `input_shape`: The shape of the input
  - `cond_shape`: The shape of the conditioning information
  - `split_dim`: The number of dimension to split the last axis on.  If `None`, defaults to `dim//2`.
  - `reverse_conditioning`: If `True`, condition on the first part of the input instead of the second part.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   cond_shape=cond_shape,
                   **kwargs)

  k1, k2 = random.split(key, 2)

  self.split_dim = split_dim if split_dim is not None else input_shape[-1]//2
  self.reverse_conditioning = reverse_conditioning

  # Get the shapes of the input to the transform and the network
  transform_input_shape, net_input_shape = self.get_split_shapes(input_shape)
  transform = transform_init(transform_input_shape, key=k1)

  net_output_size = self.get_net_output_shapes(input_shape, transform)
  net = net_init(net_input_shape, net_output_size, key=k2)

  self.net = net

  # Use this to turn an eqx module into an array and vice-versa
  self.params_to_transform = RavelParameters(transform)

  # Also initialize the parameters to be close to 0
  self.scale = random.normal(key, (1,))*0.01
split(self, x: Array) -> Tuple[Array, Array] ¤

Inherited from generax.flows.coupling.Coupling.split.

Source code in generax/flows/coupling.py
def split(self, x: Array) -> Tuple[Array, Array]:
  """Split the input into two halves."""
  x1, x2 = x[..., :self.split_dim], x[..., self.split_dim:]
  if self.reverse_conditioning:
    return x2, x1
  return x1, x2
get_split_shapes(self, input_shape: Tuple[int]) -> Tuple[Tuple[int]] ¤

Inherited from generax.flows.coupling.Coupling.get_split_shapes.

Source code in generax/flows/coupling.py
def get_split_shapes(self,
                     input_shape: Tuple[int]) -> Tuple[Tuple[int]]:
  x1_dim, x2_dim = self.split_dim, input_shape[-1] - self.split_dim
  x1_shape = input_shape[:-1] + (x1_dim,)
  x2_shape = input_shape[:-1] + (x2_dim,)
  if self.reverse_conditioning:
    return x2_shape, x1_shape
  return x1_shape, x2_shape
get_net_output_shapes(self, input_shape: Tuple[int], transform: BijectiveTransform) -> Tuple[Tuple[int], int] ¤

Inherited from generax.flows.coupling.Coupling.get_net_output_shapes.

Source code in generax/flows/coupling.py
def get_net_output_shapes(self,
                          input_shape: Tuple[int],
                          transform: BijectiveTransform) -> Tuple[Tuple[int],int]:
  """
  **Arguments**:
  - `input_shape`: The shape of the input
  - `transform`: The bijective transformation to use.

  **Returns**:
  - `net_output_size`: The size of the output of the neural network.  This is a single integer
                       because the network is expected to produce a single vector.
  """
  x1_shape, x2_shape = self.get_split_shapes(input_shape)
  if x1_shape != transform.input_shape:
    raise ValueError(f'The transform {transform} needs to have an input shape equal to {x1_shape}.  Use `get_input_shapes` to get this shape.')
  params_to_transform = RavelParameters(transform)
  net_output_size = params_to_transform.flat_params_size
  return net_output_size
data_dependent_init(self, t: Array, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • t: The time to initialize the parameters with.
  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/coupling.py
def data_dependent_init(self,
                        t: Array,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `t`: The time to initialize the parameters with.
  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'
  x1, x2 = self.split(x)
  net = self.net.data_dependent_init(t, x2, y=y, key=key)

  # Turn the new parameters into a new module
  def get_net(tree): return tree.net
  updated_layer = eqx.tree_at(get_net, self, net)
  return updated_layer
__call__(self, t: Array, xt: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.TimeDependentBijectiveTransform.__call__.

Source code in generax/flows/coupling.py
def __call__(self,
             t: Array,
             xt: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `xt`: The input to the transformation.  If inverse=True, then should be x0
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (x0, log_det)
  """
  assert xt.shape == self.input_shape, 'Only works on unbatched data'

  # Split the input into two halves
  x1, x2 = self.split(xt)
  params = self.net(t, x2, y=y, **kwargs)
  params *= self.scale*t
  assert params.size == self.params_to_transform.flat_params_size

  # Apply the transformation to x1 given x2
  transform = self.params_to_transform(params)
  z1, log_det = transform(x1, y=y, inverse=inverse, **kwargs)

  z = self.combine(z1, x2)
  return z, log_det
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.TimeDependentBijectiveTransform.inverse.

Source code in generax/flows/coupling.py
def inverse(self,
            t: Array,
            x0: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (xt, log_det)
  """
  return self(t, x0, y=y, inverse=True, **kwargs)
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.TimeDependentBijectiveTransform.vector_field.

Source code in generax/flows/coupling.py
def vector_field(self,
                 t: Array,
                 xt: Array,
                 y: Optional[Array] = None,
                 **kwargs) -> Array:
  """The vector field that samples evolve on as t changes

  **Arguments**:

  - `t`: Time.
  - `xt`: A point in the data space.
  - `y`: The conditioning information.

  **Returns**:
  `return vt`
  """
  x0 = self.to_base_space(t, xt, y=y, **kwargs)
  def ft(t):
    return self.to_data_space(t, x0, y=y, **kwargs)
  return jax.jvp(ft, (t,), (jnp.ones_like(t),))[1]