Skip to content

Neural network layers¤

generax.nn.layers.WeightNormDense ¤

Weight normalization parametrized linear layer https://arxiv.org/pdf/1602.07868.pdf

Source code in generax/nn/layers.py
class WeightNormDense(eqx.Module):
  """Weight normalization parametrized linear layer
  https://arxiv.org/pdf/1602.07868.pdf
  """

  in_size: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  W: Array
  b: Array
  g: Array

  def __init__(self,
               in_size: int,
               out_size: int,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)

    self.in_size = in_size
    self.out_size = out_size

    w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
    self.W = w_init(key, shape=(out_size, in_size))
    self.g = jnp.array(1.0)
    self.b = jnp.zeros(out_size)

  def data_dependent_init(self,
                          x: Array,
                          key: PRNGKeyArray = None,
                          before_square_plus: Optional[bool] = False) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `before_square_plus`: In case we want the activations after square plus to be gaussian

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[-1] == self.in_size, 'Only works on batched data'

    # Initialize g and b.
    W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=1))[:, None]
    x = jnp.einsum('ij,bj->bi', W, x)

    std = jnp.std(x.reshape((-1, x.shape[-1])), axis=0) + 1e-5

    if before_square_plus:
      std = std - 1/std

    g = 1/std

    x *= g

    mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
    b = -mean

    # Turn the new parameters into a new module
    get_g = lambda tree: tree.g
    get_b = lambda tree: tree.b
    updated_layer = eqx.tree_at(get_g, self, g)
    updated_layer = eqx.tree_at(get_b, updated_layer, b)

    return updated_layer

  def __call__(self, x: Array, y: Array = None) -> Array:
    W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=1))[:, None]
    x = self.g*(W@x) + self.b
    return x
__init__(self, in_size: int, out_size: int, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             in_size: int,
             out_size: int,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)

  self.in_size = in_size
  self.out_size = out_size

  w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
  self.W = w_init(key, shape=(out_size, in_size))
  self.g = jnp.array(1.0)
  self.b = jnp.zeros(out_size)
data_dependent_init(self, x: Array, key: PRNGKeyArray = None, before_square_plus: Optional[bool] = False) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • key: A jax.random.PRNGKey for initialization
  • before_square_plus: In case we want the activations after square plus to be gaussian

Returns: A new layer with the parameters initialized.

Source code in generax/nn/layers.py
def data_dependent_init(self,
                        x: Array,
                        key: PRNGKeyArray = None,
                        before_square_plus: Optional[bool] = False) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `before_square_plus`: In case we want the activations after square plus to be gaussian

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[-1] == self.in_size, 'Only works on batched data'

  # Initialize g and b.
  W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=1))[:, None]
  x = jnp.einsum('ij,bj->bi', W, x)

  std = jnp.std(x.reshape((-1, x.shape[-1])), axis=0) + 1e-5

  if before_square_plus:
    std = std - 1/std

  g = 1/std

  x *= g

  mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
  b = -mean

  # Turn the new parameters into a new module
  get_g = lambda tree: tree.g
  get_b = lambda tree: tree.b
  updated_layer = eqx.tree_at(get_g, self, g)
  updated_layer = eqx.tree_at(get_b, updated_layer, b)

  return updated_layer
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=1))[:, None]
  x = self.g*(W@x) + self.b
  return x

generax.nn.layers.WeightNormConv ¤

Weight normalization parametrized convolutional layer https://arxiv.org/pdf/1602.07868.pdf

