Skip to content

Resnet¤

generax.nn.resnet.ResNet ¤

ResNet for 1d data

Source code in generax/nn/resnet.py
class ResNet(eqx.Module):
  """ResNet for 1d data"""

  n_blocks: int = eqx.field(static=True)
  blocks: tuple[GatedResBlock, ...]
  in_projection: eqx.nn.Linear
  out_projection: eqx.nn.Linear

  input_shape: int = eqx.field(static=True)
  cond_shape: int = eqx.field(static=True)
  working_size: int = eqx.field(static=True)
  hidden_size: int = eqx.field(static=True)
  out_size: int = eqx.field(static=True)
  filter_shape: Union[Tuple[int],None] = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               working_size: int,
               hidden_size: int,
               out_size: int,
               n_blocks: int,
               filter_shape: Optional[Tuple[int]] = (3, 3),
               cond_shape: Optional[Tuple[int]] = None,
               activation: Callable = jax.nn.swish,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input size.  Output size is the same as input_shape.
    - `working_size`: The size (channels for images) of each hidden layer.
    - `hidden_size`: The size (channels for images) of each hidden layer.
    - `out_size`: The output size.  For images, this is the number of output
                  channels.
    - `n_blocks`: The number of residual blocks.
    - `cond_shape`: The size of the conditioning information.
    - `activation`: The activation function in each residual block.
    - `key`: A `jax.random.PRNGKey` for initialization.
    """
    super().__init__(**kwargs)

    if len(input_shape) not in [1, 3]:
      raise ValueError(f'Expected 1d or 3d input shape')

    image = False
    if len(input_shape) == 3:
      H, W, C = input_shape
      image = True
      assert filter_shape is not None, 'Must pass in filter shape when processing images'

    self.n_blocks = n_blocks
    self.working_size = working_size
    self.hidden_size = hidden_size
    self.filter_shape = filter_shape
    self.out_size = out_size

    k1, k2, k3 = random.split(key, 3)

    if isinstance(input_shape, int):
      input_shape = (input_shape,)
    self.input_shape = input_shape
    self.cond_shape = cond_shape

    if image == False:
      self.in_projection = WeightNormDense(in_size=input_shape[0],
                                          out_size=working_size,
                                          key=k1)
      working_shape = (working_size,)
    else:
      self.in_projection = ConvAndGroupNorm(input_shape=input_shape,
                                        out_size=working_size,
                                        filter_shape=filter_shape,
                                        groups=1,
                                        key=k1)
      working_shape = (H, W, working_size)

    def make_resblock(k):
      return GatedResBlock(input_shape=working_shape,
                          hidden_size=hidden_size,
                          cond_shape=cond_shape,
                          activation=activation,
                          filter_shape=filter_shape,
                          key=k)

    keys = random.split(k2, n_blocks)
    self.blocks = eqx.filter_vmap(make_resblock)(keys)

    if image == False:
      self.out_projection = WeightNormDense(in_size=working_size,
                                            out_size=out_size,
                                            key=k3)
    else:
      self.out_projection = ConvAndGroupNorm(input_shape=working_shape,
                                           out_size=out_size,
                                           filter_shape=filter_shape,
                                           groups=1,
                                           key=k3)

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert x.shape[1:] == self.input_shape, 'Only works on batched data'

    k1, k2, k3 = random.split(key, 3)

    # Input projection
    in_proj = self.in_projection.data_dependent_init(x, key=k1)
    x = eqx.filter_vmap(in_proj)(x)

    # Scan over the vmapped blocks
    params, state = eqx.partition(self.blocks, eqx.is_array)
    def scan_body(x, inputs):
      key, block_params = inputs
      block = eqx.combine(block_params, state)
      new_block = block.data_dependent_init(x, y, key=key)
      new_x = eqx.filter_vmap(new_block)(x, y)
      new_params, _ = eqx.partition(block, eqx.is_array)
      return new_x, new_params

    keys = random.split(k2, self.n_blocks)
    x, params = jax.lax.scan(scan_body, x, (keys, params))
    blocks = eqx.combine(params, state)

    out_proj = self.out_projection.data_dependent_init(x, key=k3)

    # Turn the new parameters into a new module
    get_in_proj = lambda tree: tree.in_projection
    get_blocks = lambda tree: tree.blocks
    get_out_proj = lambda tree: tree.out_projection

    updated_layer = eqx.tree_at(get_in_proj, self, in_proj)
    updated_layer = eqx.tree_at(get_blocks, updated_layer, blocks)
    updated_layer = eqx.tree_at(get_out_proj, updated_layer, out_proj)

    return updated_layer

  def __call__(self,
                x: Array,
                y: Optional[Array] = None,
                **kwargs) -> Array:
    """**Arguments:**

    - `t`: A JAX array with shape `()`.
    - `x`: A JAX array with shape `(input_shape,)`.
    - `y`: A JAX array with shape `(cond_shape,)`.

    **Returns:**

    A JAX array with shape `(input_shape,)`.
    """
    assert x.shape == self.input_shape

    # Input projection
    x = self.in_projection(x)

    # Resnet blocks
    dynamic, static = eqx.partition(self.blocks, eqx.is_array)

    def f(x, params):
        block = eqx.combine(params, static)
        return block(x, y), None

    out, _ = jax.lax.scan(f, x, dynamic)

    # Output projection
    out = self.out_projection(out)
    return out
__init__(self, input_shape: Tuple[int], working_size: int, hidden_size: int, out_size: int, n_blocks: int, filter_shape: Optional[Tuple[int]] = (3, 3), cond_shape: Optional[Tuple[int]] = None, activation: Callable = <PjitFunction of <function silu at 0x7fd99f02dfc0>>, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input size. Output size is the same as input_shape.
  • working_size: The size (channels for images) of each hidden layer.
  • hidden_size: The size (channels for images) of each hidden layer.
  • out_size: The output size. For images, this is the number of output channels.
  • n_blocks: The number of residual blocks.
  • cond_shape: The size of the conditioning information.
  • activation: The activation function in each residual block.
  • key: A jax.random.PRNGKey for initialization.
Source code in generax/nn/resnet.py
def __init__(self,
             input_shape: Tuple[int],
             working_size: int,
             hidden_size: int,
             out_size: int,
             n_blocks: int,
             filter_shape: Optional[Tuple[int]] = (3, 3),
             cond_shape: Optional[Tuple[int]] = None,
             activation: Callable = jax.nn.swish,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input size.  Output size is the same as input_shape.
  - `working_size`: The size (channels for images) of each hidden layer.
  - `hidden_size`: The size (channels for images) of each hidden layer.
  - `out_size`: The output size.  For images, this is the number of output
                channels.
  - `n_blocks`: The number of residual blocks.
  - `cond_shape`: The size of the conditioning information.
  - `activation`: The activation function in each residual block.
  - `key`: A `jax.random.PRNGKey` for initialization.
  """
  super().__init__(**kwargs)

  if len(input_shape) not in [1, 3]:
    raise ValueError(f'Expected 1d or 3d input shape')

  image = False
  if len(input_shape) == 3:
    H, W, C = input_shape
    image = True
    assert filter_shape is not None, 'Must pass in filter shape when processing images'

  self.n_blocks = n_blocks
  self.working_size = working_size
  self.hidden_size = hidden_size
  self.filter_shape = filter_shape
  self.out_size = out_size

  k1, k2, k3 = random.split(key, 3)

  if isinstance(input_shape, int):
    input_shape = (input_shape,)
  self.input_shape = input_shape
  self.cond_shape = cond_shape

  if image == False:
    self.in_projection = WeightNormDense(in_size=input_shape[0],
                                        out_size=working_size,
                                        key=k1)
    working_shape = (working_size,)
  else:
    self.in_projection = ConvAndGroupNorm(input_shape=input_shape,
                                      out_size=working_size,
                                      filter_shape=filter_shape,
                                      groups=1,
                                      key=k1)
    working_shape = (H, W, working_size)

  def make_resblock(k):
    return GatedResBlock(input_shape=working_shape,
                        hidden_size=hidden_size,
                        cond_shape=cond_shape,
                        activation=activation,
                        filter_shape=filter_shape,
                        key=k)

  keys = random.split(k2, n_blocks)
  self.blocks = eqx.filter_vmap(make_resblock)(keys)

  if image == False:
    self.out_projection = WeightNormDense(in_size=working_size,
                                          out_size=out_size,
                                          key=k3)
  else:
    self.out_projection = ConvAndGroupNorm(input_shape=working_shape,
                                         out_size=out_size,
                                         filter_shape=filter_shape,
                                         groups=1,
                                         key=k3)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/resnet.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert x.shape[1:] == self.input_shape, 'Only works on batched data'

  k1, k2, k3 = random.split(key, 3)

  # Input projection
  in_proj = self.in_projection.data_dependent_init(x, key=k1)
  x = eqx.filter_vmap(in_proj)(x)

  # Scan over the vmapped blocks
  params, state = eqx.partition(self.blocks, eqx.is_array)
  def scan_body(x, inputs):
    key, block_params = inputs
    block = eqx.combine(block_params, state)
    new_block = block.data_dependent_init(x, y, key=key)
    new_x = eqx.filter_vmap(new_block)(x, y)
    new_params, _ = eqx.partition(block, eqx.is_array)
    return new_x, new_params

  keys = random.split(k2, self.n_blocks)
  x, params = jax.lax.scan(scan_body, x, (keys, params))
  blocks = eqx.combine(params, state)

  out_proj = self.out_projection.data_dependent_init(x, key=k3)

  # Turn the new parameters into a new module
  get_in_proj = lambda tree: tree.in_projection
  get_blocks = lambda tree: tree.blocks
  get_out_proj = lambda tree: tree.out_projection

  updated_layer = eqx.tree_at(get_in_proj, self, in_proj)
  updated_layer = eqx.tree_at(get_blocks, updated_layer, blocks)
  updated_layer = eqx.tree_at(get_out_proj, updated_layer, out_proj)

  return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • t: A JAX array with shape ().
  • x: A JAX array with shape (input_shape,).
  • y: A JAX array with shape (cond_shape,).

