Skip to content

Time conditioner¤

generax.nn.time_condition.GaussianFourierProjection ¤

GaussianFourierProjection(args, *kwargs)

Source code in generax/nn/time_condition.py
class GaussianFourierProjection(eqx.Module):

  embedding_size: int = eqx.field(static=True)
  W: eqx.nn.Linear

  def __init__(self,
               embedding_size: Optional[int] = 16,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `embedding_size`: The size of the embedding.
    """
    super().__init__(**kwargs)

    self.embedding_size = embedding_size
    self.W = eqx.nn.Linear(in_features=1,
                           out_features=embedding_size,
                           use_bias=False,
                           key=key)

  def __call__(self, t: Array) -> Array:
    """**Arguments:**

    - `t`: A JAX array with shape `()`.

    **Returns:**

    A JAX array with shape `(2*embedding_size,)`.
    """
    assert t.shape == ()
    t = jnp.expand_dims(t, axis=-1)
    t_proj = self.W(t*2*jnp.pi)
    return jnp.concatenate([jnp.sin(t_proj), jnp.cos(t_proj)], axis=-1)
__init__(self, embedding_size: Optional[int] = 16, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • embedding_size: The size of the embedding.
Source code in generax/nn/time_condition.py
def __init__(self,
             embedding_size: Optional[int] = 16,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `embedding_size`: The size of the embedding.
  """
  super().__init__(**kwargs)

  self.embedding_size = embedding_size
  self.W = eqx.nn.Linear(in_features=1,
                         out_features=embedding_size,
                         use_bias=False,
                         key=key)
__call__(self, t: Array) -> Array ¤

Arguments:

  • t: A JAX array with shape ().

Returns:

A JAX array with shape (2*embedding_size,).

Source code in generax/nn/time_condition.py
def __call__(self, t: Array) -> Array:
  """**Arguments:**

  - `t`: A JAX array with shape `()`.

  **Returns:**

  A JAX array with shape `(2*embedding_size,)`.
  """
  assert t.shape == ()
  t = jnp.expand_dims(t, axis=-1)
  t_proj = self.W(t*2*jnp.pi)
  return jnp.concatenate([jnp.sin(t_proj), jnp.cos(t_proj)], axis=-1)

generax.nn.time_condition.TimeFeatures ¤

TimeFeatures(args, *kwargs)

Source code in generax/nn/time_condition.py
class TimeFeatures(eqx.Module):

  out_features: int = eqx.field(static=True)
  projection: GaussianFourierProjection
  W1: Array
  W2: Array
  activation: Callable

  def __init__(self,
               embedding_size: Optional[int] = 16,
               out_features: int=8,
               activation: Callable = jax.nn.gelu,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `embedding_size`: The size of the embedding.
    - `out_features`: The number of output features.
    - `activation`: The activation function.
    """
    super().__init__(**kwargs)
    self.out_features = out_features

    k1, k2, k3 = random.split(key, 3)
    self.projection = GaussianFourierProjection(embedding_size=embedding_size,
                                                key=k1)
    self.W1 = eqx.nn.Linear(in_features=2*embedding_size,
                            out_features=4*embedding_size,
                            key=k2)
    self.activation = activation
    self.W2 = eqx.nn.Linear(in_features=4*embedding_size,
                            out_features=self.out_features,
                            key=k3)

  def __call__(self, t: Array) -> Array:
    """**Arguments:**

    - `t`: A JAX array with shape `()`.

    **Returns:**

    A JAX array with shape `(out_features,)`.
    """
    assert t.shape == ()
    x = self.projection(t)
    x = self.W1(x)
    x = self.activation(x)
    return self.W2(x)
__init__(self, embedding_size: Optional[int] = 16, out_features: int = 8, activation: Callable = <function gelu>, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • embedding_size: The size of the embedding.
  • out_features: The number of output features.
  • activation: The activation function.
Source code in generax/nn/time_condition.py
def __init__(self,
             embedding_size: Optional[int] = 16,
             out_features: int=8,
             activation: Callable = jax.nn.gelu,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `embedding_size`: The size of the embedding.
  - `out_features`: The number of output features.
  - `activation`: The activation function.
  """
  super().__init__(**kwargs)
  self.out_features = out_features

  k1, k2, k3 = random.split(key, 3)
  self.projection = GaussianFourierProjection(embedding_size=embedding_size,
                                              key=k1)
  self.W1 = eqx.nn.Linear(in_features=2*embedding_size,
                          out_features=4*embedding_size,
                          key=k2)
  self.activation = activation
  self.W2 = eqx.nn.Linear(in_features=4*embedding_size,
                          out_features=self.out_features,
                          key=k3)
__call__(self, t: Array) -> Array ¤

Arguments:

  • t: A JAX array with shape ().

Returns:

A JAX array with shape (out_features,).

Source code in generax/nn/time_condition.py
def __call__(self, t: Array) -> Array:
  """**Arguments:**

  - `t`: A JAX array with shape `()`.

  **Returns:**

  A JAX array with shape `(out_features,)`.
  """
  assert t.shape == ()
  x = self.projection(t)
  x = self.W1(x)
  x = self.activation(x)
  return self.W2(x)