Source code in generax/nn/layers.py
class WeightNormConv(eqx.Module):
  """Weight normalization parametrized convolutional layer
  https://arxiv.org/pdf/1602.07868.pdf
  """

  input_shape: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  filter_shape: Tuple[int] = eqx.field(static=True)
  padding: Union[int,str] = eqx.field(static=True)
  stride: int = eqx.field(static=True)
  W: Array
  b: Array
  g: Array

  def __init__(self,
               input_shape: Tuple[int], # in_channels
               filter_shape: Tuple[int],
               out_size: int,
               *,
               key: PRNGKeyArray,
               padding: Union[int,str] = 'SAME',
               stride: int = 1,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape

    self.input_shape = input_shape
    self.filter_shape = filter_shape
    self.out_size = out_size
    self.padding = padding
    self.stride = stride
    w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
    self.W = w_init(key, shape=self.filter_shape + (C, out_size))
    self.g = jnp.array(1.0)
    self.b = jnp.zeros(out_size)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None,
                          before_square_plus: Optional[bool] = False) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `before_square_plus`: In case we want the activations after square plus to be gaussian

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'

    # Initialize g and b.
    W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=(0, 1, 2)))[None,None,None,:]
    x = util.conv(W, x, stride=self.stride, padding=self.padding)

    std = jnp.std(x.reshape((-1, x.shape[-1])), axis=0) + 1e-5

    if before_square_plus:
      std = std - 1/std

    g = 1/std

    x *= g

    mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
    b = -mean

    # Turn the new parameters into a new module
    get_g = lambda tree: tree.g
    get_b = lambda tree: tree.b
    updated_layer = eqx.tree_at(get_g, self, g)
    updated_layer = eqx.tree_at(get_b, updated_layer, b)

    return updated_layer

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=(0, 1, 2)))[None,None,None,:]
    x = self.g*util.conv(W, x, stride=self.stride, padding=self.padding) + self.b
    return x
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int], out_size: int, *, key: PRNGKeyArray, padding: Union[int, str] = 'SAME', stride: int = 1, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int], # in_channels
             filter_shape: Tuple[int],
             out_size: int,
             *,
             key: PRNGKeyArray,
             padding: Union[int,str] = 'SAME',
             stride: int = 1,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape

  self.input_shape = input_shape
  self.filter_shape = filter_shape
  self.out_size = out_size
  self.padding = padding
  self.stride = stride
  w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
  self.W = w_init(key, shape=self.filter_shape + (C, out_size))
  self.g = jnp.array(1.0)
  self.b = jnp.zeros(out_size)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None, before_square_plus: Optional[bool] = False) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • key: A jax.random.PRNGKey for initialization
  • before_square_plus: In case we want the activations after square plus to be gaussian

Returns: A new layer with the parameters initialized.

Source code in generax/nn/layers.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None,
                        before_square_plus: Optional[bool] = False) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `before_square_plus`: In case we want the activations after square plus to be gaussian

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'

  # Initialize g and b.
  W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=(0, 1, 2)))[None,None,None,:]
  x = util.conv(W, x, stride=self.stride, padding=self.padding)

  std = jnp.std(x.reshape((-1, x.shape[-1])), axis=0) + 1e-5

  if before_square_plus:
    std = std - 1/std

  g = 1/std

  x *= g

  mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
  b = -mean

  # Turn the new parameters into a new module
  get_g = lambda tree: tree.g
  get_b = lambda tree: tree.b
  updated_layer = eqx.tree_at(get_g, self, g)
  updated_layer = eqx.tree_at(get_b, updated_layer, b)

  return updated_layer
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  W = self.W*jax.lax.rsqrt((self.W**2).sum(axis=(0, 1, 2)))[None,None,None,:]
  x = self.g*util.conv(W, x, stride=self.stride, padding=self.padding) + self.b
  return x

generax.nn.layers.WeightStandardizedConv ¤

Weight standardized parametrized convolutional layer https://arxiv.org/pdf/1903.10520.pdf