Returns:

A JAX array with shape (input_shape,).

Source code in generax/nn/resnet.py
def __call__(self,
              x: Array,
              y: Optional[Array] = None,
              **kwargs) -> Array:
  """**Arguments:**

  - `t`: A JAX array with shape `()`.
  - `x`: A JAX array with shape `(input_shape,)`.
  - `y`: A JAX array with shape `(cond_shape,)`.

  **Returns:**

  A JAX array with shape `(input_shape,)`.
  """
  assert x.shape == self.input_shape

  # Input projection
  x = self.in_projection(x)

  # Resnet blocks
  dynamic, static = eqx.partition(self.blocks, eqx.is_array)

  def f(x, params):
      block = eqx.combine(params, static)
      return block(x, y), None

  out, _ = jax.lax.scan(f, x, dynamic)

  # Output projection
  out = self.out_projection(out)
  return out

generax.nn.resnet.TimeDependentResNet (ResNet) ¤

A time dependent version of a 1d resnet

Source code in generax/nn/resnet.py
class TimeDependentResNet(ResNet):
  """A time dependent version of a 1d resnet
  """

  time_features: TimeFeatures

  def __init__(self,
               input_shape: Tuple[int],
               working_size: int,
               hidden_size: int,
               out_size: int,
               n_blocks: int,
               filter_shape: Optional[Tuple[int]] = (3, 3),
               cond_shape: Optional[Tuple[int]] = None,
               activation: Callable = jax.nn.swish,
               embedding_size: Optional[int] = 16,
               out_features: int=8,
               *,
               key: PRNGKeyArray,
               **kwargs):
    k1, k2 = random.split(key, 2)
    self.time_features = TimeFeatures(embedding_size=embedding_size,
                                      out_features=out_features,
                                      key=k1,
                                      **kwargs)

    total_cond_size = out_features
    if cond_shape is not None:
      if len(cond_shape) != 1:
        raise ValueError(f'Expected 1d conditional input.')
      total_cond_size += cond_shape[0]

    super().__init__(input_shape=input_shape,
                     working_size=working_size,
                     hidden_size=hidden_size,
                     out_size=out_size,
                     n_blocks=n_blocks,
                     filter_shape=filter_shape,
                     cond_shape=(total_cond_size,),
                     activation=activation,
                     key=k2,
                     **kwargs)

  def data_dependent_init(self,
                          t: Array,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> eqx.Module:
    """Initialize the parameters of the layer based on the data.

    **Arguments**:

    - `t`: The time to initialize the parameters with.
    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new layer with the parameters initialized.
    """
    assert t.ndim == 1
    h = eqx.filter_vmap(self.time_features)(t)
    if y is not None:
      h = jnp.concatenate([h, y], axis=-1)
    return super().data_dependent_init(x, y=h, key=key)

  def __call__(self,
               t: Array,
               x: Array,
               y: Optional[Array] = None,
               **kwargs) -> Array:
    assert t.shape == ()

    h = self.time_features(t)
    if y is not None:
      h = jnp.concatenate([h, y], axis=-1)

    return super().__call__(x, y=h)
__init__(self, input_shape: Tuple[int], working_size: int, hidden_size: int, out_size: int, n_blocks: int, filter_shape: Optional[Tuple[int]] = (3, 3), cond_shape: Optional[Tuple[int]] = None, activation: Callable = <PjitFunction of <function silu at 0x7fd99f02dfc0>>, embedding_size: Optional[int] = 16, out_features: int = 8, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/nn/resnet.py
def __init__(self,
             input_shape: Tuple[int],
             working_size: int,
             hidden_size: int,
             out_size: int,
             n_blocks: int,
             filter_shape: Optional[Tuple[int]] = (3, 3),
             cond_shape: Optional[Tuple[int]] = None,
             activation: Callable = jax.nn.swish,
             embedding_size: Optional[int] = 16,
             out_features: int=8,
             *,
             key: PRNGKeyArray,
             **kwargs):
  k1, k2 = random.split(key, 2)
  self.time_features = TimeFeatures(embedding_size=embedding_size,
                                    out_features=out_features,
                                    key=k1,
                                    **kwargs)

  total_cond_size = out_features
  if cond_shape is not None:
    if len(cond_shape) != 1:
      raise ValueError(f'Expected 1d conditional input.')
    total_cond_size += cond_shape[0]

  super().__init__(input_shape=input_shape,
                   working_size=working_size,
                   hidden_size=hidden_size,
                   out_size=out_size,
                   n_blocks=n_blocks,
                   filter_shape=filter_shape,
                   cond_shape=(total_cond_size,),
                   activation=activation,
                   key=k2,
                   **kwargs)
data_dependent_init(self, t: Array, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • t: The time to initialize the parameters with.
  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/nn/resnet.py
def data_dependent_init(self,
                        t: Array,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> eqx.Module:
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `t`: The time to initialize the parameters with.
  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  assert t.ndim == 1
  h = eqx.filter_vmap(self.time_features)(t)
  if y is not None:
    h = jnp.concatenate([h, y], axis=-1)
  return super().data_dependent_init(x, y=h, key=key)
__call__(self, t: Array, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • t: A JAX array with shape ().
  • x: A JAX array with shape (input_shape,).
  • y: A JAX array with shape (cond_shape,).

Returns:

A JAX array with shape (input_shape,).

Source code in generax/nn/resnet.py
def __call__(self,
             t: Array,
             x: Array,
             y: Optional[Array] = None,
             **kwargs) -> Array:
  assert t.shape == ()

  h = self.time_features(t)
  if y is not None:
    h = jnp.concatenate([h, y], axis=-1)

  return super().__call__(x, y=h)