Skip to content

Affine¤

generax.flows.affine.Shift (BijectiveTransform) ¤

This represents a shift transformation This is NICE https://arxiv.org/pdf/1410.8516.pdf when used in a coupling layer.

Attributes: - b: The shift parameter.

Source code in generax/flows/affine.py
class Shift(BijectiveTransform):
  """This represents a shift transformation
  This is NICE https://arxiv.org/pdf/1410.8516.pdf when used
  in a coupling layer.

  **Attributes**:
  - `b`: The shift parameter.
  """
  b: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    # Initialize the parameters randomly
    self.b = random.normal(key, shape=input_shape)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'x must be batched'
    mean, std = misc.mean_and_std(x, axis=0)

    # Initialize the parameters so that z will have
    # zero mean and unit variance
    b = mean

    # Turn the new parameters into a new module
    get_b = lambda tree: tree.b
    updated_layer = eqx.tree_at(get_b, self, b)

    return updated_layer

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      z = x - self.b
    else:
      z = x + self.b

    log_det = jnp.array(0.0)

    return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  # Initialize the parameters randomly
  self.b = random.normal(key, shape=input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'x must be batched'
  mean, std = misc.mean_and_std(x, axis=0)

  # Initialize the parameters so that z will have
  # zero mean and unit variance
  b = mean

  # Turn the new parameters into a new module
  get_b = lambda tree: tree.b
  updated_layer = eqx.tree_at(get_b, self, b)

  return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    z = x - self.b
  else:
    z = x + self.b

  log_det = jnp.array(0.0)

  return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.flows.affine.ShiftScale (BijectiveTransform) ¤

This represents a shift and scale transformation. This is RealNVP https://arxiv.org/pdf/1605.08803.pdf when used in a coupling layer.

Attributes: - s_unbounded: The unbounded scaling parameter. - b: The shift parameter.

Source code in generax/flows/affine.py
class ShiftScale(BijectiveTransform):
  """This represents a shift and scale transformation.
  This is RealNVP https://arxiv.org/pdf/1605.08803.pdf when used
  in a coupling layer.


  **Attributes**:
  - `s_unbounded`: The unbounded scaling parameter.
  - `b`: The shift parameter.
  """

  s_unbounded: Array
  b: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    # Initialize the parameters randomly
    self.s_unbounded, self.b = random.normal(key, shape=(2,) + input_shape)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'x must be batched'
    mean, std = misc.mean_and_std(x, axis=0)
    std += 1e-4

    # Initialize the parameters so that z will have
    # zero mean and unit variance
    b = mean
    s_unbounded = std - 1/std

    # Turn the new parameters into a new module
    get_b = lambda tree: tree.b
    get_s_unbounded = lambda tree: tree.s_unbounded
    updated_layer = eqx.tree_at(get_b, self, b)
    updated_layer = eqx.tree_at(get_s_unbounded, updated_layer, s_unbounded)

    return updated_layer

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    # s must be strictly positive
    s = misc.square_plus(self.s_unbounded, gamma=1.0)# + 1e-4
    log_s = jnp.log(s)

    if inverse == False:
      z = (x - self.b)/s
    else:
      z = x*s + self.b

    if inverse == False:
      log_det = -log_s.sum()
    else:
      log_det = log_s.sum()

    return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  # Initialize the parameters randomly
  self.s_unbounded, self.b = random.normal(key, shape=(2,) + input_shape)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'x must be batched'
  mean, std = misc.mean_and_std(x, axis=0)
  std += 1e-4

  # Initialize the parameters so that z will have
  # zero mean and unit variance
  b = mean
  s_unbounded = std - 1/std

  # Turn the new parameters into a new module
  get_b = lambda tree: tree.b
  get_s_unbounded = lambda tree: tree.s_unbounded
  updated_layer = eqx.tree_at(get_b, self, b)
  updated_layer = eqx.tree_at(get_s_unbounded, updated_layer, s_unbounded)

  return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  # s must be strictly positive
  s = misc.square_plus(self.s_unbounded, gamma=1.0)# + 1e-4
  log_s = jnp.log(s)

  if inverse == False:
    z = (x - self.b)/s
  else:
    z = x*s + self.b

  if inverse == False:
    log_det = -log_s.sum()
  else:
    log_det = log_s.sum()

  return z, log_det

generax.flows.affine.DenseLinear (BijectiveTransform) ¤

Multiply the last axis by a dense matrix. When applied to images, this is GLOW https://arxiv.org/pdf/1807.03039.pdf

Attributes: - W: The weight matrix

Source code in generax/flows/affine.py
class DenseLinear(BijectiveTransform):
  """Multiply the last axis by a dense matrix.  When applied to images,
  this is GLOW https://arxiv.org/pdf/1807.03039.pdf

  **Attributes**:
  - `W`: The weight matrix
  """

  W: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    dim = self.input_shape[-1]
    self.W = random.normal(key, shape=(dim, dim))
    self.W = misc.whiten(self.W)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      z = jnp.einsum('ij,...j->...i', self.W, x)
    else:
      W_inv = jnp.linalg.inv(self.W)
      z = jnp.einsum('ij,...j->...i', W_inv, x)

    # Need to multiply the log determinant by the number of times
    # that we're applying the transformation.
    if len(self.input_shape) > 1:
      dim_mult = np.prod(self.input_shape[:-1])
    else:
      dim_mult = 1
    log_det = jnp.linalg.slogdet(self.W)[1]*dim_mult

    if inverse:
      log_det *= -1

    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  dim = self.input_shape[-1]
  self.W = random.normal(key, shape=(dim, dim))
  self.W = misc.whiten(self.W)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.BijectiveTransform.__call__.

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    z = jnp.einsum('ij,...j->...i', self.W, x)
  else:
    W_inv = jnp.linalg.inv(self.W)
    z = jnp.einsum('ij,...j->...i', W_inv, x)

  # Need to multiply the log determinant by the number of times
  # that we're applying the transformation.
  if len(self.input_shape) > 1:
    dim_mult = np.prod(self.input_shape[:-1])
  else:
    dim_mult = 1
  log_det = jnp.linalg.slogdet(self.W)[1]*dim_mult

  if inverse:
    log_det *= -1

  return z, log_det

generax.flows.affine.DenseAffine (BijectiveTransform) ¤

Multiply the last axis by a dense matrix. When applied to images, this is GLOW https://arxiv.org/pdf/1807.03039.pdf

Attributes: - W: The weight matrix - b: The bias vector

Source code in generax/flows/affine.py
class DenseAffine(BijectiveTransform):
  """Multiply the last axis by a dense matrix.  When applied to images,
  this is GLOW https://arxiv.org/pdf/1807.03039.pdf

  **Attributes**:
  - `W`: The weight matrix
  - `b`: The bias vector
  """

  W: DenseLinear
  b: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    self.W = DenseLinear(input_shape=input_shape,
                         key=key,
                         **kwargs)
    self.b = jnp.zeros(input_shape)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'x must be batched'
    b = -jnp.mean(x, axis=0)
    return eqx.tree_at(lambda tree: tree.b, self, b)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      x = x + self.b
      z, log_det = self.W(x, y=y, inverse=False)
    else:
      z, log_det = self.W(x, y=y, inverse=True)
      z = z - self.b
    return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  self.W = DenseLinear(input_shape=input_shape,
                       key=key,
                       **kwargs)
  self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'x must be batched'
  b = -jnp.mean(x, axis=0)
  return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.BijectiveTransform.__call__.

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    x = x + self.b
    z, log_det = self.W(x, y=y, inverse=False)
  else:
    z, log_det = self.W(x, y=y, inverse=True)
    z = z - self.b
  return z, log_det

generax.flows.affine.CaleyOrthogonalMVP (BijectiveTransform) ¤

Caley transform parametrization of an orthogonal matrix. This performs a matrix vector product with an orthogonal matrix.

Attributes: - W: The weight matrix - b: The bias vector

Source code in generax/flows/affine.py
class CaleyOrthogonalMVP(BijectiveTransform):
  """Caley transform parametrization of an orthogonal matrix. This performs
  a matrix vector product with an orthogonal matrix.

  **Attributes**:
  - `W`: The weight matrix
  - `b`: The bias vector
  """

  W: Array
  b: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    dim = self.input_shape[-1]
    self.W = random.normal(key, shape=(dim, dim))
    self.b = jnp.zeros(input_shape)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    assert x.shape[1:] == self.input_shape, 'x must be batched'
    b = -jnp.mean(x, axis=0)
    return eqx.tree_at(lambda tree: tree.b, self, b)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    A = self.W - self.W.T
    dim = self.input_shape[-1]

    # So that we can multiply with channel dim of images
    @partial(jnp.vectorize, signature='(i,j),(j)->(i)')
    def matmul(A, x):
      return A@x

    if inverse == False:
      x += self.b
      IpA_inv = jnp.linalg.inv(jnp.eye(dim) + A)
      y = matmul(IpA_inv, x)
      z = y - matmul(A, y)
    else:
      ImA_inv = jnp.linalg.inv(jnp.eye(dim) - A)
      y = matmul(ImA_inv, x)
      z = y + matmul(A, y)
      z -= self.b

    log_det = jnp.zeros(1)
    return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  dim = self.input_shape[-1]
  self.W = random.normal(key, shape=(dim, dim))
  self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  assert x.shape[1:] == self.input_shape, 'x must be batched'
  b = -jnp.mean(x, axis=0)
  return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.BijectiveTransform.__call__.

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  A = self.W - self.W.T
  dim = self.input_shape[-1]

  # So that we can multiply with channel dim of images
  @partial(jnp.vectorize, signature='(i,j),(j)->(i)')
  def matmul(A, x):
    return A@x

  if inverse == False:
    x += self.b
    IpA_inv = jnp.linalg.inv(jnp.eye(dim) + A)
    y = matmul(IpA_inv, x)
    z = y - matmul(A, y)
  else:
    ImA_inv = jnp.linalg.inv(jnp.eye(dim) - A)
    y = matmul(ImA_inv, x)
    z = y + matmul(A, y)
    z -= self.b

  log_det = jnp.zeros(1)
  return z, log_det

generax.flows.affine.PLUAffine (BijectiveTransform) ¤

Multiply the last axis by a matrix that is parametrized using the LU decomposition. This is more efficient than the dense parametrization

Attributes: - A: The weight matrix components. The top half is the upper triangular matrix, and the bottom half is the lower triangular matrix and the diagonal is ignored. - b: The bias vector

Source code in generax/flows/affine.py
class PLUAffine(BijectiveTransform):
  """Multiply the last axis by a matrix that is parametrized using the LU decomposition.  This is more efficient
  than the dense parametrization

  **Attributes**:
  - `A`: The weight matrix components.  The top half is the upper triangular matrix, and the bottom half is the
          lower triangular matrix and the diagonal is ignored.
  - `b`: The bias vector
  """

  A: Array
  b: Array

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    # Initialize so that this will be approximately the identity matrix
    dim = input_shape[-1]
    self.A = random.normal(key, shape=(dim, dim))*0.01
    self.A = self.A.at[jnp.arange(dim),jnp.arange(dim)].set(1.0)

    self.b = jnp.zeros(input_shape)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'x must be batched'
    b = -jnp.mean(x, axis=0)
    return eqx.tree_at(lambda tree: tree.b, self, b)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    dim = x.shape[-1]
    mask = jnp.ones((dim, dim), dtype=bool)
    upper_mask = jnp.triu(mask)
    lower_mask = jnp.tril(mask, k=-1)

    if inverse == False:
      x += self.b
      z = jnp.einsum("ij,...j->...i", self.A*upper_mask, x)
      z = jnp.einsum("ij,...j->...i", self.A*lower_mask, z) + z
    else:
      # vmap in order to handle images
      L_solve_vmap = L_solve
      U_solve_vmap = U_solve_with_diag
      for _ in x.shape[:-1]:
        L_solve_vmap = jax.vmap(L_solve_vmap, in_axes=(None, 0))
        U_solve_vmap = jax.vmap(U_solve_vmap, in_axes=(None, 0))
      z = L_solve_vmap(self.A*lower_mask, x)
      z = U_solve_vmap(self.A*upper_mask, z)
      z -= self.b

    log_det = jnp.log(jnp.abs(jnp.diag(self.A))).sum()*misc.list_prod(x.shape[:-1])
    if inverse:
      log_det *= -1
    return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  # Initialize so that this will be approximately the identity matrix
  dim = input_shape[-1]
  self.A = random.normal(key, shape=(dim, dim))*0.01
  self.A = self.A.at[jnp.arange(dim),jnp.arange(dim)].set(1.0)

  self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'x must be batched'
  b = -jnp.mean(x, axis=0)
  return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Implements generax.flows.base.BijectiveTransform.__call__.

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  dim = x.shape[-1]
  mask = jnp.ones((dim, dim), dtype=bool)
  upper_mask = jnp.triu(mask)
  lower_mask = jnp.tril(mask, k=-1)

  if inverse == False:
    x += self.b
    z = jnp.einsum("ij,...j->...i", self.A*upper_mask, x)
    z = jnp.einsum("ij,...j->...i", self.A*lower_mask, z) + z
  else:
    # vmap in order to handle images
    L_solve_vmap = L_solve
    U_solve_vmap = U_solve_with_diag
    for _ in x.shape[:-1]:
      L_solve_vmap = jax.vmap(L_solve_vmap, in_axes=(None, 0))
      U_solve_vmap = jax.vmap(U_solve_vmap, in_axes=(None, 0))
    z = L_solve_vmap(self.A*lower_mask, x)
    z = U_solve_vmap(self.A*upper_mask, z)
    z -= self.b

  log_det = jnp.log(jnp.abs(jnp.diag(self.A))).sum()*misc.list_prod(x.shape[:-1])
  if inverse:
    log_det *= -1
  return z, log_det

generax.flows.affine.ConditionalOptionalTransport (TimeDependentBijectiveTransform) ¤

Given x1, compute f(t, x0) = tx1 + (1-t)x0. This is the optimal transport map between the two points. Used in flow matching https://arxiv.org/pdf/2210.02747.pdf

Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.

Attributes:

Source code in generax/flows/affine.py
class ConditionalOptionalTransport(TimeDependentBijectiveTransform):
  """Given x1, compute f(t, x0) = t*x1 + (1-t)*x0.  This is the optimal transport
  map between the two points.  Used in flow matching https://arxiv.org/pdf/2210.02747.pdf

  Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.

  **Attributes**:
  """

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               t: Array,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `t`: The time point.
    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to invert the transformation (0 -> t)

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    if y is None:
      raise ValueError(f'Expected a conditional input')
    if y.shape != x.shape:
      raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')

    x1 = y
    if inverse:
      x0 = x
      xt = (1 - t)*x0 + t*x1
      log_det = jnp.log(1 - t)
      return xt, log_det
    else:
      xt = x
      x0 = (xt - t*x1)/(1 - t)
      log_det = -jnp.log(1 - t)
      return x0, log_det

  def vector_field(self,
                   t: Array,
                   xt: Array,
                   y: Optional[Array] = None,
                   **kwargs) -> Array:
    """The vector field that samples evolve on as t changes

    **Arguments**:

    - `t`: Time.
    - `x0`: A point in the base space.
    - `y`: The conditioning information.

    **Returns**:
    The vector field that samples evolve on at (t, x).
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    if y is None:
      raise ValueError(f'Expected a conditional input')
    if y.shape != x.shape:
      raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
    return y - x
data_dependent_init(self, t: Array, xt: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.TimeDependentBijectiveTransform.data_dependent_init.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        t: Array,
                        xt: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `t`: Time.
  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.TimeDependentBijectiveTransform.inverse.

Source code in generax/flows/affine.py
def inverse(self,
            t: Array,
            x0: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (xt, log_det)
  """
  return self(t, x0, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, t: Array, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • t: The time point.
  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to invert the transformation (0 -> t)

Returns: (z, log_det)

Source code in generax/flows/affine.py
def __call__(self,
             t: Array,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `t`: The time point.
  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to invert the transformation (0 -> t)

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  if y is None:
    raise ValueError(f'Expected a conditional input')
  if y.shape != x.shape:
    raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')

  x1 = y
  if inverse:
    x0 = x
    xt = (1 - t)*x0 + t*x1
    log_det = jnp.log(1 - t)
    return xt, log_det
  else:
    xt = x
    x0 = (xt - t*x1)/(1 - t)
    log_det = -jnp.log(1 - t)
    return x0, log_det
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

The vector field that samples evolve on as t changes

Arguments:

  • t: Time.
  • x0: A point in the base space.
  • y: The conditioning information.

Returns: The vector field that samples evolve on at (t, x).

Source code in generax/flows/affine.py
def vector_field(self,
                 t: Array,
                 xt: Array,
                 y: Optional[Array] = None,
                 **kwargs) -> Array:
  """The vector field that samples evolve on as t changes

  **Arguments**:

  - `t`: Time.
  - `x0`: A point in the base space.
  - `y`: The conditioning information.

  **Returns**:
  The vector field that samples evolve on at (t, x).
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  if y is None:
    raise ValueError(f'Expected a conditional input')
  if y.shape != x.shape:
    raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
  return y - x

generax.flows.affine.TallDenseLinear (InjectiveTransform) ¤

Matrix vector product with a tall matrix.

Attributes: - W: The weight matrix

Source code in generax/flows/affine.py
class TallDenseLinear(InjectiveTransform):
  """Matrix vector product with a tall matrix.

  **Attributes**:
  - `W`: The weight matrix
  """

  W: Array

  def __init__(self,
               input_shape: Tuple[int],
               output_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    assert len(input_shape) == 1, 'Only implemented for 1d data'
    super().__init__(input_shape=input_shape,
                     output_shape=output_shape,
                     **kwargs)

    dim_in = self.input_shape[-1]
    dim_out = self.output_shape[-1]
    self.W = random.normal(key, shape=(dim_in, dim_out))
    self.W = misc.whiten(self.W)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    if inverse == False:
      assert x.shape == self.input_shape, 'Only works on unbatched data'
    else:
      assert x.shape == self.output_shape, 'Only works on unbatched data'

    if inverse == False:
      W_pinv = jnp.linalg.pinv(self.W)
      z = jnp.einsum('ij,...j->...i', W_pinv, x)
    else:
      z = jnp.einsum('ij,...j->...i', self.W, x)

    log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]

    if inverse:
      log_det *= -1

    return z, log_det

  def log_determinant(self,
                      z: Array,
                      **kwargs) -> Array:
    """Compute -0.5*log(det(J^TJ))

    **Arguments**:

    - `z`: An element of the base space

    **Returns**:
    The log determinant of (J^TJ)^0.5
    """
    log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]
    return log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/affine.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/affine.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.InjectiveTransform.project.

Source code in generax/flows/affine.py
def project(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Project a point onto the image of the transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  z
  """
  z, _ = self(x, y=y, **kwargs)
  x_proj, _ = self(z, y=y, inverse=True, **kwargs)
  return x_proj
__init__(self, input_shape: Tuple[int], output_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/affine.py
def __init__(self,
             input_shape: Tuple[int],
             output_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  assert len(input_shape) == 1, 'Only implemented for 1d data'
  super().__init__(input_shape=input_shape,
                   output_shape=output_shape,
                   **kwargs)

  dim_in = self.input_shape[-1]
  dim_out = self.output_shape[-1]
  self.W = random.normal(key, shape=(dim_in, dim_out))
  self.W = misc.whiten(self.W)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/affine.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  if inverse == False:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
  else:
    assert x.shape == self.output_shape, 'Only works on unbatched data'

  if inverse == False:
    W_pinv = jnp.linalg.pinv(self.W)
    z = jnp.einsum('ij,...j->...i', W_pinv, x)
  else:
    z = jnp.einsum('ij,...j->...i', self.W, x)

  log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]

  if inverse:
    log_det *= -1

  return z, log_det