Source code in generax/nn/layers.py
class WeightStandardizedConv(eqx.Module):
  """Weight standardized parametrized convolutional layer
  https://arxiv.org/pdf/1903.10520.pdf
  """

  input_shape: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  filter_shape: Tuple[int] = eqx.field(static=True)
  padding: Union[int,str] = eqx.field(static=True)
  stride: int = eqx.field(static=True)
  W: Array
  b: Array

  def __init__(self,
               input_shape: Tuple[int], # in_channels
               filter_shape: Tuple[int],
               out_size: int,
               *,
               key: PRNGKeyArray,
               padding: Union[int,str] = 'SAME',
               stride: int = 1,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape

    self.input_shape = input_shape
    self.filter_shape = filter_shape
    self.out_size = out_size
    self.padding = padding
    self.stride = stride

    w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
    self.W = w_init(key, shape=self.filter_shape + (C, out_size))
    self.b = jnp.zeros(out_size)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """

    axes = (0, 1, 2)
    mean = jnp.mean(self.W, axis=axes, keepdims=True)
    var = jnp.var(self.W, axis=axes, keepdims=True)

    W_hat = (self.W - mean)/jnp.sqrt(var + 1e-5)
    x = util.conv(W_hat, x, stride=self.stride, padding=self.padding)

    # Initialize b.
    mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
    b = -mean

    # Turn the new parameters into a new module
    get_b = lambda tree: tree.b
    updated_layer = eqx.tree_at(get_b, self, b)

    return updated_layer

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    axes = (0, 1, 2)
    mean = jnp.mean(self.W, axis=axes, keepdims=True)
    var = jnp.var(self.W, axis=axes, keepdims=True)

    H, W, C_in, C_out = self.W.shape
    fan_in = H*W*C_in
    W_hat = (self.W - mean)*jax.lax.rsqrt(fan_in*var + 1e-5)
    x = util.conv(W_hat, x, stride=self.stride, padding=self.padding) + self.b
    return x
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int], out_size: int, *, key: PRNGKeyArray, padding: Union[int, str] = 'SAME', stride: int = 1, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int], # in_channels
             filter_shape: Tuple[int],
             out_size: int,
             *,
             key: PRNGKeyArray,
             padding: Union[int,str] = 'SAME',
             stride: int = 1,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape

  self.input_shape = input_shape
  self.filter_shape = filter_shape
  self.out_size = out_size
  self.padding = padding
  self.stride = stride

  w_init = jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1)
  self.W = w_init(key, shape=self.filter_shape + (C, out_size))
  self.b = jnp.zeros(out_size)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/layers.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """

  axes = (0, 1, 2)
  mean = jnp.mean(self.W, axis=axes, keepdims=True)
  var = jnp.var(self.W, axis=axes, keepdims=True)

  W_hat = (self.W - mean)/jnp.sqrt(var + 1e-5)
  x = util.conv(W_hat, x, stride=self.stride, padding=self.padding)

  # Initialize b.
  mean = jnp.mean(x.reshape((-1, x.shape[-1])), axis=0)
  b = -mean

  # Turn the new parameters into a new module
  get_b = lambda tree: tree.b
  updated_layer = eqx.tree_at(get_b, self, b)

  return updated_layer
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  axes = (0, 1, 2)
  mean = jnp.mean(self.W, axis=axes, keepdims=True)
  var = jnp.var(self.W, axis=axes, keepdims=True)

  H, W, C_in, C_out = self.W.shape
  fan_in = H*W*C_in
  W_hat = (self.W - mean)*jax.lax.rsqrt(fan_in*var + 1e-5)
  x = util.conv(W_hat, x, stride=self.stride, padding=self.padding) + self.b
  return x

generax.nn.layers.ChannelConvention ¤

ChannelConvention(args, *kwargs)

Source code in generax/nn/layers.py
class ChannelConvention(eqx.Module):
  module: eqx.Module
  def __init__(self, module: eqx.Module):
    super().__init__()
    self.module = module

  def __call__(self, x):
    x = einops.rearrange(x, 'H W C -> C H W')
    x = self.module(x)
    x = einops.rearrange(x, 'C H W -> H W C')
    return x
__init__(self, module: Module) ¤
Source code in generax/nn/layers.py
def __init__(self, module: eqx.Module):
  super().__init__()
  self.module = module
__call__(self, x) ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x):
  x = einops.rearrange(x, 'H W C -> C H W')
  x = self.module(x)
  x = einops.rearrange(x, 'C H W -> H W C')
  return x

generax.nn.layers.ConvAndGroupNorm ¤

Weight standardized conv + group norm

Source code in generax/nn/layers.py
class ConvAndGroupNorm(eqx.Module):
  """Weight standardized conv + group norm
  """
  input_shape: int = eqx.field(static=True)
  conv: WeightStandardizedConv
  norm: eqx.nn.GroupNorm

  def __init__(self,
               input_shape: Tuple[int], # in_channels
               filter_shape: Tuple[int],
               out_size: int,
               groups: int,
               *,
               key: PRNGKeyArray,
               padding: Union[int,str] = 'SAME',
               stride: int = 1,
               **kwargs):
    super().__init__(**kwargs)

    if out_size%groups != 0:
      raise ValueError("The number of groups must divide the number of channels.")

    self.conv = WeightStandardizedConv(input_shape=input_shape,
                                        filter_shape=filter_shape,
                                        out_size=out_size,
                                        key=key,
                                        padding=padding,
                                        stride=stride)
    self.norm = ChannelConvention(eqx.nn.GroupNorm(groups=groups, channels=out_size))
    self.input_shape = self.conv.input_shape

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          shift_scale: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    new_conv = self.conv.data_dependent_init(x, y, key=key)
    get_conv = lambda tree: tree.conv
    updated_layer = eqx.tree_at(get_conv, self, new_conv)
    return updated_layer

  def __call__(self,
               x: Array,
               y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    x = self.conv(x)
    x = self.norm(x)
    return x
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int], out_size: int, groups: int, *, key: PRNGKeyArray, padding: Union[int, str] = 'SAME', stride: int = 1, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int], # in_channels
             filter_shape: Tuple[int],
             out_size: int,
             groups: int,
             *,
             key: PRNGKeyArray,
             padding: Union[int,str] = 'SAME',
             stride: int = 1,
             **kwargs):
  super().__init__(**kwargs)

  if out_size%groups != 0:
    raise ValueError("The number of groups must divide the number of channels.")

  self.conv = WeightStandardizedConv(input_shape=input_shape,
                                      filter_shape=filter_shape,
                                      out_size=out_size,
                                      key=key,
                                      padding=padding,
                                      stride=stride)
  self.norm = ChannelConvention(eqx.nn.GroupNorm(groups=groups, channels=out_size))
  self.input_shape = self.conv.input_shape
