Skip to content

Flows¤

generax.distributions.flow_models.NormalizingFlow (ProbabilityDistribution) ¤

A normalizing flow is a model that we use to represent probability distributions. See this for an overview.

Atributes:

  • transform: A BijectiveTransform object that transforms a variable from the base space to the data space and also computes the change is log pdf.
  • transform(x) -> (z,log_det): Apply the transformation to the input.
  • prior: The prior probability distribution.

Methods:

  • to_base_space(x) -> z: Transform a point from the data space to the base space.
  • sample_and_log_prob(key) -> (x,log_px): Sample from the distribution and compute the log probability.
  • sample(key) -> x: Pull a single sample from the model
  • log_prob(x) -> log_px: Compute the log probability of a point under the model
Source code in generax/distributions/flow_models.py
class NormalizingFlow(ProbabilityDistribution, ABC):
  """A normalizing flow is a model that we use to represent probability
  distributions.  See [this](https://arxiv.org/pdf/1912.02762.pdf) for an overview.

  **Atributes**:

  - `transform`: A `BijectiveTransform` object that transforms a variable
                  from the base space to the data space and also computes
                  the change is log pdf.
    - `transform(x) -> (z,log_det)`: Apply the transformation to the input.
  - `prior`: The prior probability distribution.

  **Methods**:

  - `to_base_space(x) -> z`: Transform a point from the data space to the base space.
  - `sample_and_log_prob(key) -> (x,log_px)`: Sample from the distribution and compute the log probability.
  - `sample(key) -> x`: Pull a single sample from the model
  - `log_prob(x) -> log_px`: Compute the log probability of a point under the model
  """
  transform: BijectiveTransform
  prior: ProbabilityDistribution

  def __init__(self,
               transform: BijectiveTransform,
               prior: ProbabilityDistribution,
               **kwargs):
    """**Arguments**:

    - `transform`: A bijective transformation
    - `prior`: The prior distribution
    """
    self.transform = transform
    self.prior = prior
    input_shape = self.transform.input_shape
    super().__init__(input_shape=input_shape, **kwargs)

  def to_base_space(self,
                    x: Array,
                    y: Optional[Array] = None,
                    **kwargs) -> Array:
    """**Arguments**:

    - `x`: A JAX array with shape `(dim,)`.
    - `y`: The conditioning information

    **Returns**:
    A JAX array with shape `(dim,)`.
    """
    return self.transform(x, y=y, **kwargs)[0]

  def to_data_space(self,
                    z: Array,
                    y: Optional[Array] = None,
                    **kwargs) -> Array:
    """**Arguments**:

    - `z`: A JAX array with shape `(dim,)`.
    - `y`: The conditioning information

    **Returns**:
    A JAX array with shape `(dim,)`.
    """
    return self.transform(z, y=y, inverse=True, **kwargs)[0]

  def sample_and_log_prob(self,
                          key: PRNGKeyArray,
                          y: Optional[Array] = None,
                          **kwargs) -> Array:
    """**Arguments**:

    - `key`: The random number generator key.
    - `y`: The conditioning information

    **Returns**:
    A single sample from the model.  Use vmap to get more samples.
    """
    z, log_pz = self.prior.sample_and_log_prob(key)
    x, log_det = self.transform(z, y=y, inverse=True, **kwargs)
    # The log determinant of the inverse transform has a negative sign!
    return x, log_pz - log_det

  def log_prob(self,
               x: Array,
               y: Optional[Array] = None,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The point we want to compute logp(x) at.
    - `y`: The conditioning information

    **Returns**:
    The log likelihood of x under the model.
    """
    z, log_det = self.transform(x, y=y, **kwargs)
    log_pz = self.prior.log_prob(z)
    return log_pz + log_det

  def data_dependent_init(self,
                          x: Array,
                          y: Optional[Array] = None,
                          key: PRNGKeyArray = None) -> BijectiveTransform:
    """Initialize the parameters of the layer based on a batch of data.
    This is one of the few times that $x$ is expected to be batched.

    **Arguments**:

    - `x`: The data to initialize the parameters with.
    - `y`: The conditioning information
    - `key`: A `jax.random.PRNGKey` for initialization

    **Returns**:
    A new flow with the parameters initialized.
    """
    new_layer = self.transform.data_dependent_init(x, y=y, key=key)

    # Turn the new parameters into a new module
    get_transform = lambda tree: tree.transform
    return eqx.tree_at(get_transform, self, new_layer)
__init__(self, transform: BijectiveTransform, prior: ProbabilityDistribution, **kwargs) ¤

Arguments:

  • transform: A bijective transformation
  • prior: The prior distribution
Source code in generax/distributions/flow_models.py
def __init__(self,
             transform: BijectiveTransform,
             prior: ProbabilityDistribution,
             **kwargs):
  """**Arguments**:

  - `transform`: A bijective transformation
  - `prior`: The prior distribution
  """
  self.transform = transform
  self.prior = prior
  input_shape = self.transform.input_shape
  super().__init__(input_shape=input_shape, **kwargs)
sample(self, key: PRNGKeyArray, y: Optional[Array] = None) -> Array ¤

Inherited from generax.distributions.base.ProbabilityDistribution.sample.

Source code in generax/distributions/flow_models.py
def sample(self,
           key: PRNGKeyArray,
           y: Optional[Array] = None) -> Array:
  """
  **Arguments**:

  - `key`: The random number generator key.

  **Returns**:
  Samples from the model

  Use eqx.filter_vmap to get more samples!  For example,
  ```python
  keys = random.split(key, n_samples)
  samples = eqx.filter_vmap(self.sample)(keys)
  ```
  """
  return self.sample_and_log_prob(key, y)[0]
sample_and_log_prob(self, key: PRNGKeyArray, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • key: The random number generator key.
  • y: The conditioning information

Returns: A single sample from the model. Use vmap to get more samples.

Source code in generax/distributions/flow_models.py
def sample_and_log_prob(self,
                        key: PRNGKeyArray,
                        y: Optional[Array] = None,
                        **kwargs) -> Array:
  """**Arguments**:

  - `key`: The random number generator key.
  - `y`: The conditioning information

  **Returns**:
  A single sample from the model.  Use vmap to get more samples.
  """
  z, log_pz = self.prior.sample_and_log_prob(key)
  x, log_det = self.transform(z, y=y, inverse=True, **kwargs)
  # The log determinant of the inverse transform has a negative sign!
  return x, log_pz - log_det
log_prob(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • x: The point we want to compute logp(x) at.
  • y: The conditioning information

Returns: The log likelihood of x under the model.

Source code in generax/distributions/flow_models.py
def log_prob(self,
             x: Array,
             y: Optional[Array] = None,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The point we want to compute logp(x) at.
  - `y`: The conditioning information

  **Returns**:
  The log likelihood of x under the model.
  """
  z, log_det = self.transform(x, y=y, **kwargs)
  log_pz = self.prior.log_prob(z)
  return log_pz + log_det
score(self, x: Array, y: Optional[Array] = None, key: Optional[PRNGKeyArray] = None) -> Array ¤

Inherited from generax.distributions.base.ProbabilityDistribution.score.

Source code in generax/distributions/flow_models.py
def score(self,
          x: Array,
          y: Optional[Array] = None,
          key: Optional[PRNGKeyArray] = None) -> Array:
  """**Arguments**:

  - `x`: The point we want to compute grad logp(x) at.
  - `y`: The (optional) conditioning information.
  - `key`: The random number generator key.  Can be passed in the event
           that we're getting a stochastic estimate of the log prob.

  **Returns**:
  The log likelihood of x under the model.
  """
  return eqx.filter_grad(self.log_prob)(x, y=y, key=key)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on a batch of data. This is one of the few times that \(x\) is expected to be batched.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new flow with the parameters initialized.

Source code in generax/distributions/flow_models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on a batch of data.
  This is one of the few times that $x$ is expected to be batched.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new flow with the parameters initialized.
  """
  new_layer = self.transform.data_dependent_init(x, y=y, key=key)

  # Turn the new parameters into a new module
  get_transform = lambda tree: tree.transform
  return eqx.tree_at(get_transform, self, new_layer)

generax.distributions.flow_models.RectangularFlow (NormalizingFlow) ¤

A flow with an injective transformation, i.e. the base space has a lower dimensionality than the data space.

Source code in generax/distributions/flow_models.py
class RectangularFlow(NormalizingFlow):
  """A flow with an injective transformation, i.e. the base space has a lower
  dimensionality than the data space.
  """

  transform: InjectiveTransform
  output_shape: Tuple[int]

  def __init__(self,
               transform: InjectiveTransform,
                prior: ProbabilityDistribution,
                **kwargs):
    """**Arguments**:

    - `transform`: A bijective transformation
    - `prior`: The prior distribution
    """
    assert isinstance(transform, InjectiveTransform)
    self.output_shape = transform.output_shape
    super().__init__(transform=transform,
                     prior=prior,
                     **kwargs)

  def project(self,
              x: Array,
              y: Optional[Array] = None,
              **kwargs) -> Array:
    """Project a point onto the image of the transformation.

    **Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information

    **Returns**:
    z
    """
    return self.transform.project(x, y=y, **kwargs)

  def sample_and_log_prob(self,
                          key: PRNGKeyArray,
                          y: Optional[Array] = None,
                          **kwargs) -> Array:
    """**Arguments**:

    - `key`: The random number generator key.
    - `y`: The conditioning information

    **Returns**:
    A single sample from the model.  Use vmap to get more samples.
    """
    z, log_pz = self.prior.sample_and_log_prob(key)
    x, _ = self.transform(z, y=y, inverse=True, **kwargs)
    log_det = self.transform.log_determiant(z, y=y, **kwargs)
    return x, log_pz + log_det

  def log_prob(self,
               x: Array,
               y: Optional[Array] = None,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The point we want to compute logp(x) at.
    - `y`: The conditioning information

    **Returns**:
    The log likelihood of x under the model.
    """
    z, _ = self.transform(x, y=y, **kwargs)
    log_pz = self.prior.log_prob(z)
    log_det = self.transform.log_determiant(z, y=y, **kwargs)
    return log_pz + log_det
sample(self, key: PRNGKeyArray, y: Optional[Array] = None) -> Array ¤

Arguments:

  • key: The random number generator key.

Returns: Samples from the model

Use eqx.filter_vmap to get more samples! For example,

keys = random.split(key, n_samples)
samples = eqx.filter_vmap(self.sample)(keys)

Source code in generax/distributions/flow_models.py
def sample(self,
           key: PRNGKeyArray,
           y: Optional[Array] = None) -> Array:
  """
  **Arguments**:

  - `key`: The random number generator key.

  **Returns**:
  Samples from the model

  Use eqx.filter_vmap to get more samples!  For example,
  ```python
  keys = random.split(key, n_samples)
  samples = eqx.filter_vmap(self.sample)(keys)
  ```
  """
  return self.sample_and_log_prob(key, y)[0]
score(self, x: Array, y: Optional[Array] = None, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: The point we want to compute grad logp(x) at.
  • y: The (optional) conditioning information.
  • key: The random number generator key. Can be passed in the event that we're getting a stochastic estimate of the log prob.

Returns: The log likelihood of x under the model.

Source code in generax/distributions/flow_models.py
def score(self,
          x: Array,
          y: Optional[Array] = None,
          key: Optional[PRNGKeyArray] = None) -> Array:
  """**Arguments**:

  - `x`: The point we want to compute grad logp(x) at.
  - `y`: The (optional) conditioning information.
  - `key`: The random number generator key.  Can be passed in the event
           that we're getting a stochastic estimate of the log prob.

  **Returns**:
  The log likelihood of x under the model.
  """
  return eqx.filter_grad(self.log_prob)(x, y=y, key=key)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Inherited from generax.distributions.flow_models.NormalizingFlow.data_dependent_init.

Source code in generax/distributions/flow_models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  """Initialize the parameters of the layer based on a batch of data.
  This is one of the few times that $x$ is expected to be batched.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new flow with the parameters initialized.
  """
  new_layer = self.transform.data_dependent_init(x, y=y, key=key)

  # Turn the new parameters into a new module
  get_transform = lambda tree: tree.transform
  return eqx.tree_at(get_transform, self, new_layer)
__init__(self, transform: InjectiveTransform, prior: ProbabilityDistribution, **kwargs) ¤

Arguments:

  • transform: A bijective transformation
  • prior: The prior distribution
Source code in generax/distributions/flow_models.py
def __init__(self,
             transform: InjectiveTransform,
              prior: ProbabilityDistribution,
              **kwargs):
  """**Arguments**:

  - `transform`: A bijective transformation
  - `prior`: The prior distribution
  """
  assert isinstance(transform, InjectiveTransform)
  self.output_shape = transform.output_shape
  super().__init__(transform=transform,
                   prior=prior,
                   **kwargs)
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Project a point onto the image of the transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: z

Source code in generax/distributions/flow_models.py
def project(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Project a point onto the image of the transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  z
  """
  return self.transform.project(x, y=y, **kwargs)
sample_and_log_prob(self, key: PRNGKeyArray, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • key: The random number generator key.
  • y: The conditioning information

Returns: A single sample from the model. Use vmap to get more samples.

Source code in generax/distributions/flow_models.py
def sample_and_log_prob(self,
                        key: PRNGKeyArray,
                        y: Optional[Array] = None,
                        **kwargs) -> Array:
  """**Arguments**:

  - `key`: The random number generator key.
  - `y`: The conditioning information

  **Returns**:
  A single sample from the model.  Use vmap to get more samples.
  """
  z, log_pz = self.prior.sample_and_log_prob(key)
  x, _ = self.transform(z, y=y, inverse=True, **kwargs)
  log_det = self.transform.log_determiant(z, y=y, **kwargs)
  return x, log_pz + log_det
log_prob(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Arguments:

  • x: The point we want to compute logp(x) at.
  • y: The conditioning information

Returns: The log likelihood of x under the model.

Source code in generax/distributions/flow_models.py
def log_prob(self,
             x: Array,
             y: Optional[Array] = None,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The point we want to compute logp(x) at.
  - `y`: The conditioning information

  **Returns**:
  The log likelihood of x under the model.
  """
  z, _ = self.transform(x, y=y, **kwargs)
  log_pz = self.prior.log_prob(z)
  log_det = self.transform.log_determiant(z, y=y, **kwargs)
  return log_pz + log_det

generax.distributions.flow_models.RealNVP (NormalizingFlow) ¤

RealNVP(args, *kwargs)

Source code in generax/distributions/flow_models.py
class RealNVP(NormalizingFlow):

  def __init__(self,
               input_shape: Tuple[int],
               n_flow_layers: int = 3,
               working_size: int = 16,
               hidden_size: int = 32,
               n_blocks: int = 4,
               cond_shape: Optional[Tuple[int]] = None,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The shape of the input data.
    - `n_flow_layers`: The number of layers in the flow.
    - `working_size`: The size of the working space.
    - `hidden_size`: The size of the hidden layers.
    - `n_blocks`: The number of blocks in the coupling layers.
    - `cond_shape`: The shape of the conditioning information.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    transform = RealNVPTransform(input_shape=input_shape,
                                 n_flow_layers=n_flow_layers,
                                 working_size=working_size,
                                 hidden_size=hidden_size,
                                 n_blocks=n_blocks,
                                 cond_shape=cond_shape,
                                 key=key)
    prior = Gaussian(input_shape=input_shape)
    super().__init__(transform=transform,
                     prior=prior,
                     **kwargs)
__init__(self, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, cond_shape: Optional[Tuple[int]] = None, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The shape of the input data.
  • n_flow_layers: The number of layers in the flow.
  • working_size: The size of the working space.
  • hidden_size: The size of the hidden layers.
  • n_blocks: The number of blocks in the coupling layers.
  • cond_shape: The shape of the conditioning information.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/distributions/flow_models.py
def __init__(self,
             input_shape: Tuple[int],
             n_flow_layers: int = 3,
             working_size: int = 16,
             hidden_size: int = 32,
             n_blocks: int = 4,
             cond_shape: Optional[Tuple[int]] = None,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The shape of the input data.
  - `n_flow_layers`: The number of layers in the flow.
  - `working_size`: The size of the working space.
  - `hidden_size`: The size of the hidden layers.
  - `n_blocks`: The number of blocks in the coupling layers.
  - `cond_shape`: The shape of the conditioning information.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  transform = RealNVPTransform(input_shape=input_shape,
                               n_flow_layers=n_flow_layers,
                               working_size=working_size,
                               hidden_size=hidden_size,
                               n_blocks=n_blocks,
                               cond_shape=cond_shape,
                               key=key)
  prior = Gaussian(input_shape=input_shape)
  super().__init__(transform=transform,
                   prior=prior,
                   **kwargs)

generax.distributions.flow_models.NeuralSpline (NormalizingFlow) ¤

NeuralSpline(args, *kwargs)

Source code in generax/distributions/flow_models.py
class NeuralSpline(NormalizingFlow):

  def __init__(self,
               input_shape: Tuple[int],
               n_flow_layers: int = 3,
               working_size: int = 16,
               hidden_size: int = 32,
               n_blocks: int = 4,
               n_spline_knots: int = 8,
               cond_shape: Optional[Tuple[int]] = None,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The shape of the input data.
    - `n_flow_layers`: The number of layers in the flow.
    - `working_size`: The size of the working space.
    - `hidden_size`: The size of the hidden layers.
    - `n_blocks`: The number of blocks in the coupling layers.
    - `cond_shape`: The shape of the conditioning information.
    - `n_splice_knots`: The number of knots in the spline.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    transform = NeuralSplineTransform(input_shape=input_shape,
                                 n_flow_layers=n_flow_layers,
                                 working_size=working_size,
                                 hidden_size=hidden_size,
                                 n_blocks=n_blocks,
                                 n_spline_knots=n_spline_knots,
                                 cond_shape=cond_shape,
                                 key=key)
    prior = Gaussian(input_shape=input_shape)
    super().__init__(transform=transform,
                     prior=prior,
                     **kwargs)
__init__(self, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, n_spline_knots: int = 8, cond_shape: Optional[Tuple[int]] = None, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The shape of the input data.
  • n_flow_layers: The number of layers in the flow.
  • working_size: The size of the working space.
  • hidden_size: The size of the hidden layers.
  • n_blocks: The number of blocks in the coupling layers.
  • cond_shape: The shape of the conditioning information.
  • n_splice_knots: The number of knots in the spline.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/distributions/flow_models.py
def __init__(self,
             input_shape: Tuple[int],
             n_flow_layers: int = 3,
             working_size: int = 16,
             hidden_size: int = 32,
             n_blocks: int = 4,
             n_spline_knots: int = 8,
             cond_shape: Optional[Tuple[int]] = None,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The shape of the input data.
  - `n_flow_layers`: The number of layers in the flow.
  - `working_size`: The size of the working space.
  - `hidden_size`: The size of the hidden layers.
  - `n_blocks`: The number of blocks in the coupling layers.
  - `cond_shape`: The shape of the conditioning information.
  - `n_splice_knots`: The number of knots in the spline.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  transform = NeuralSplineTransform(input_shape=input_shape,
                               n_flow_layers=n_flow_layers,
                               working_size=working_size,
                               hidden_size=hidden_size,
                               n_blocks=n_blocks,
                               n_spline_knots=n_spline_knots,
                               cond_shape=cond_shape,
                               key=key)
  prior = Gaussian(input_shape=input_shape)
  super().__init__(transform=transform,
                   prior=prior,
                   **kwargs)

generax.distributions.flow_models.ContinuousNormalizingFlow (NormalizingFlow) ¤

This is FFJORD.

Source code in generax/distributions/flow_models.py
class ContinuousNormalizingFlow(NormalizingFlow):
  """This is [FFJORD](https://arxiv.org/pdf/1810.01367.pdf).
  """
  def __init__(self,
               input_shape: Tuple[int],
               net: eqx.Module = None,
               cond_shape: Optional[Tuple[int]] = None,
               *,
               controller_rtol: Optional[float] = 1e-3,
               controller_atol: Optional[float] = 1e-5,
               adjoint='recursive_checkpoint',
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The shape of the input data.
    - `net`: The neural network to use for the vector field.  If None, a default
              network will be used.  `net` should accept `net(t, x, y=y)`
    - `cond_shape`: The shape of the conditioning information.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `controller_rtol`: The relative tolerance for the controller.
    - `controller_atol`: The absolute tolerance for the controller.
    - `adjoint`: The adjoint method to use.  See [this](https://docs.kidger.site/diffrax/api/adjoints/)
    """
    transform = FFJORDTransform(input_shape=input_shape,
                                net=net,
                                cond_shape=cond_shape,
                                key=key,
                                controller_rtol=controller_rtol,
                                controller_atol=controller_atol,
                                adjoint=adjoint,
                                **kwargs)
    prior = Gaussian(input_shape=input_shape)
    super().__init__(transform=transform,
                     prior=prior,
                     **kwargs)

  @property
  def neural_ode(self):
    return self.transform.neural_ode

  @property
  def vector_field(self):
    """Get the vector field function that samples evolve on as t changes.  This is
    an `eqx.Module` that with the signature `vector_field(t, x, y=y) -> dx/dt`."""
    return self.transform.vector_field

  @property
  def net(self):
    """Same as `vector_field`"""
    return self.vector_field

  def sample(self,
             key: PRNGKeyArray,
             y: Optional[Array] = None,
             **kwargs) -> Array:
    """**Arguments**:

    - `key`: The random number generator key.

    **Returns**:
    Samples from the model
    """
    z = self.prior.sample(key)
    x, _ = self.transform(z,
                          y=y,
                          inverse=True,
                          log_likelihood=False,
                          **kwargs)
    return x
vector_field property readonly ¤

Get the vector field function that samples evolve on as t changes. This is an eqx.Module that with the signature vector_field(t, x, y=y) -> dx/dt.

net property readonly ¤

Same as vector_field

__init__(self, input_shape: Tuple[int], net: Module = None, cond_shape: Optional[Tuple[int]] = None, *, controller_rtol: Optional[float] = 0.001, controller_atol: Optional[float] = 1e-05, adjoint = 'recursive_checkpoint', key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The shape of the input data.
  • net: The neural network to use for the vector field. If None, a default network will be used. net should accept net(t, x, y=y)
  • cond_shape: The shape of the conditioning information.
  • key: A jax.random.PRNGKey for initialization
  • controller_rtol: The relative tolerance for the controller.
  • controller_atol: The absolute tolerance for the controller.
  • adjoint: The adjoint method to use. See this
Source code in generax/distributions/flow_models.py
def __init__(self,
             input_shape: Tuple[int],
             net: eqx.Module = None,
             cond_shape: Optional[Tuple[int]] = None,
             *,
             controller_rtol: Optional[float] = 1e-3,
             controller_atol: Optional[float] = 1e-5,
             adjoint='recursive_checkpoint',
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The shape of the input data.
  - `net`: The neural network to use for the vector field.  If None, a default
            network will be used.  `net` should accept `net(t, x, y=y)`
  - `cond_shape`: The shape of the conditioning information.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `controller_rtol`: The relative tolerance for the controller.
  - `controller_atol`: The absolute tolerance for the controller.
  - `adjoint`: The adjoint method to use.  See [this](https://docs.kidger.site/diffrax/api/adjoints/)
  """
  transform = FFJORDTransform(input_shape=input_shape,
                              net=net,
                              cond_shape=cond_shape,
                              key=key,
                              controller_rtol=controller_rtol,
                              controller_atol=controller_atol,
                              adjoint=adjoint,
                              **kwargs)
  prior = Gaussian(input_shape=input_shape)
  super().__init__(transform=transform,
                   prior=prior,
                   **kwargs)