Skip to content

Reshape¤

generax.flows.reshape.Flatten (BijectiveTransform) ¤

Flatten

Source code in generax/flows/reshape.py
class Flatten(BijectiveTransform):
  """Flatten
  """

  def __init__(self,
               input_shape: Tuple[int],
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    The transformed input and 0
    """
    log_det = jnp.array(0.0)
    if inverse == False:
      assert x.shape == self.input_shape
      return x.ravel(), log_det
    else:
      return x.reshape(self.input_shape), log_det
__init__(self, input_shape: Tuple[int], **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/reshape.py
def __init__(self,
             input_shape: Tuple[int],
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: The transformed input and 0

Source code in generax/flows/reshape.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  The transformed input and 0
  """
  log_det = jnp.array(0.0)
  if inverse == False:
    assert x.shape == self.input_shape
    return x.ravel(), log_det
  else:
    return x.reshape(self.input_shape), log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/reshape.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/reshape.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.flows.reshape.Reverse (BijectiveTransform) ¤

Reverse an input

Source code in generax/flows/reshape.py
class Reverse(BijectiveTransform):
  """Reverse an input
  """

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    The transformed input and 0
    """
    assert x.shape == self.input_shape
    z = x[..., ::-1]
    log_det = jnp.array(0.0)
    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/reshape.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/reshape.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: The transformed input and 0

Source code in generax/flows/reshape.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  The transformed input and 0
  """
  assert x.shape == self.input_shape
  z = x[..., ::-1]
  log_det = jnp.array(0.0)
  return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/reshape.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.flows.reshape.Checkerboard (BijectiveTransform) ¤

Checkerboard pattern from https://arxiv.org/pdf/1605.08803.pdf

Source code in generax/flows/reshape.py
class Checkerboard(BijectiveTransform):
  """Checkerboard pattern from https://arxiv.org/pdf/1605.08803.pdf"""

  output_shape: Tuple[int] = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    """
    H, W, C = input_shape
    assert W%2 == 0, 'Need even width'
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.output_shape = (H, W//2, C*2)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    The transformed input and 0
    """
    if inverse == False:
      assert x.shape == self.input_shape
      z = einops.rearrange(x, 'h (w k) c -> h w (c k)', k=2)
    else:
      assert x.shape == self.output_shape
      z = einops.rearrange(x, 'h w (c k) -> h (w k) c', k=2)

    log_det = jnp.array(0.0)
    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/reshape.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
__init__(self, input_shape: Tuple[int], **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
Source code in generax/flows/reshape.py
def __init__(self,
             input_shape: Tuple[int],
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  """
  H, W, C = input_shape
  assert W%2 == 0, 'Need even width'
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.output_shape = (H, W//2, C*2)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: The transformed input and 0

Source code in generax/flows/reshape.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  The transformed input and 0
  """
  if inverse == False:
    assert x.shape == self.input_shape
    z = einops.rearrange(x, 'h (w k) c -> h w (c k)', k=2)
  else:
    assert x.shape == self.output_shape
    z = einops.rearrange(x, 'h w (c k) -> h (w k) c', k=2)

  log_det = jnp.array(0.0)
  return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/reshape.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.flows.reshape.Squeeze (BijectiveTransform) ¤

Space to depth. (H, W, C) -> (H//2, W//2, C*4)

Source code in generax/flows/reshape.py
class Squeeze(BijectiveTransform):
  """Space to depth.  (H, W, C) -> (H//2, W//2, C*4)"""

  output_shape: Tuple[int] = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    H, W, C = input_shape
    assert H % 2 == 0, 'Need even height'
    assert W % 2 == 0, 'Need even width'
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.output_shape = (H//2, W//2, C*4)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    The transformed input and 0
    """
    if inverse == False:
      assert x.shape == self.input_shape
      z = einops.rearrange(x, '(h m) (w n) c -> h w (c m n)', m=2, n=2)
    else:
      assert x.shape == self.output_shape
      z = einops.rearrange(x, 'h w (c m n) -> (h m) (w n) c', m=2, n=2)

    log_det = jnp.array(0.0)
    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/reshape.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/reshape.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/reshape.py
def __init__(self,
             input_shape: Tuple[int],
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  H, W, C = input_shape
  assert H % 2 == 0, 'Need even height'
  assert W % 2 == 0, 'Need even width'
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.output_shape = (H//2, W//2, C*4)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: The transformed input and 0

Source code in generax/flows/reshape.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  The transformed input and 0
  """
  if inverse == False:
    assert x.shape == self.input_shape
    z = einops.rearrange(x, '(h m) (w n) c -> h w (c m n)', m=2, n=2)
  else:
    assert x.shape == self.output_shape
    z = einops.rearrange(x, 'h w (c m n) -> (h m) (w n) c', m=2, n=2)

  log_det = jnp.array(0.0)
  return z, log_det

generax.flows.reshape.Slice (InjectiveTransform) ¤

Slice an input to reduce the dimension

Source code in generax/flows/reshape.py
class Slice(InjectiveTransform):
  """Slice an input to reduce the dimension
  """

  def __init__(self,
               input_shape: Tuple[int],
               output_shape: Tuple[int],
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    assert input_shape[:-1] == output_shape[:-1], 'Need to keep the same spatial dimensions'
    super().__init__(input_shape=input_shape,
                     output_shape=output_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    (z, log_det)
    """
    if inverse == False:
      assert x.shape == self.input_shape, 'Only works on unbatched data'
    else:
      assert x.shape == self.output_shape, 'Only works on unbatched data'

    if inverse == False:
      z = x[..., :self.output_shape[-1]]
    else:
      pad_shape = self.input_shape[:-1] + (self.input_shape[-1] - self.output_shape[-1],)
      z = jnp.concatenate([x, jnp.zeros(pad_shape)], axis=-1)

    log_det = jnp.array(0.0)
    return z, log_det

  def log_determinant(self,
                      z: Array,
                      **kwargs) -> Array:
    """Compute -0.5*log(det(J^TJ))

    **Arguments**:

    - `z`: An element of the base space

    **Returns**:
    The log determinant of (J^TJ)^0.5
    """
    return jnp.array(0.0)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/reshape.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/reshape.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], output_shape: Tuple[int], **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/reshape.py
def __init__(self,
             input_shape: Tuple[int],
             output_shape: Tuple[int],
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  assert input_shape[:-1] == output_shape[:-1], 'Need to keep the same spatial dimensions'
  super().__init__(input_shape=input_shape,
                   output_shape=output_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/reshape.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  if inverse == False:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
  else:
    assert x.shape == self.output_shape, 'Only works on unbatched data'

  if inverse == False:
    z = x[..., :self.output_shape[-1]]
  else:
    pad_shape = self.input_shape[:-1] + (self.input_shape[-1] - self.output_shape[-1],)
    z = jnp.concatenate([x, jnp.zeros(pad_shape)], axis=-1)

  log_det = jnp.array(0.0)
  return z, log_det
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.InjectiveTransform.project.

Source code in generax/flows/reshape.py
def project(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Project a point onto the image of the transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  z
  """
  z, _ = self(x, y=y, **kwargs)
  x_proj, _ = self(z, y=y, inverse=True, **kwargs)
  return x_proj