data_dependent_init(self, x: Array, y: Optional[Array] = None, shift_scale: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/layers.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        shift_scale: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  new_conv = self.conv.data_dependent_init(x, y, key=key)
  get_conv = lambda tree: tree.conv
  updated_layer = eqx.tree_at(get_conv, self, new_conv)
  return updated_layer
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self,
             x: Array,
             y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  x = self.conv(x)
  x = self.norm(x)
  return x

generax.nn.layers.Upsample ¤

https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf

Source code in generax/nn/layers.py
class Upsample(eqx.Module):
  """https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
  """

  input_shape: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  conv: WeightStandardizedConv

  def __init__(self,
               input_shape: Tuple[int],
               out_size: Optional[int] = None,
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape
    self.input_shape = input_shape
    self.out_size = out_size if out_size is not None else C
    self.conv = WeightStandardizedConv(input_shape=(H, W, C),
                                       filter_shape=(3, 3),
                                       out_size=4*self.out_size,
                                       key=key)

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape
    x = self.conv(x)
    x = jax.nn.silu(x)
    x = einops.rearrange(x, 'h w (c k1 k2) -> (h k1) (w k2) c', k1=2, k2=2)
    assert x.shape == (H*2, W*2, self.out_size)
    return x
__init__(self, input_shape: Tuple[int], out_size: Optional[int] = None, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             out_size: Optional[int] = None,
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape
  self.input_shape = input_shape
  self.out_size = out_size if out_size is not None else C
  self.conv = WeightStandardizedConv(input_shape=(H, W, C),
                                     filter_shape=(3, 3),
                                     out_size=4*self.out_size,
                                     key=key)
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape
  x = self.conv(x)
  x = jax.nn.silu(x)
  x = einops.rearrange(x, 'h w (c k1 k2) -> (h k1) (w k2) c', k1=2, k2=2)
  assert x.shape == (H*2, W*2, self.out_size)
  return x

generax.nn.layers.Downsample ¤

Downsample(args, *kwargs)

Source code in generax/nn/layers.py
class Downsample(eqx.Module):

  input_shape: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  conv: WeightStandardizedConv

  def __init__(self,
               input_shape: Tuple[int],
               out_size: Optional[int] = None,
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape
    self.input_shape = input_shape
    self.out_size = out_size if out_size is not None else C
    self.conv = WeightStandardizedConv(input_shape=(H//2, W//2, C*4),
                                       filter_shape=(3, 3),
                                       out_size=self.out_size,
                                       key=key)

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape
    x = einops.rearrange(x, '(h k1) (w k2) c -> h w (c k1 k2)', k1=2, k2=2)
    x = self.conv(x)
    assert x.shape == (H//2, W//2, self.out_size)
    return x
__init__(self, input_shape: Tuple[int], out_size: Optional[int] = None, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             out_size: Optional[int] = None,
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape
  self.input_shape = input_shape
  self.out_size = out_size if out_size is not None else C
  self.conv = WeightStandardizedConv(input_shape=(H//2, W//2, C*4),
                                     filter_shape=(3, 3),
                                     out_size=self.out_size,
                                     key=key)
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape
  x = einops.rearrange(x, '(h k1) (w k2) c -> h w (c k1 k2)', k1=2, k2=2)
  x = self.conv(x)
  assert x.shape == (H//2, W//2, self.out_size)
  return x

generax.nn.layers.GatedGlobalContext ¤

Modified version of https://arxiv.org/pdf/1904.11492.pdf used in imagen https://github.com/lucidrains/imagen-pytorch/

Source code in generax/nn/layers.py
class GatedGlobalContext(eqx.Module):
  """Modified version of https://arxiv.org/pdf/1904.11492.pdf used in imagen https://github.com/lucidrains/imagen-pytorch/"""

  input_shape: int = eqx.field(static=True)
  linear1: WeightNormConv
  linear2: WeightNormConv
  context_conv: WeightNormConv

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape
    self.input_shape = input_shape
    out_size = C

    hidden_dim = max(3, out_size//2)
    k1, k2, k3 = random.split(key, 3)
    self.linear1 = WeightNormDense(in_size=C,
                                   out_size=hidden_dim,
                                   key=k1)

    self.linear2 = WeightNormDense(in_size=hidden_dim,
                                out_size=out_size,
                                key=k2)

    self.context_conv = WeightNormConv(input_shape=input_shape,
                                       filter_shape=(1, 1),
                                       out_size=1,
                                       key=k3)

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    x_in = x
    H, W, C = x.shape

    # Reduce channels to (H, W, 1)
    context = self.context_conv(x)

    # Flatten
    c_flat = einops.rearrange(context, 'h w c -> (h w) c')
    x_flat = einops.rearrange(x, 'h w c -> (h w) c')

    # Context over the pixels
    c_sm = jax.nn.softmax(c_flat, axis=0)

    # Reweight the channels
    out = jnp.einsum('tu,tv->uv', c_sm, x_flat)
    assert out.shape == (1, C)
    out = out[0]

    out = self.linear1(out)
    out = jax.nn.silu(out)
    out = self.linear2(out)
    out = jax.nn.sigmoid(out)
    return x_in*out[None,None,:]
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape
  self.input_shape = input_shape
  out_size = C

  hidden_dim = max(3, out_size//2)
  k1, k2, k3 = random.split(key, 3)
  self.linear1 = WeightNormDense(in_size=C,
                                 out_size=hidden_dim,
                                 key=k1)

  self.linear2 = WeightNormDense(in_size=hidden_dim,
                              out_size=out_size,
                              key=k2)

  self.context_conv = WeightNormConv(input_shape=input_shape,
                                     filter_shape=(1, 1),
                                     out_size=1,
                                     key=k3)
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  x_in = x
  H, W, C = x.shape

  # Reduce channels to (H, W, 1)
  context = self.context_conv(x)

  # Flatten
  c_flat = einops.rearrange(context, 'h w c -> (h w) c')
  x_flat = einops.rearrange(x, 'h w c -> (h w) c')

  # Context over the pixels
  c_sm = jax.nn.softmax(c_flat, axis=0)

  # Reweight the channels
  out = jnp.einsum('tu,tv->uv', c_sm, x_flat)
  assert out.shape == (1, C)
  out = out[0]

  out = self.linear1(out)
  out = jax.nn.silu(out)
  out = self.linear2(out)
  out = jax.nn.sigmoid(out)
  return x_in*out[None,None,:]

generax.nn.layers.Attention ¤

Attention(args, *kwargs)

Source code in generax/nn/layers.py
class Attention(eqx.Module):

  input_shape: int = eqx.field(static=True)
  heads: int = eqx.field(static=True)
  dim_head: int = eqx.field(static=True)
  scale: float = eqx.field(static=True)

  conv_in: eqx.nn.Conv3d
  conv_out: eqx.nn.Conv3d

  def __init__(self,
               input_shape: Tuple[int],
               heads: int = 4,
               dim_head: int = 32,
               scale: float = 10,
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape
    self.input_shape = input_shape
    self.heads = heads
    self.dim_head = dim_head
    self.scale = scale

    k1, k2 = random.split(key, 2)
    dim = self.dim_head*self.heads
    self.conv_in = ChannelConvention(eqx.nn.Conv2d(in_channels=C,
                                                   out_channels=3*dim,
                                                   kernel_size=1,
                                                   use_bias=False,
                                                   key=k1))
    self.conv_out = ChannelConvention(eqx.nn.Conv2d(in_channels=dim,
                                                    out_channels=C,
                                                    kernel_size=1,
                                                    use_bias=True,
                                                    key=k2))

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape
    qkv = self.conv_in(x) # (H, W, heads*dim_head*3)
    qkv = einops.rearrange(qkv, 'H W (u h d) -> (H W) h d u', h=self.heads, d=self.dim_head, u=3)
    q, k, v = jnp.split(qkv, 3, axis=-1)
    q, k, v = q[...,0], k[...,0], v[...,0]
    assert q.shape == k.shape == v.shape == (H*W, self.heads, self.dim_head)

    def normalize(x):
      return x/jnp.clip(jnp.linalg.norm(x, axis=0, keepdims=True), 1e-8)
    q, k = normalize(q), normalize(k)

    sim = jnp.einsum('ihd,jhd->hij', q, k)*self.scale
    attn = jax.nn.softmax(sim, axis=-1)
    assert attn.shape == (self.heads, H*W, H*W)

    out = jnp.einsum('hij,jhd->hid', attn, v)
    out = einops.rearrange(out, 'h (H W) d -> H W (h d)', H=H, W=W, h=self.heads, d=self.dim_head)
    assert out.shape == (H, W, self.dim_head*self.heads)

    out = self.conv_out(out)
    return out
__init__(self, input_shape: Tuple[int], heads: int = 4, dim_head: int = 32, scale: float = 10, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             heads: int = 4,
             dim_head: int = 32,
             scale: float = 10,
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape
  self.input_shape = input_shape
  self.heads = heads
  self.dim_head = dim_head
  self.scale = scale

  k1, k2 = random.split(key, 2)
  dim = self.dim_head*self.heads
  self.conv_in = ChannelConvention(eqx.nn.Conv2d(in_channels=C,
                                                 out_channels=3*dim,
                                                 kernel_size=1,
                                                 use_bias=False,
                                                 key=k1))
  self.conv_out = ChannelConvention(eqx.nn.Conv2d(in_channels=dim,
                                                  out_channels=C,
                                                  kernel_size=1,
                                                  use_bias=True,
                                                  key=k2))
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape
  qkv = self.conv_in(x) # (H, W, heads*dim_head*3)
  qkv = einops.rearrange(qkv, 'H W (u h d) -> (H W) h d u', h=self.heads, d=self.dim_head, u=3)
  q, k, v = jnp.split(qkv, 3, axis=-1)
  q, k, v = q[...,0], k[...,0], v[...,0]
  assert q.shape == k.shape == v.shape == (H*W, self.heads, self.dim_head)

  def normalize(x):
    return x/jnp.clip(jnp.linalg.norm(x, axis=0, keepdims=True), 1e-8)
  q, k = normalize(q), normalize(k)

  sim = jnp.einsum('ihd,jhd->hij', q, k)*self.scale
  attn = jax.nn.softmax(sim, axis=-1)
  assert attn.shape == (self.heads, H*W, H*W)

  out = jnp.einsum('hij,jhd->hid', attn, v)
  out = einops.rearrange(out, 'h (H W) d -> H W (h d)', H=H, W=W, h=self.heads, d=self.dim_head)
  assert out.shape == (H, W, self.dim_head*self.heads)

  out = self.conv_out(out)
  return out

generax.nn.layers.LinearAttention ¤

LinearAttention(args, *kwargs)

Source code in generax/nn/layers.py
class LinearAttention(eqx.Module):

  input_shape: int = eqx.field(static=True)
  heads: int = eqx.field(static=True)
  dim_head: int = eqx.field(static=True)

  conv_in: eqx.nn.Conv3d
  conv_out: eqx.nn.Conv3d
  norm: eqx.nn.LayerNorm

  def __init__(self,
               input_shape: Tuple[int],
               heads: int = 4,
               dim_head: int = 32,
               *,
               key: PRNGKeyArray,
               **kwargs):
    super().__init__(**kwargs)
    H, W, C = input_shape
    self.input_shape = input_shape
    self.heads = heads
    self.dim_head = dim_head

    k1, k2 = random.split(key, 2)
    dim = self.dim_head*self.heads
    self.conv_in = ChannelConvention(eqx.nn.Conv2d(in_channels=C,
                                                   out_channels=3*dim,
                                                   kernel_size=1,
                                                   use_bias=False,
                                                   key=k1))
    self.conv_out = ChannelConvention(eqx.nn.Conv2d(in_channels=dim,
                                                    out_channels=C,
                                                    kernel_size=1,
                                                    use_bias=True,
                                                    key=k2))
    self.norm = eqx.nn.LayerNorm(shape=(C,), use_bias=False)

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    H, W, C = x.shape
    qkv = self.conv_in(x) # (H, W, heads*dim_head*3)
    qkv = einops.rearrange(qkv, 'H W (u h d) -> (H W) h d u', h=self.heads, d=self.dim_head, u=3)
    q, k, v = jnp.split(qkv, 3, axis=-1)
    q, k, v = q[...,0], k[...,0], v[...,0]
    assert q.shape == k.shape == v.shape == (H*W, self.heads, self.dim_head)

    q = jax.nn.softmax(q, axis=-1)
    k = jax.nn.softmax(k, axis=-3)

    q = q/jnp.sqrt(self.dim_head)
    v = v/(H*W)

    context = jnp.einsum("n h d, n h e -> h d e", k, v)
    out = jnp.einsum("h d e, n h d -> h e n", context, q)
    out = einops.rearrange(out, "h e (x y) -> x y (h e)", x=H)
    assert out.shape == (H, W, self.dim_head*self.heads)

    out = self.conv_out(out)
    out = eqx.filter_vmap(eqx.filter_vmap(self.norm))(out)
    return out
__init__(self, input_shape: Tuple[int], heads: int = 4, dim_head: int = 32, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             heads: int = 4,
             dim_head: int = 32,
             *,
             key: PRNGKeyArray,
             **kwargs):
  super().__init__(**kwargs)
  H, W, C = input_shape
  self.input_shape = input_shape
  self.heads = heads
  self.dim_head = dim_head

  k1, k2 = random.split(key, 2)
  dim = self.dim_head*self.heads
  self.conv_in = ChannelConvention(eqx.nn.Conv2d(in_channels=C,
                                                 out_channels=3*dim,
                                                 kernel_size=1,
                                                 use_bias=False,
                                                 key=k1))
  self.conv_out = ChannelConvention(eqx.nn.Conv2d(in_channels=dim,
                                                  out_channels=C,
                                                  kernel_size=1,
                                                  use_bias=True,
                                                  key=k2))
  self.norm = eqx.nn.LayerNorm(shape=(C,), use_bias=False)
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  H, W, C = x.shape
  qkv = self.conv_in(x) # (H, W, heads*dim_head*3)
  qkv = einops.rearrange(qkv, 'H W (u h d) -> (H W) h d u', h=self.heads, d=self.dim_head, u=3)
  q, k, v = jnp.split(qkv, 3, axis=-1)
  q, k, v = q[...,0], k[...,0], v[...,0]
  assert q.shape == k.shape == v.shape == (H*W, self.heads, self.dim_head)

  q = jax.nn.softmax(q, axis=-1)
  k = jax.nn.softmax(k, axis=-3)

  q = q/jnp.sqrt(self.dim_head)
  v = v/(H*W)

  context = jnp.einsum("n h d, n h e -> h d e", k, v)
  out = jnp.einsum("h d e, n h d -> h e n", context, q)
  out = einops.rearrange(out, "h e (x y) -> x y (h e)", x=H)
  assert out.shape == (H, W, self.dim_head*self.heads)

  out = self.conv_out(out)
  out = eqx.filter_vmap(eqx.filter_vmap(self.norm))(out)
  return out

generax.nn.layers.AttentionBlock ¤

AttentionBlock(args, *kwargs)

Source code in generax/nn/layers.py
class AttentionBlock(eqx.Module):

  input_shape: int = eqx.field(static=True)
  attn: Union[Attention, LinearAttention]
  norm: eqx.nn.LayerNorm

  def __init__(self,
               input_shape: Tuple[int],
               heads: int = 4,
               dim_head: int = 32,
               *,
               key: PRNGKeyArray,
               use_linear_attention: bool = True,
               **kwargs):
    super().__init__(**kwargs)

    if use_linear_attention:
      self.attn = LinearAttention(input_shape=input_shape,
                                  heads=heads,
                                  dim_head=dim_head,
                                  key=key)
    else:
      self.attn = Attention(input_shape=input_shape,
                            heads=heads,
                            dim_head=dim_head,
                            key=key)
    self.input_shape = self.attn.input_shape
    H, W, C = input_shape
    self.norm = eqx.nn.LayerNorm(shape=(C,), use_bias=False)

  def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
    return self

  def __call__(self, x: Array, y: Array = None) -> Array:
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    normed_x = eqx.filter_vmap(eqx.filter_vmap(self.norm))(x)
    out = self.attn(normed_x)
    return out + x
__init__(self, input_shape: Tuple[int], heads: int = 4, dim_head: int = 32, *, key: PRNGKeyArray, use_linear_attention: bool = True, **kwargs) ¤
Source code in generax/nn/layers.py
def __init__(self,
             input_shape: Tuple[int],
             heads: int = 4,
             dim_head: int = 32,
             *,
             key: PRNGKeyArray,
             use_linear_attention: bool = True,
             **kwargs):
  super().__init__(**kwargs)

  if use_linear_attention:
    self.attn = LinearAttention(input_shape=input_shape,
                                heads=heads,
                                dim_head=dim_head,
                                key=key)
  else:
    self.attn = Attention(input_shape=input_shape,
                          heads=heads,
                          dim_head=dim_head,
                          key=key)
  self.input_shape = self.attn.input_shape
  H, W, C = input_shape
  self.norm = eqx.nn.LayerNorm(shape=(C,), use_bias=False)
data_dependent_init(self, *args, **kwargs) -> Module ¤
Source code in generax/nn/layers.py
def data_dependent_init(self, *args, **kwargs) -> eqx.Module:
  return self
__call__(self, x: Array, y: Array = None) -> Array ¤

Call self as a function.

Source code in generax/nn/layers.py
def __call__(self, x: Array, y: Array = None) -> Array:
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  normed_x = eqx.filter_vmap(eqx.filter_vmap(self.norm))(x)
  out = self.attn(normed_x)
  return out + x

generax.nn.grad_wrapper.GradWrapper ¤

An easy wrapper around a function that computes the gradient of a scalar function.

Source code in generax/nn/grad_wrapper.py
class GradWrapper(eqx.Module):
  """An easy wrapper around a function that computes the gradient of a scalar function."""

  net: eqx.Module
  input_shape: Tuple[int, ...]

  def __init__(self,
               net: eqx.Module):
    self.net = net
    self.input_shape = net.input_shape

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'
    out = self.net(x, y=y, key=key)
    assert out.shape == (1,)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               **kwargs) -> Array:
    """**Arguments:**

    - `t`: A JAX array with shape `()`.
    - `x`: A JAX array with shape `(input_shape,)`.
    - `y`: A JAX array with shape `(cond_shape,)`.

    **Returns:**

    A JAX array with shape `(input_shape,)`.
    """
    assert x.shape == self.input_shape

    def net(x):
      net_out = self.net(x, y=y, **kwargs)
      if net_out.shape != (1,):
        raise ValueError(f'Expected net to return a scalar, but got {net_out.shape}')
      return net_out.ravel()

    return eqx.filter_grad(net)(x)

  @property
  def energy(self):
    return self.net
energy property readonly ¤
__init__(self, net: Module) ¤
Source code in generax/nn/grad_wrapper.py
def __init__(self,
             net: eqx.Module):
  self.net = net
  self.input_shape = net.input_shape
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/grad_wrapper.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'
  out = self.net(x, y=y, key=key)
  assert out.shape == (1,)
__call__(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • t: A JAX array with shape ().
  • x: A JAX array with shape (input_shape,).
  • y: A JAX array with shape (cond_shape,).

Returns:

A JAX array with shape (input_shape,).

Source code in generax/nn/grad_wrapper.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             **kwargs) -> Array:
  """**Arguments:**

  - `t`: A JAX array with shape `()`.
  - `x`: A JAX array with shape `(input_shape,)`.
  - `y`: A JAX array with shape `(cond_shape,)`.

  **Returns:**

  A JAX array with shape `(input_shape,)`.
  """
  assert x.shape == self.input_shape

  def net(x):
    net_out = self.net(x, y=y, **kwargs)
    if net_out.shape != (1,):
      raise ValueError(f'Expected net to return a scalar, but got {net_out.shape}')
    return net_out.ravel()

  return eqx.filter_grad(net)(x)

generax.nn.grad_wrapper.TimeDependentGradWrapper (GradWrapper) ¤

An easy wrapper around a function that computes the gradient of a scalar function.

Source code in generax/nn/grad_wrapper.py
class TimeDependentGradWrapper(GradWrapper):
  """An easy wrapper around a function that computes the gradient of a scalar function."""

  def data_dependent_init(self,
                          t: Array,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `t`: The time to initialize the parameters with.
    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'
    out = self.net(t, x, y=y, key=key)
    assert out.shape == (1,)

  def __call__(self,
               t: Array,
               x: Array,
               y: Optional[Array] = None,
               **kwargs) -> Array:
    """**Arguments:**

    - `t`: A JAX array with shape `()`.
    - `x`: A JAX array with shape `(input_shape,)`.
    - `y`: A JAX array with shape `(cond_shape,)`.

    **Returns:**

    A JAX array with shape `(input_shape,)`.
    """
    assert x.shape == self.input_shape

    def net(x):
      net_out = self.net(t, x, y=y, **kwargs)
      if net_out.shape != (1,):
        raise ValueError(f'Expected net to return a scalar, but got {net_out.shape}')
      return net_out[0]
    return eqx.filter_grad(net)(x)

  @property
  def energy(self):
    return self.net
energy property readonly ¤
__init__(self, net: Module) ¤
Source code in generax/nn/grad_wrapper.py
def __init__(self,
             net: eqx.Module):
  self.net = net
  self.input_shape = net.input_shape
data_dependent_init(self, t: Array, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • t: The time to initialize the parameters with.
  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/grad_wrapper.py
def data_dependent_init(self,
                        t: Array,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `t`: The time to initialize the parameters with.
  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'
  out = self.net(t, x, y=y, key=key)
  assert out.shape == (1,)
__call__(self, t: Array, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • t: A JAX array with shape ().
  • x: A JAX array with shape (input_shape,).
  • y: A JAX array with shape (cond_shape,).

Returns:

A JAX array with shape (input_shape,).

Source code in generax/nn/grad_wrapper.py
def __call__(self,
             t: Array,
             x: Array,
             y: Optional[Array] = None,
             **kwargs) -> Array:
  """**Arguments:**

  - `t`: A JAX array with shape `()`.
  - `x`: A JAX array with shape `(input_shape,)`.
  - `y`: A JAX array with shape `(cond_shape,)`.

  **Returns:**

  A JAX array with shape `(input_shape,)`.
  """
  assert x.shape == self.input_shape

  def net(x):
    net_out = self.net(t, x, y=y, **kwargs)
    if net_out.shape != (1,):
      raise ValueError(f'Expected net to return a scalar, but got {net_out.shape}')
    return net_out[0]
  return eqx.filter_grad(net)(x)