Skip to content

Convolution¤

generax.flows.conv.CircularConv (BijectiveTransform) ¤

Circular convolution. Equivalent to a regular convolution with circular padding. https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf

Source code in generax/flows/conv.py
class CircularConv(BijectiveTransform):
  """Circular convolution.  Equivalent to a regular convolution with circular padding.
        https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
  """
  filter_shape: Tuple[int] = eqx.field(static=True)
  w: Array

  def __init__(self,
               input_shape: Tuple[int],
               filter_shape: Tuple[int]=(3, 3),
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `filter_shape`: Height and width for the convolutional filter, (Kx, Ky).  The full
                      kernel will have shape (Kx, Ky, C, C)
    """
    assert len(filter_shape) == 2
    self.filter_shape = filter_shape
    super().__init__(input_shape=input_shape,
                     **kwargs)

    H, W, C = input_shape

    w = random.normal(key, shape=self.filter_shape + (C, C))
    self.w = jax.vmap(jax.vmap(util.whiten))(w)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """
    See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf
    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape

    Kx, Ky, _, _ = self.w.shape

    # See how much we need to roll the filter
    W_x = (Kx - 1) // 2
    W_y = (Ky - 1) // 2

    # Pad the filter to match the fft size and roll it so that its center is at (0,0)
    W_padded = jnp.pad(self.w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
    W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))

    # Apply the FFT to get the convolution
    if inverse == False:
      image_fft = fft_channel_vmap(x)
    else:
      image_fft = fft_channel_vmap(x)
    W_fft = fft_double_channel_vmap(W_padded)

    if inverse == True:
      z_fft = jnp.einsum("abij,abj->abi", W_fft, image_fft)
      z = ifft_channel_vmap(z_fft).real
    else:
      # For deconv, we need to invert the W over the channel dims
      W_fft_inv = inv_height_width_vmap(W_fft)

      x_fft = jnp.einsum("abij,abj->abi", W_fft_inv, image_fft)
      z = ifft_channel_vmap(x_fft).real

    # The log determinant is the log det of the frequencies over the channel dims
    log_det = -slogdet_height_width_vmap(W_fft).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int] = (3, 3), *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
  • filter_shape: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)
