Flows¤
generax.distributions.flow_models.NormalizingFlow (ProbabilityDistribution)
¤
A normalizing flow is a model that we use to represent probability distributions. See this for an overview.
Atributes:
transform
: ABijectiveTransform
object that transforms a variable from the base space to the data space and also computes the change is log pdf.transform(x) -> (z,log_det)
: Apply the transformation to the input.prior
: The prior probability distribution.
Methods:
to_base_space(x) -> z
: Transform a point from the data space to the base space.sample_and_log_prob(key) -> (x,log_px)
: Sample from the distribution and compute the log probability.sample(key) -> x
: Pull a single sample from the modellog_prob(x) -> log_px
: Compute the log probability of a point under the model
Source code in generax/distributions/flow_models.py
class NormalizingFlow(ProbabilityDistribution, ABC):
"""A normalizing flow is a model that we use to represent probability
distributions. See [this](https://arxiv.org/pdf/1912.02762.pdf) for an overview.
**Atributes**:
- `transform`: A `BijectiveTransform` object that transforms a variable
from the base space to the data space and also computes
the change is log pdf.
- `transform(x) -> (z,log_det)`: Apply the transformation to the input.
- `prior`: The prior probability distribution.
**Methods**:
- `to_base_space(x) -> z`: Transform a point from the data space to the base space.
- `sample_and_log_prob(key) -> (x,log_px)`: Sample from the distribution and compute the log probability.
- `sample(key) -> x`: Pull a single sample from the model
- `log_prob(x) -> log_px`: Compute the log probability of a point under the model
"""
transform: BijectiveTransform
prior: ProbabilityDistribution
def __init__(self,
transform: BijectiveTransform,
prior: ProbabilityDistribution,
**kwargs):
"""**Arguments**:
- `transform`: A bijective transformation
- `prior`: The prior distribution
"""
self.transform = transform
self.prior = prior
input_shape = self.transform.input_shape
super().__init__(input_shape=input_shape, **kwargs)
def to_base_space(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `x`: A JAX array with shape `(dim,)`.
- `y`: The conditioning information
**Returns**:
A JAX array with shape `(dim,)`.
"""
return self.transform(x, y=y, **kwargs)[0]
def to_data_space(self,
z: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `z`: A JAX array with shape `(dim,)`.
- `y`: The conditioning information
**Returns**:
A JAX array with shape `(dim,)`.
"""
return self.transform(z, y=y, inverse=True, **kwargs)[0]
def sample_and_log_prob(self,
key: PRNGKeyArray,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `key`: The random number generator key.
- `y`: The conditioning information
**Returns**:
A single sample from the model. Use vmap to get more samples.
"""
z, log_pz = self.prior.sample_and_log_prob(key)
x, log_det = self.transform(z, y=y, inverse=True, **kwargs)
# The log determinant of the inverse transform has a negative sign!
return x, log_pz - log_det
def log_prob(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The point we want to compute logp(x) at.
- `y`: The conditioning information
**Returns**:
The log likelihood of x under the model.
"""
z, log_det = self.transform(x, y=y, **kwargs)
log_pz = self.prior.log_prob(z)
return log_pz + log_det
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on a batch of data.
This is one of the few times that $x$ is expected to be batched.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new flow with the parameters initialized.
"""
new_layer = self.transform.data_dependent_init(x, y=y, key=key)
# Turn the new parameters into a new module
get_transform = lambda tree: tree.transform
return eqx.tree_at(get_transform, self, new_layer)
__init__(self, transform: BijectiveTransform, prior: ProbabilityDistribution, **kwargs)
¤
Arguments:
transform
: A bijective transformationprior
: The prior distribution
Source code in generax/distributions/flow_models.py
def __init__(self,
transform: BijectiveTransform,
prior: ProbabilityDistribution,
**kwargs):
"""**Arguments**:
- `transform`: A bijective transformation
- `prior`: The prior distribution
"""
self.transform = transform
self.prior = prior
input_shape = self.transform.input_shape
super().__init__(input_shape=input_shape, **kwargs)
sample(self, key: PRNGKeyArray, y: Optional[Array] = None) -> Array
¤
Inherited from generax.distributions.base.ProbabilityDistribution.sample
.
Source code in generax/distributions/flow_models.py
def sample(self,
key: PRNGKeyArray,
y: Optional[Array] = None) -> Array:
"""
**Arguments**:
- `key`: The random number generator key.
**Returns**:
Samples from the model
Use eqx.filter_vmap to get more samples! For example,
```python
keys = random.split(key, n_samples)
samples = eqx.filter_vmap(self.sample)(keys)
```
"""
return self.sample_and_log_prob(key, y)[0]
sample_and_log_prob(self, key: PRNGKeyArray, y: Optional[Array] = None, **kwargs) -> Array
¤
Arguments:
key
: The random number generator key.y
: The conditioning information
Returns: A single sample from the model. Use vmap to get more samples.
Source code in generax/distributions/flow_models.py
def sample_and_log_prob(self,
key: PRNGKeyArray,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `key`: The random number generator key.
- `y`: The conditioning information
**Returns**:
A single sample from the model. Use vmap to get more samples.
"""
z, log_pz = self.prior.sample_and_log_prob(key)
x, log_det = self.transform(z, y=y, inverse=True, **kwargs)
# The log determinant of the inverse transform has a negative sign!
return x, log_pz - log_det
log_prob(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Arguments:
x
: The point we want to compute logp(x) at.y
: The conditioning information
Returns: The log likelihood of x under the model.
Source code in generax/distributions/flow_models.py
def log_prob(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The point we want to compute logp(x) at.
- `y`: The conditioning information
**Returns**:
The log likelihood of x under the model.
"""
z, log_det = self.transform(x, y=y, **kwargs)
log_pz = self.prior.log_prob(z)
return log_pz + log_det
score(self, x: Array, y: Optional[Array] = None, key: Optional[PRNGKeyArray] = None) -> Array
¤
Inherited from generax.distributions.base.ProbabilityDistribution.score
.
Source code in generax/distributions/flow_models.py
def score(self,
x: Array,
y: Optional[Array] = None,
key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments**:
- `x`: The point we want to compute grad logp(x) at.
- `y`: The (optional) conditioning information.
- `key`: The random number generator key. Can be passed in the event
that we're getting a stochastic estimate of the log prob.
**Returns**:
The log likelihood of x under the model.
"""
return eqx.filter_grad(self.log_prob)(x, y=y, key=key)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on a batch of data. This is one of the few times that \(x\) is expected to be batched.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new flow with the parameters initialized.
Source code in generax/distributions/flow_models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on a batch of data.
This is one of the few times that $x$ is expected to be batched.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new flow with the parameters initialized.
"""
new_layer = self.transform.data_dependent_init(x, y=y, key=key)
# Turn the new parameters into a new module
get_transform = lambda tree: tree.transform
return eqx.tree_at(get_transform, self, new_layer)
generax.distributions.flow_models.RectangularFlow (NormalizingFlow)
¤
A flow with an injective transformation, i.e. the base space has a lower dimensionality than the data space.
Source code in generax/distributions/flow_models.py
class RectangularFlow(NormalizingFlow):
"""A flow with an injective transformation, i.e. the base space has a lower
dimensionality than the data space.
"""
transform: InjectiveTransform
output_shape: Tuple[int]
def __init__(self,
transform: InjectiveTransform,
prior: ProbabilityDistribution,
**kwargs):
"""**Arguments**:
- `transform`: A bijective transformation
- `prior`: The prior distribution
"""
assert isinstance(transform, InjectiveTransform)
self.output_shape = transform.output_shape
super().__init__(transform=transform,
prior=prior,
**kwargs)
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
return self.transform.project(x, y=y, **kwargs)
def sample_and_log_prob(self,
key: PRNGKeyArray,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `key`: The random number generator key.
- `y`: The conditioning information
**Returns**:
A single sample from the model. Use vmap to get more samples.
"""
z, log_pz = self.prior.sample_and_log_prob(key)
x, _ = self.transform(z, y=y, inverse=True, **kwargs)
log_det = self.transform.log_determiant(z, y=y, **kwargs)
return x, log_pz + log_det
def log_prob(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The point we want to compute logp(x) at.
- `y`: The conditioning information
**Returns**:
The log likelihood of x under the model.
"""
z, _ = self.transform(x, y=y, **kwargs)
log_pz = self.prior.log_prob(z)
log_det = self.transform.log_determiant(z, y=y, **kwargs)
return log_pz + log_det
sample(self, key: PRNGKeyArray, y: Optional[Array] = None) -> Array
¤
Arguments:
key
: The random number generator key.
Returns: Samples from the model
Use eqx.filter_vmap to get more samples! For example,
keys = random.split(key, n_samples)
samples = eqx.filter_vmap(self.sample)(keys)
Source code in generax/distributions/flow_models.py
def sample(self,
key: PRNGKeyArray,
y: Optional[Array] = None) -> Array:
"""
**Arguments**:
- `key`: The random number generator key.
**Returns**:
Samples from the model
Use eqx.filter_vmap to get more samples! For example,
```python
keys = random.split(key, n_samples)
samples = eqx.filter_vmap(self.sample)(keys)
```
"""
return self.sample_and_log_prob(key, y)[0]
score(self, x: Array, y: Optional[Array] = None, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The point we want to compute grad logp(x) at.y
: The (optional) conditioning information.key
: The random number generator key. Can be passed in the event that we're getting a stochastic estimate of the log prob.
Returns: The log likelihood of x under the model.
Source code in generax/distributions/flow_models.py
def score(self,
x: Array,
y: Optional[Array] = None,
key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments**:
- `x`: The point we want to compute grad logp(x) at.
- `y`: The (optional) conditioning information.
- `key`: The random number generator key. Can be passed in the event
that we're getting a stochastic estimate of the log prob.
**Returns**:
The log likelihood of x under the model.
"""
return eqx.filter_grad(self.log_prob)(x, y=y, key=key)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Inherited from generax.distributions.flow_models.NormalizingFlow.data_dependent_init
.
Source code in generax/distributions/flow_models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on a batch of data.
This is one of the few times that $x$ is expected to be batched.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new flow with the parameters initialized.
"""
new_layer = self.transform.data_dependent_init(x, y=y, key=key)
# Turn the new parameters into a new module
get_transform = lambda tree: tree.transform
return eqx.tree_at(get_transform, self, new_layer)
__init__(self, transform: InjectiveTransform, prior: ProbabilityDistribution, **kwargs)
¤
Arguments:
transform
: A bijective transformationprior
: The prior distribution
Source code in generax/distributions/flow_models.py
def __init__(self,
transform: InjectiveTransform,
prior: ProbabilityDistribution,
**kwargs):
"""**Arguments**:
- `transform`: A bijective transformation
- `prior`: The prior distribution
"""
assert isinstance(transform, InjectiveTransform)
self.output_shape = transform.output_shape
super().__init__(transform=transform,
prior=prior,
**kwargs)
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Project a point onto the image of the transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: z
Source code in generax/distributions/flow_models.py
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
return self.transform.project(x, y=y, **kwargs)
sample_and_log_prob(self, key: PRNGKeyArray, y: Optional[Array] = None, **kwargs) -> Array
¤
Arguments:
key
: The random number generator key.y
: The conditioning information
Returns: A single sample from the model. Use vmap to get more samples.
Source code in generax/distributions/flow_models.py
def sample_and_log_prob(self,
key: PRNGKeyArray,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `key`: The random number generator key.
- `y`: The conditioning information
**Returns**:
A single sample from the model. Use vmap to get more samples.
"""
z, log_pz = self.prior.sample_and_log_prob(key)
x, _ = self.transform(z, y=y, inverse=True, **kwargs)
log_det = self.transform.log_determiant(z, y=y, **kwargs)
return x, log_pz + log_det
log_prob(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Arguments:
x
: The point we want to compute logp(x) at.y
: The conditioning information
Returns: The log likelihood of x under the model.
Source code in generax/distributions/flow_models.py
def log_prob(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The point we want to compute logp(x) at.
- `y`: The conditioning information
**Returns**:
The log likelihood of x under the model.
"""
z, _ = self.transform(x, y=y, **kwargs)
log_pz = self.prior.log_prob(z)
log_det = self.transform.log_determiant(z, y=y, **kwargs)
return log_pz + log_det
generax.distributions.flow_models.RealNVP (NormalizingFlow)
¤
RealNVP(args, *kwargs)
Source code in generax/distributions/flow_models.py
class RealNVP(NormalizingFlow):
def __init__(self,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `n_flow_layers`: The number of layers in the flow.
- `working_size`: The size of the working space.
- `hidden_size`: The size of the hidden layers.
- `n_blocks`: The number of blocks in the coupling layers.
- `cond_shape`: The shape of the conditioning information.
- `key`: A `jax.random.PRNGKey` for initialization
"""
transform = RealNVPTransform(input_shape=input_shape,
n_flow_layers=n_flow_layers,
working_size=working_size,
hidden_size=hidden_size,
n_blocks=n_blocks,
cond_shape=cond_shape,
key=key)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)
__init__(self, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, cond_shape: Optional[Tuple[int]] = None, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The shape of the input data.n_flow_layers
: The number of layers in the flow.working_size
: The size of the working space.hidden_size
: The size of the hidden layers.n_blocks
: The number of blocks in the coupling layers.cond_shape
: The shape of the conditioning information.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/distributions/flow_models.py
def __init__(self,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `n_flow_layers`: The number of layers in the flow.
- `working_size`: The size of the working space.
- `hidden_size`: The size of the hidden layers.
- `n_blocks`: The number of blocks in the coupling layers.
- `cond_shape`: The shape of the conditioning information.
- `key`: A `jax.random.PRNGKey` for initialization
"""
transform = RealNVPTransform(input_shape=input_shape,
n_flow_layers=n_flow_layers,
working_size=working_size,
hidden_size=hidden_size,
n_blocks=n_blocks,
cond_shape=cond_shape,
key=key)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)
generax.distributions.flow_models.NeuralSpline (NormalizingFlow)
¤
NeuralSpline(args, *kwargs)
Source code in generax/distributions/flow_models.py
class NeuralSpline(NormalizingFlow):
def __init__(self,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
n_spline_knots: int = 8,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `n_flow_layers`: The number of layers in the flow.
- `working_size`: The size of the working space.
- `hidden_size`: The size of the hidden layers.
- `n_blocks`: The number of blocks in the coupling layers.
- `cond_shape`: The shape of the conditioning information.
- `n_splice_knots`: The number of knots in the spline.
- `key`: A `jax.random.PRNGKey` for initialization
"""
transform = NeuralSplineTransform(input_shape=input_shape,
n_flow_layers=n_flow_layers,
working_size=working_size,
hidden_size=hidden_size,
n_blocks=n_blocks,
n_spline_knots=n_spline_knots,
cond_shape=cond_shape,
key=key)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)
__init__(self, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, n_spline_knots: int = 8, cond_shape: Optional[Tuple[int]] = None, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The shape of the input data.n_flow_layers
: The number of layers in the flow.working_size
: The size of the working space.hidden_size
: The size of the hidden layers.n_blocks
: The number of blocks in the coupling layers.cond_shape
: The shape of the conditioning information.n_splice_knots
: The number of knots in the spline.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/distributions/flow_models.py
def __init__(self,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
n_spline_knots: int = 8,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `n_flow_layers`: The number of layers in the flow.
- `working_size`: The size of the working space.
- `hidden_size`: The size of the hidden layers.
- `n_blocks`: The number of blocks in the coupling layers.
- `cond_shape`: The shape of the conditioning information.
- `n_splice_knots`: The number of knots in the spline.
- `key`: A `jax.random.PRNGKey` for initialization
"""
transform = NeuralSplineTransform(input_shape=input_shape,
n_flow_layers=n_flow_layers,
working_size=working_size,
hidden_size=hidden_size,
n_blocks=n_blocks,
n_spline_knots=n_spline_knots,
cond_shape=cond_shape,
key=key)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)
generax.distributions.flow_models.ContinuousNormalizingFlow (NormalizingFlow)
¤
This is FFJORD.
Source code in generax/distributions/flow_models.py
class ContinuousNormalizingFlow(NormalizingFlow):
"""This is [FFJORD](https://arxiv.org/pdf/1810.01367.pdf).
"""
def __init__(self,
input_shape: Tuple[int],
net: eqx.Module = None,
cond_shape: Optional[Tuple[int]] = None,
*,
controller_rtol: Optional[float] = 1e-3,
controller_atol: Optional[float] = 1e-5,
adjoint='recursive_checkpoint',
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `net`: The neural network to use for the vector field. If None, a default
network will be used. `net` should accept `net(t, x, y=y)`
- `cond_shape`: The shape of the conditioning information.
- `key`: A `jax.random.PRNGKey` for initialization
- `controller_rtol`: The relative tolerance for the controller.
- `controller_atol`: The absolute tolerance for the controller.
- `adjoint`: The adjoint method to use. See [this](https://docs.kidger.site/diffrax/api/adjoints/)
"""
transform = FFJORDTransform(input_shape=input_shape,
net=net,
cond_shape=cond_shape,
key=key,
controller_rtol=controller_rtol,
controller_atol=controller_atol,
adjoint=adjoint,
**kwargs)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)
@property
def neural_ode(self):
return self.transform.neural_ode
@property
def vector_field(self):
"""Get the vector field function that samples evolve on as t changes. This is
an `eqx.Module` that with the signature `vector_field(t, x, y=y) -> dx/dt`."""
return self.transform.vector_field
@property
def net(self):
"""Same as `vector_field`"""
return self.vector_field
def sample(self,
key: PRNGKeyArray,
y: Optional[Array] = None,
**kwargs) -> Array:
"""**Arguments**:
- `key`: The random number generator key.
**Returns**:
Samples from the model
"""
z = self.prior.sample(key)
x, _ = self.transform(z,
y=y,
inverse=True,
log_likelihood=False,
**kwargs)
return x
vector_field
property
readonly
¤
Get the vector field function that samples evolve on as t changes. This is
an eqx.Module
that with the signature vector_field(t, x, y=y) -> dx/dt
.
net
property
readonly
¤
Same as vector_field
__init__(self, input_shape: Tuple[int], net: Module = None, cond_shape: Optional[Tuple[int]] = None, *, controller_rtol: Optional[float] = 0.001, controller_atol: Optional[float] = 1e-05, adjoint = 'recursive_checkpoint', key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The shape of the input data.net
: The neural network to use for the vector field. If None, a default network will be used.net
should acceptnet(t, x, y=y)
cond_shape
: The shape of the conditioning information.key
: Ajax.random.PRNGKey
for initializationcontroller_rtol
: The relative tolerance for the controller.controller_atol
: The absolute tolerance for the controller.adjoint
: The adjoint method to use. See this
Source code in generax/distributions/flow_models.py
def __init__(self,
input_shape: Tuple[int],
net: eqx.Module = None,
cond_shape: Optional[Tuple[int]] = None,
*,
controller_rtol: Optional[float] = 1e-3,
controller_atol: Optional[float] = 1e-5,
adjoint='recursive_checkpoint',
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The shape of the input data.
- `net`: The neural network to use for the vector field. If None, a default
network will be used. `net` should accept `net(t, x, y=y)`
- `cond_shape`: The shape of the conditioning information.
- `key`: A `jax.random.PRNGKey` for initialization
- `controller_rtol`: The relative tolerance for the controller.
- `controller_atol`: The absolute tolerance for the controller.
- `adjoint`: The adjoint method to use. See [this](https://docs.kidger.site/diffrax/api/adjoints/)
"""
transform = FFJORDTransform(input_shape=input_shape,
net=net,
cond_shape=cond_shape,
key=key,
controller_rtol=controller_rtol,
controller_atol=controller_atol,
adjoint=adjoint,
**kwargs)
prior = Gaussian(input_shape=input_shape)
super().__init__(transform=transform,
prior=prior,
**kwargs)