Source code in generax/flows/conv.py
def __init__(self,
             input_shape: Tuple[int],
             filter_shape: Tuple[int]=(3, 3),
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `filter_shape`: Height and width for the convolutional filter, (Kx, Ky).  The full
                    kernel will have shape (Kx, Ky, C, C)
  """
  assert len(filter_shape) == 2
  self.filter_shape = filter_shape
  super().__init__(input_shape=input_shape,
                   **kwargs)

  H, W, C = input_shape

  w = random.normal(key, shape=self.filter_shape + (C, C))
  self.w = jax.vmap(jax.vmap(util.whiten))(w)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/conv.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/conv.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape

  Kx, Ky, _, _ = self.w.shape

  # See how much we need to roll the filter
  W_x = (Kx - 1) // 2
  W_y = (Ky - 1) // 2

  # Pad the filter to match the fft size and roll it so that its center is at (0,0)
  W_padded = jnp.pad(self.w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
  W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))

  # Apply the FFT to get the convolution
  if inverse == False:
    image_fft = fft_channel_vmap(x)
  else:
    image_fft = fft_channel_vmap(x)
  W_fft = fft_double_channel_vmap(W_padded)

  if inverse == True:
    z_fft = jnp.einsum("abij,abj->abi", W_fft, image_fft)
    z = ifft_channel_vmap(z_fft).real
  else:
    # For deconv, we need to invert the W over the channel dims
    W_fft_inv = inv_height_width_vmap(W_fft)

    x_fft = jnp.einsum("abij,abj->abi", W_fft_inv, image_fft)
    z = ifft_channel_vmap(x_fft).real

  # The log determinant is the log det of the frequencies over the channel dims
  log_det = -slogdet_height_width_vmap(W_fft).sum()

  if inverse:
    log_det = -log_det

  return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/conv.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.flows.conv.CaleyOrthogonalConv (BijectiveTransform) ¤

Caley parametrization of an orthogonal convolution. https://arxiv.org/pdf/2104.07167.pdf

Source code in generax/flows/conv.py
class CaleyOrthogonalConv(BijectiveTransform):
  """Caley parametrization of an orthogonal convolution.
        https://arxiv.org/pdf/2104.07167.pdf
  """
  filter_shape: Tuple[int] = eqx.field(static=True)
  v: Array
  g: Array

  def __init__(self,
               input_shape: Tuple[int],
               filter_shape: Tuple[int] = (3, 3),
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `filter_shape`: Height and width for the convolutional filter, (Kx, Ky).  The full
                      kernel will have shape (Kx, Ky, C, C)
    """
    assert len(filter_shape) == 2
    self.filter_shape = filter_shape
    super().__init__(input_shape=input_shape,
                     **kwargs)

    H, W, C = input_shape

    k1, k2 = random.split(key, 2)
    self.v = random.normal(k1, shape=self.filter_shape + (C, C))
    self.g = random.normal(k2, shape=())

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """
    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape

    w = self.g*self.v/jnp.linalg.norm(self.v)

    Kx, Ky, _, _ = w.shape

    # See how much we need to roll the filter
    W_x = (Kx - 1) // 2
    W_y = (Ky - 1) // 2

    # Pad the filter to match the fft size and roll it so that its center is at (0,0)
    W_padded = jnp.pad(w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
    W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))

    # Apply the FFT to get the convolution
    if inverse == False:
      image_fft = fft_channel_vmap(x)
    else:
      image_fft = fft_channel_vmap(x)
    W_fft = fft_double_channel_vmap(W_padded)

    A_fft = W_fft - W_fft.conj().transpose((0, 1, 3, 2))
    I = jnp.eye(W_fft.shape[-1])

    if inverse == True:
      IpA_inv = inv_height_width_vmap(I[None,None] + A_fft)
      y_fft = jnp.einsum("abij,abj->abi", IpA_inv, image_fft)
      z_fft = y_fft - jnp.einsum("abij,abj->abi", A_fft, y_fft)
      z = ifft_channel_vmap(z_fft).real
    else:
      ImA_inv = inv_height_width_vmap(I[None,None] - A_fft)
      y_fft = jnp.einsum("abij,abj->abi", ImA_inv, image_fft)
      z_fft = y_fft + jnp.einsum("abij,abj->abi", A_fft, y_fft)
      z = ifft_channel_vmap(z_fft).real

    log_det = jnp.array(0.0)
    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/conv.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/conv.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int] = (3, 3), *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
  • filter_shape: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)
Source code in generax/flows/conv.py
def __init__(self,
             input_shape: Tuple[int],
             filter_shape: Tuple[int] = (3, 3),
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `filter_shape`: Height and width for the convolutional filter, (Kx, Ky).  The full
                    kernel will have shape (Kx, Ky, C, C)
  """
  assert len(filter_shape) == 2
  self.filter_shape = filter_shape
  super().__init__(input_shape=input_shape,
                   **kwargs)

  H, W, C = input_shape

  k1, k2 = random.split(key, 2)
  self.v = random.normal(k1, shape=self.filter_shape + (C, C))
  self.g = random.normal(k2, shape=())
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/conv.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape

  w = self.g*self.v/jnp.linalg.norm(self.v)

  Kx, Ky, _, _ = w.shape

  # See how much we need to roll the filter
  W_x = (Kx - 1) // 2
  W_y = (Ky - 1) // 2

  # Pad the filter to match the fft size and roll it so that its center is at (0,0)
  W_padded = jnp.pad(w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
  W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))

  # Apply the FFT to get the convolution
  if inverse == False:
    image_fft = fft_channel_vmap(x)
  else:
    image_fft = fft_channel_vmap(x)
  W_fft = fft_double_channel_vmap(W_padded)

  A_fft = W_fft - W_fft.conj().transpose((0, 1, 3, 2))
  I = jnp.eye(W_fft.shape[-1])

  if inverse == True:
    IpA_inv = inv_height_width_vmap(I[None,None] + A_fft)
    y_fft = jnp.einsum("abij,abj->abi", IpA_inv, image_fft)
    z_fft = y_fft - jnp.einsum("abij,abj->abi", A_fft, y_fft)
    z = ifft_channel_vmap(z_fft).real
  else:
    ImA_inv = inv_height_width_vmap(I[None,None] - A_fft)
    y_fft = jnp.einsum("abij,abj->abi", ImA_inv, image_fft)
    z_fft = y_fft + jnp.einsum("abij,abj->abi", A_fft, y_fft)
    z = ifft_channel_vmap(z_fft).real

  log_det = jnp.array(0.0)
  return z, log_det

generax.flows.conv.OneByOneConv (BijectiveTransform) ¤

1x1 convolution. Uses a dense parametrization because the channel dimension will probably never be that big. Costs O(C^3). Used in GLOW https://arxiv.org/pdf/1807.03039.pdf

Source code in generax/flows/conv.py
class OneByOneConv(BijectiveTransform):
  """ 1x1 convolution.  Uses a dense parametrization because the channel dimension will probably
      never be that big.  Costs O(C^3).  Used in GLOW https://arxiv.org/pdf/1807.03039.pdf
  """
  w: Array

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    H, W, C = input_shape
    w = random.normal(key, shape=(C, C))
    self.w = util.whiten(w)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """
    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape

    # Using lax.conv instead of matrix multiplication over the channel dimension
    # is faster and also more numerically stable for some reason.

    # Run the flow
    if inverse == False:
      z = util.conv(self.w[None,None,:,:], x)
    else:
      w_inv = jnp.linalg.inv(self.w)
      z = util.conv(w_inv[None,None,:,:], x)

    log_det = jnp.linalg.slogdet(self.w)[1]*H*W

    if inverse:
      log_det = -log_det

    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/conv.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/conv.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/conv.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  H, W, C = input_shape
  w = random.normal(key, shape=(C, C))
  self.w = util.whiten(w)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/conv.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape

  # Using lax.conv instead of matrix multiplication over the channel dimension
  # is faster and also more numerically stable for some reason.

  # Run the flow
  if inverse == False:
    z = util.conv(self.w[None,None,:,:], x)
  else:
    w_inv = jnp.linalg.inv(self.w)
    z = util.conv(w_inv[None,None,:,:], x)

  log_det = jnp.linalg.slogdet(self.w)[1]*H*W

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.conv.HaarWavelet (BijectiveTransform) ¤

Wavelet flow https://arxiv.org/pdf/2010.13821.pdf

Source code in generax/flows/conv.py
class HaarWavelet(BijectiveTransform):
  """Wavelet flow https://arxiv.org/pdf/2010.13821.pdf"""

  W: Array = eqx.field(static=True)
  output_shape: Tuple[int] = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    """
    H, W, C = input_shape
    if H%2 != 0:
      raise ValueError('Height must be even')
    if W%2 != 0:
      raise ValueError('Width must be even')
    super().__init__(input_shape=input_shape,
                     **kwargs)

    self.output_shape = (H//2, W//2, 4*C)

    # Construct the filter
    p, n = 0.5, -0.5
    W = np.array([[[p, p],
                   [p, p]],
                  [[p, n],
                   [p, n]],
                  [[p, p],
                   [n, n]],
                  [[p, n],
                   [n, p]]])
    W = W.transpose((1, 2, 0)) # (H, W, O)
    W = W[:,:,None,:] # (H, W, I, O).  We'll be applying this channelwise
    self.W = W

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """
    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """

    def haar_conv(x):
      """(H, W) -> (H/2, W/2, 4))"""
      return util.conv(self.W, x[:,:,None], stride=2)

    if inverse == False:
      H, W, C = x.shape
      z = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(x)

      # Rescale the lowpass to have same mean
      z = z.at[:,:,0].mul(0.5)

      z = einops.rearrange(z, 'H W D C -> H W (C D)', D=4)
    else:
      h = einops.rearrange(x, 'H W (C D) -> H W C D', D=4)

      # Rescale
      h = h.at[:,:,:,0].mul(2.0)

      h = einops.rearrange(h, 'H W C (M N) -> (H M) (W N) C', M=2, N=2)
      h = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(h)
      z = einops.rearrange(h, 'H W (M N) C -> (H M) (W N) C', M=2, N=2)

    total_dim = util.list_prod(x.shape)
    log_det = jnp.log(0.5)*total_dim/4

    if inverse:
      log_det = -log_det
    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/conv.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/conv.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
Source code in generax/flows/conv.py
def __init__(self,
             input_shape: Tuple[int],
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  """
  H, W, C = input_shape
  if H%2 != 0:
    raise ValueError('Height must be even')
  if W%2 != 0:
    raise ValueError('Width must be even')
  super().__init__(input_shape=input_shape,
                   **kwargs)

  self.output_shape = (H//2, W//2, 4*C)

  # Construct the filter
  p, n = 0.5, -0.5
  W = np.array([[[p, p],
                 [p, p]],
                [[p, n],
                 [p, n]],
                [[p, p],
                 [n, n]],
                [[p, n],
                 [n, p]]])
  W = W.transpose((1, 2, 0)) # (H, W, O)
  W = W[:,:,None,:] # (H, W, I, O).  We'll be applying this channelwise
  self.W = W
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/conv.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """

  def haar_conv(x):
    """(H, W) -> (H/2, W/2, 4))"""
    return util.conv(self.W, x[:,:,None], stride=2)

  if inverse == False:
    H, W, C = x.shape
    z = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(x)

    # Rescale the lowpass to have same mean
    z = z.at[:,:,0].mul(0.5)

    z = einops.rearrange(z, 'H W D C -> H W (C D)', D=4)
  else:
    h = einops.rearrange(x, 'H W (C D) -> H W C D', D=4)

    # Rescale
    h = h.at[:,:,:,0].mul(2.0)

    h = einops.rearrange(h, 'H W C (M N) -> (H M) (W N) C', M=2, N=2)
    h = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(h)
    z = einops.rearrange(h, 'H W (M N) C -> (H M) (W N) C', M=2, N=2)

  total_dim = util.list_prod(x.shape)
  log_det = jnp.log(0.5)*total_dim/4

  if inverse:
    log_det = -log_det
  return z, log_det

generax.flows.pac_flow.EmergingConv (PACFlow) ¤

Emerging convolutions https://arxiv.org/pdf/1901.11137.pdf This is a special case of PAC flows

Source code in generax/flows/pac_flow.py
class EmergingConv(PACFlow):
  """Emerging convolutions https://arxiv.org/pdf/1901.11137.pdf
  This is a special case of PAC flows
  """

  def __init__(self,
               input_shape: Tuple[int],
               kernel_size: int = 5,
               order_type: str = "s_curve",
               zero_init: bool = True,
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(input_shape=input_shape,
                     feature_dim=None,
                     kernel_size=kernel_size,
                     order_type=order_type,
                     pixel_adaptive=False,
                     zero_init=zero_init,
                     key=key,
                     **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/pac_flow.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/pac_flow.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Inherited from generax.flows.pac_flow.PACFlow.__call__.

Source code in generax/flows/pac_flow.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape

  # Apply the linear function
  z, diag_jacobian = pac_ldu_mvp(x,
                                 self.theta,
                                 self.w,
                                 self.order,
                                 inverse=inverse,
                                 **self.im2col_kwargs)

  # Get the log det
  if self.theta is not None:
    flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
  else:
    flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))

  log_det = jnp.log(jnp.abs(flat_diag)).sum()

  if inverse:
    log_det = -log_det

  assert z.shape == self.input_shape
  assert log_det.shape == ()

  return z, log_det
__init__(self, input_shape: Tuple[int], kernel_size: int = 5, order_type: str = 's_curve', zero_init: bool = True, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/flows/pac_flow.py
def __init__(self,
             input_shape: Tuple[int],
             kernel_size: int = 5,
             order_type: str = "s_curve",
             zero_init: bool = True,
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(input_shape=input_shape,
                   feature_dim=None,
                   kernel_size=kernel_size,
                   order_type=order_type,
                   pixel_adaptive=False,
                   zero_init=zero_init,
                   key=key,
                   **kwargs)

generax.flows.pac_flow.PACFlow (BijectiveTransform) ¤

Pixel adaptive convolutions. Gets too numerically unstable to use in practice... https://eddiecunningham.github.io/pdfs/PAC_Flow.pdf

Source code in generax/flows/pac_flow.py
class PACFlow(BijectiveTransform):
  """Pixel adaptive convolutions.  Gets too numerically unstable to use in practice...
  https://eddiecunningham.github.io/pdfs/PAC_Flow.pdf
  """
  kernel_shape: Tuple[int] = eqx.field(static=True)
  feature_dim: int = eqx.field(static=True)
  order_type: str = eqx.field(static=True)
  pixel_adaptive: bool = eqx.field(static=True)
  im2col_kwargs: Any = eqx.field(static=True)

  order: Array = eqx.field(static=True)

  w: Array
  theta: Union[Array, None]

  def __init__(self,
               input_shape: Tuple[int],
               feature_dim: int = 8,
               kernel_size: int=5,
               order_type: str="s_curve",
               pixel_adaptive: bool=True,
               zero_init: bool = True,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """
    **Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `feature_dim`: The dimension of the features
    - `kernel_size`: Height and width for the convolutional filter, (Kx, Ky).  The full
                     kernel will have shape (Kx, Ky, C, C)
    - `order_type`: The order to convolve in.  Either "raster" or "s_curve"
    - `pixel_adaptive`: Whether to use pixel adaptive convolutions
    - `zero_init`: Whether to initialize the weights to zero
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

    assert kernel_size%2 == 1
    self.kernel_shape   = (kernel_size, kernel_size)
    self.feature_dim    = feature_dim
    self.order_type     = order_type
    self.pixel_adaptive = pixel_adaptive

    H, W, C = input_shape

    # Extract the im2col kwargs
    Kx, Ky = self.kernel_shape
    pad = Kx//2
    self.im2col_kwargs = dict(filter_shape=self.kernel_shape,
                              stride=(1, 1),
                              padding=((pad, Kx - pad - 1),
                                       (pad, Ky - pad - 1)),
                              lhs_dilation=(1, 1),
                              rhs_dilation=(1, 1),
                              dimension_numbers=("NHWC", "HWIO", "NHWC"))

    # Determine the order to convolve
    order_shape = H, W, 1

    if self.order_type == "raster":
      order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
    elif self.order_type == "s_curve":
      order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
      order[::2, :, :] = order[::2, :, :][:, ::-1]
    order = order*1.0 # Turn into a float

    self.order = order

    # Initialize the weights
    k1, k2 = random.split(key, 2)
    w = random.normal(k1, shape=self.kernel_shape + (C, C))
    if zero_init:
      pad = Kx//2
      w = w.at[pad,pad,jnp.arange(C),jnp.arange(C)].set(1.0)
    self.w = w

    if self.pixel_adaptive == True:
      self.theta = random.normal(k2, shape=(H, W, 2*C + self.feature_dim))
    else:
      self.theta = None

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               **kwargs) -> Array:
    """
    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape

    # Apply the linear function
    z, diag_jacobian = pac_ldu_mvp(x,
                                   self.theta,
                                   self.w,
                                   self.order,
                                   inverse=inverse,
                                   **self.im2col_kwargs)

    # Get the log det
    if self.theta is not None:
      flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
    else:
      flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))

    log_det = jnp.log(jnp.abs(flat_diag)).sum()

    if inverse:
      log_det = -log_det

    assert z.shape == self.input_shape
    assert log_det.shape == ()

    return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/pac_flow.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/pac_flow.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], feature_dim: int = 8, kernel_size: int = 5, order_type: str = 's_curve', pixel_adaptive: bool = True, zero_init: bool = True, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • feature_dim: The dimension of the features
  • kernel_size: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)
  • order_type: The order to convolve in. Either "raster" or "s_curve"
  • pixel_adaptive: Whether to use pixel adaptive convolutions
  • zero_init: Whether to initialize the weights to zero
Source code in generax/flows/pac_flow.py
def __init__(self,
             input_shape: Tuple[int],
             feature_dim: int = 8,
             kernel_size: int=5,
             order_type: str="s_curve",
             pixel_adaptive: bool=True,
             zero_init: bool = True,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """
  **Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `feature_dim`: The dimension of the features
  - `kernel_size`: Height and width for the convolutional filter, (Kx, Ky).  The full
                   kernel will have shape (Kx, Ky, C, C)
  - `order_type`: The order to convolve in.  Either "raster" or "s_curve"
  - `pixel_adaptive`: Whether to use pixel adaptive convolutions
  - `zero_init`: Whether to initialize the weights to zero
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)

  assert kernel_size%2 == 1
  self.kernel_shape   = (kernel_size, kernel_size)
  self.feature_dim    = feature_dim
  self.order_type     = order_type
  self.pixel_adaptive = pixel_adaptive

  H, W, C = input_shape

  # Extract the im2col kwargs
  Kx, Ky = self.kernel_shape
  pad = Kx//2
  self.im2col_kwargs = dict(filter_shape=self.kernel_shape,
                            stride=(1, 1),
                            padding=((pad, Kx - pad - 1),
                                     (pad, Ky - pad - 1)),
                            lhs_dilation=(1, 1),
                            rhs_dilation=(1, 1),
                            dimension_numbers=("NHWC", "HWIO", "NHWC"))

  # Determine the order to convolve
  order_shape = H, W, 1

  if self.order_type == "raster":
    order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
  elif self.order_type == "s_curve":
    order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
    order[::2, :, :] = order[::2, :, :][:, ::-1]
  order = order*1.0 # Turn into a float

  self.order = order

  # Initialize the weights
  k1, k2 = random.split(key, 2)
  w = random.normal(k1, shape=self.kernel_shape + (C, C))
  if zero_init:
    pad = Kx//2
    w = w.at[pad,pad,jnp.arange(C),jnp.arange(C)].set(1.0)
  self.w = w

  if self.pixel_adaptive == True:
    self.theta = random.normal(k2, shape=(H, W, 2*C + self.feature_dim))
  else:
    self.theta = None
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/pac_flow.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """
  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape

  # Apply the linear function
  z, diag_jacobian = pac_ldu_mvp(x,
                                 self.theta,
                                 self.w,
                                 self.order,
                                 inverse=inverse,
                                 **self.im2col_kwargs)

  # Get the log det
  if self.theta is not None:
    flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
  else:
    flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))

  log_det = jnp.log(jnp.abs(flat_diag)).sum()

  if inverse:
    log_det = -log_det

  assert z.shape == self.input_shape
  assert log_det.shape == ()

  return z, log_det