Base¤
generax.flows.base.BijectiveTransform
¤
This represents a bijective transformation.
Atributes:
input_shape
: The input shape. Output shape will have the same dimensionality as the input.cond_shape
: The shape of the conditioning information. If there is no conditioning information, this is None.
Source code in generax/flows/base.py
class BijectiveTransform(eqx.Module, ABC):
"""This represents a bijective transformation.
**Atributes**:
- `input_shape`: The input shape. Output shape will have the same dimensionality as the input.
- `cond_shape`: The shape of the conditioning information. If there is no
conditioning information, this is None.
"""
input_shape: Tuple[int] = eqx.field(static=True)
cond_shape: Union[None, Tuple[int]] = eqx.field(static=True)
def __init__(self,
*_,
input_shape: Tuple[int],
cond_shape: Union[None, Tuple[int]] = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `cond_shape`: The shape of the conditioning information. If there is no
"""
super().__init__(**kwargs)
assert isinstance(input_shape, tuple) or isinstance(input_shape, list)
self.input_shape = tuple(input_shape)
if cond_shape is not None:
assert isinstance(cond_shape, tuple) or isinstance(cond_shape, list)
self.cond_shape = tuple(cond_shape)
else:
self.cond_shape = None
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
@abstractmethod
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
pass
def to_base_space(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
return self(x, y=y, **kwargs)[0]
def to_data_space(self,
z: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `z`: The input to the transformation
- `y`: The conditioning information
**Returns**:
x
"""
return self(z, y=y, inverse=True, **kwargs)[0]
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
def get_inverse(self) -> 'BijectiveTransform':
"""Get a new `BijectiveTransform` that is the inverse of this one.
**Returns**:
The inverse transformation.
"""
class Wrapper(eqx.Module):
transform: BijectiveTransform
input_shape: Tuple[int]
cond_shape: Tuple[int]
def __init__(self, transform):
self.transform = transform
self.input_shape = transform.input_shape
self.cond_shape = transform.cond_shape
def __call__(self, x, y=None, inverse=False, **kwargs):
return self.transform(x, y=y, inverse=not inverse, **kwargs)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
# Invert first
def apply_fun(x):
return self(x, y=y)[0]
z = eqx.filter_vmap(apply_fun)(x)
# Regular data dependent init
new_layer = self.transform.data_dependent_init(z, y=y, key=key)
return new_layer.get_inverse()
@property
def __wrapped__(self):
return self.transform
return eqx.module_update_wrapper(Wrapper(self))
__init__(self, *_, *, input_shape: Tuple[int], cond_shape: Optional[Tuple[int]] = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.cond_shape
: The shape of the conditioning information. If there is no
Source code in generax/flows/base.py
def __init__(self,
*_,
input_shape: Tuple[int],
cond_shape: Union[None, Tuple[int]] = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `cond_shape`: The shape of the conditioning information. If there is no
"""
super().__init__(**kwargs)
assert isinstance(input_shape, tuple) or isinstance(input_shape, list)
self.input_shape = tuple(input_shape)
if cond_shape is not None:
assert isinstance(cond_shape, tuple) or isinstance(cond_shape, list)
self.cond_shape = tuple(cond_shape)
else:
self.cond_shape = None
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
abstractmethod
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/base.py
@abstractmethod
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
pass
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/base.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
get_inverse(self) -> BijectiveTransform
¤
Get a new BijectiveTransform
that is the inverse of this one.
Returns: The inverse transformation.
Source code in generax/flows/base.py
def get_inverse(self) -> 'BijectiveTransform':
"""Get a new `BijectiveTransform` that is the inverse of this one.
**Returns**:
The inverse transformation.
"""
class Wrapper(eqx.Module):
transform: BijectiveTransform
input_shape: Tuple[int]
cond_shape: Tuple[int]
def __init__(self, transform):
self.transform = transform
self.input_shape = transform.input_shape
self.cond_shape = transform.cond_shape
def __call__(self, x, y=None, inverse=False, **kwargs):
return self.transform(x, y=y, inverse=not inverse, **kwargs)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
# Invert first
def apply_fun(x):
return self(x, y=y)[0]
z = eqx.filter_vmap(apply_fun)(x)
# Regular data dependent init
new_layer = self.transform.data_dependent_init(z, y=y, key=key)
return new_layer.get_inverse()
@property
def __wrapped__(self):
return self.transform
return eqx.module_update_wrapper(Wrapper(self))
generax.flows.base.TimeDependentBijectiveTransform (BijectiveTransform)
¤
Time dependent bijective transform. This will help us build simple probability paths. Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.
Atributes:
input_shape
: The input shape. Output shape will have the same dimensionality as the input.cond_shape
: The shape of the conditioning information. If there is no conditioning information, this is None.
Source code in generax/flows/base.py
class TimeDependentBijectiveTransform(BijectiveTransform):
"""Time dependent bijective transform. This will help us build simple probability paths.
Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.
**Atributes**:
- `input_shape`: The input shape. Output shape will have the same dimensionality
as the input.
- `cond_shape`: The shape of the conditioning information. If there is no
conditioning information, this is None.
"""
def data_dependent_init(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `t`: Time.
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
@abstractmethod
def __call__(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `xt`: The input to the transformation. If inverse=True, then should be x0
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(x0, log_det)
"""
pass
def inverse(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(xt, log_det)
"""
return self(t, x0, y=y, inverse=True, **kwargs)
def to_base_space(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
return self(t, xt, y=y, **kwargs)[0]
def to_data_space(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `z`: The input to the transformation
- `y`: The conditioning information
**Returns**:
x
"""
return self(t, x0, y=y, inverse=True, **kwargs)[0]
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `xt`: A point in the data space.
- `y`: The conditioning information.
**Returns**:
`return vt`
"""
x0 = self.to_base_space(t, xt, y=y, **kwargs)
def ft(t):
return self.to_data_space(t, x0, y=y, **kwargs)
return jax.jvp(ft, (t,), (jnp.ones_like(t),))[1]
__init__(self, *_, *, input_shape: Tuple[int], cond_shape: Optional[Tuple[int]] = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.cond_shape
: The shape of the conditioning information. If there is no
Source code in generax/flows/base.py
def __init__(self,
*_,
input_shape: Tuple[int],
cond_shape: Union[None, Tuple[int]] = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `cond_shape`: The shape of the conditioning information. If there is no
"""
super().__init__(**kwargs)
assert isinstance(input_shape, tuple) or isinstance(input_shape, list)
self.input_shape = tuple(input_shape)
if cond_shape is not None:
assert isinstance(cond_shape, tuple) or isinstance(cond_shape, list)
self.cond_shape = tuple(cond_shape)
else:
self.cond_shape = None
data_dependent_init(self, t: Array, xt: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Initialize the parameters of the layer based on the data.
Arguments:
t
: Time.x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `t`: Time.
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__call__(self, t: Array, xt: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
abstractmethod
¤
Arguments:
xt
: The input to the transformation. If inverse=True, then should be x0y
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (x0, log_det)
Source code in generax/flows/base.py
@abstractmethod
def __call__(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `xt`: The input to the transformation. If inverse=True, then should be x0
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(x0, log_det)
"""
pass
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (xt, log_det)
Source code in generax/flows/base.py
def inverse(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(xt, log_det)
"""
return self(t, x0, y=y, inverse=True, **kwargs)
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
The vector field that samples evolve on as t changes
Arguments:
t
: Time.xt
: A point in the data space.y
: The conditioning information.
Returns:
return vt
Source code in generax/flows/base.py
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `xt`: A point in the data space.
- `y`: The conditioning information.
**Returns**:
`return vt`
"""
x0 = self.to_base_space(t, xt, y=y, **kwargs)
def ft(t):
return self.to_data_space(t, x0, y=y, **kwargs)
return jax.jvp(ft, (t,), (jnp.ones_like(t),))[1]
generax.flows.base.Repeat (BijectiveTransform)
¤
A repeated bijective transformations that is vmapped together. The input to this function should be an initializer function for a transform. For example:
def make_layer(key):
return ShiftScale(input_shape=x_shape, key=key)
layer = Repeat(make_layer, n_repeats=3, key=key)
Attributes:
- layers
: A vmapped layer in the composition
Source code in generax/flows/base.py
class Repeat(BijectiveTransform):
"""A repeated bijective transformations that is vmapped together. The input
to this function should be an initializer function for a transform. For example:
```python
def make_layer(key):
return ShiftScale(input_shape=x_shape, key=key)
layer = Repeat(make_layer, n_repeats=3, key=key)
```
**Attributes**:
- `layers`: A vmapped layer in the composition
"""
n_repeats: int = eqx.field(static=True)
layers: BijectiveTransform
def __init__(self,
layer_init: Callable[[PRNGKeyArray], BijectiveTransform],
n_repeats: int,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
self.n_repeats = n_repeats
keys = random.split(key, n_repeats)
self.layers = eqx.filter_vmap(layer_init)(keys)
super().__init__(input_shape=self.layers.input_shape,
cond_shape=self.layers.cond_shape,
**kwargs)
def to_sequential(self) -> Sequential:
"""Convert this to a sequential composition.
"""
params, static = eqx.partition(self.layers, eqx.is_array)
def make_layer(single_parameters: PyTree):
return eqx.combine(single_parameters, static)
layers = []
for i in range(self.n_repeats):
layer = make_layer(jax.tree_util.tree_map(lambda x: x[i], params))
layers.append(layer)
return Sequential(*layers)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/base.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, layer_init: Callable[[PRNGKeyArray], BijectiveTransform], n_repeats: int, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
layers
: A sequence ofBijectiveTransform
.
Source code in generax/flows/base.py
def __init__(self,
layer_init: Callable[[PRNGKeyArray], BijectiveTransform],
n_repeats: int,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
self.n_repeats = n_repeats
keys = random.split(key, n_repeats)
self.layers = eqx.filter_vmap(layer_init)(keys)
super().__init__(input_shape=self.layers.input_shape,
cond_shape=self.layers.cond_shape,
**kwargs)
to_sequential(self) -> Sequential
¤
Convert this to a sequential composition.
Source code in generax/flows/base.py
def to_sequential(self) -> Sequential:
"""Convert this to a sequential composition.
"""
params, static = eqx.partition(self.layers, eqx.is_array)
def make_layer(single_parameters: PyTree):
return eqx.combine(single_parameters, static)
layers = []
for i in range(self.n_repeats):
layer = make_layer(jax.tree_util.tree_map(lambda x: x[i], params))
layers.append(layer)
return Sequential(*layers)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/base.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.base.TimeDependentRepeat (TimeDependentBijectiveTransform)
¤
A time dependent repeated bijective transformations that is vmapped together. The input to this function should be an initializer function for a transform. For example:
def make_layer(key):
return ShiftScale(input_shape=x_shape, key=key)
layer = Repeat(make_layer, n_repeats=3, key=key)
Attributes:
- layers
: A vmapped layer in the composition
Source code in generax/flows/base.py
class TimeDependentRepeat(TimeDependentBijectiveTransform):
"""A time dependent repeated bijective transformations that is vmapped together. The input
to this function should be an initializer function for a transform. For example:
```python
def make_layer(key):
return ShiftScale(input_shape=x_shape, key=key)
layer = Repeat(make_layer, n_repeats=3, key=key)
```
**Attributes**:
- `layers`: A vmapped layer in the composition
"""
n_repeats: int = eqx.field(static=True)
layers: BijectiveTransform
def __init__(self,
layer_init: Callable[[PRNGKeyArray], BijectiveTransform],
n_repeats: int,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
self.n_repeats = n_repeats
keys = random.split(key, n_repeats)
self.layers = eqx.filter_vmap(layer_init)(keys)
super().__init__(input_shape=self.layers.input_shape,
cond_shape=self.layers.cond_shape,
**kwargs)
def to_sequential(self) -> TimeDependentSequential:
"""Convert this to a sequential composition.
"""
params, static = eqx.partition(self.layers, eqx.is_array)
def make_layer(single_parameters: PyTree):
return eqx.combine(single_parameters, static)
layers = []
for i in range(self.n_repeats):
layer = make_layer(jax.tree_util.tree_map(lambda x: x[i], params))
layers.append(layer)
return TimeDependentSequential(*layers)
def data_dependent_init(self,
t: Array,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(t, x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(t, x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.inverse
.
Source code in generax/flows/base.py
def inverse(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(xt, log_det)
"""
return self(t, x0, y=y, inverse=True, **kwargs)
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.vector_field
.
Source code in generax/flows/base.py
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `xt`: A point in the data space.
- `y`: The conditioning information.
**Returns**:
`return vt`
"""
x0 = self.to_base_space(t, xt, y=y, **kwargs)
def ft(t):
return self.to_data_space(t, x0, y=y, **kwargs)
return jax.jvp(ft, (t,), (jnp.ones_like(t),))[1]
__init__(self, layer_init: Callable[[PRNGKeyArray], BijectiveTransform], n_repeats: int, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
layers
: A sequence ofBijectiveTransform
.
Source code in generax/flows/base.py
def __init__(self,
layer_init: Callable[[PRNGKeyArray], BijectiveTransform],
n_repeats: int,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
self.n_repeats = n_repeats
keys = random.split(key, n_repeats)
self.layers = eqx.filter_vmap(layer_init)(keys)
super().__init__(input_shape=self.layers.input_shape,
cond_shape=self.layers.cond_shape,
**kwargs)
to_sequential(self) -> TimeDependentSequential
¤
Convert this to a sequential composition.
Source code in generax/flows/base.py
def to_sequential(self) -> TimeDependentSequential:
"""Convert this to a sequential composition.
"""
params, static = eqx.partition(self.layers, eqx.is_array)
def make_layer(single_parameters: PyTree):
return eqx.combine(single_parameters, static)
layers = []
for i in range(self.n_repeats):
layer = make_layer(jax.tree_util.tree_map(lambda x: x[i], params))
layers.append(layer)
return TimeDependentSequential(*layers)
data_dependent_init(self, t: Array, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
t
: Time.x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
t: Array,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(t, x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, t: Array, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/base.py
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(t, x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.base.Sequential (BijectiveTransform)
¤
A sequence of bijective transformations. Accepts a sequence
of BijectiveTransform
initializers.
# Intented usage:
layer1 = MyTransform(...)
layer2 = MyTransform(...)
transform = Sequential(layer1, layer2)
Attributes:
- n_layers
: The number of layers in the composition
- layers
: A tuple of the layers in the composition
Source code in generax/flows/base.py
class Sequential(BijectiveTransform):
"""A sequence of bijective transformations. Accepts a sequence
of `BijectiveTransform` initializers.
```python
# Intented usage:
layer1 = MyTransform(...)
layer2 = MyTransform(...)
transform = Sequential(layer1, layer2)
```
**Attributes**:
- `n_layers`: The number of layers in the composition
- `layers`: A tuple of the layers in the composition
"""
n_layers: int = eqx.field(static=True)
layers: Tuple[BijectiveTransform]
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
for layer in layers:
assert layer.cond_shape == cond_shape
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
# We need to initialize each of the layers
keys = random.split(key, self.n_layers)
new_layers = []
for i, (layer, key) in enumerate(zip(self.layers, keys)):
new_layer = layer.data_dependent_init(x=x, y=y, key=key)
new_layers.append(new_layer)
x, _ = eqx.filter_vmap(new_layer)(x)
new_layers = tuple(new_layers)
# Turn the new parameters into a new module
get_layers = lambda tree: tree.layers
updated_layer = eqx.tree_at(get_layers, self, new_layers)
return updated_layer
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
delta_logpx = 0.0
layers = reversed(self.layers) if inverse else self.layers
for layer in layers:
x, log_det = layer(x, y=y, inverse=inverse, **kwargs)
delta_logpx += log_det
return x, delta_logpx
def __getitem__(self, i: Union[int, slice]) -> Callable:
if isinstance(i, int):
return self.layers[i]
elif isinstance(i, slice):
return Sequential(self.layers[i])
else:
raise TypeError(f"Indexing with type {type(i)} is not supported")
def __iter__(self):
yield from self.layers
def __len__(self):
return len(self.layers)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/base.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, *layers: Sequence[BijectiveTransform], **kwargs)
¤
Arguments:
layers
: A sequence ofBijectiveTransform
.
Source code in generax/flows/base.py
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
for layer in layers:
assert layer.cond_shape == cond_shape
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
# We need to initialize each of the layers
keys = random.split(key, self.n_layers)
new_layers = []
for i, (layer, key) in enumerate(zip(self.layers, keys)):
new_layer = layer.data_dependent_init(x=x, y=y, key=key)
new_layers.append(new_layer)
x, _ = eqx.filter_vmap(new_layer)(x)
new_layers = tuple(new_layers)
# Turn the new parameters into a new module
get_layers = lambda tree: tree.layers
updated_layer = eqx.tree_at(get_layers, self, new_layers)
return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/base.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
delta_logpx = 0.0
layers = reversed(self.layers) if inverse else self.layers
for layer in layers:
x, log_det = layer(x, y=y, inverse=inverse, **kwargs)
delta_logpx += log_det
return x, delta_logpx
generax.flows.base.TimeDependentSequential (TimeDependentBijectiveTransform)
¤
A sequence of bijective transformations. Accepts a sequence
of BijectiveTransform
initializers.
# Intented usage:
layer1 = MyTransform(...)
layer2 = MyTransform(...)
transform = Sequential(layer1, layer2)
Attributes:
- n_layers
: The number of layers in the composition
- layers
: A tuple of the layers in the composition
Source code in generax/flows/base.py
class TimeDependentSequential(TimeDependentBijectiveTransform):
"""A sequence of bijective transformations. Accepts a sequence
of `BijectiveTransform` initializers.
```python
# Intented usage:
layer1 = MyTransform(...)
layer2 = MyTransform(...)
transform = Sequential(layer1, layer2)
```
**Attributes**:
- `n_layers`: The number of layers in the composition
- `layers`: A tuple of the layers in the composition
"""
n_layers: int = eqx.field(static=True)
layers: Tuple[BijectiveTransform]
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
for layer in layers:
assert layer.cond_shape == cond_shape
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)
def data_dependent_init(self,
t: Array,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
# We need to initialize each of the layers
keys = random.split(key, self.n_layers)
new_layers = []
for i, (layer, key) in enumerate(zip(self.layers, keys)):
new_layer = layer.data_dependent_init(t, x=x, y=y, key=key)
new_layers.append(new_layer)
x, _ = eqx.filter_vmap(new_layer)(t, x)
new_layers = tuple(new_layers)
# Turn the new parameters into a new module
get_layers = lambda tree: tree.layers
updated_layer = eqx.tree_at(get_layers, self, new_layers)
return updated_layer
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
delta_logpx = 0.0
layers = reversed(self.layers) if inverse else self.layers
for layer in layers:
x, log_det = layer(t, x, y=y, inverse=inverse, **kwargs)
delta_logpx += log_det
return x, delta_logpx
def __getitem__(self, i: Union[int, slice]) -> Callable:
if isinstance(i, int):
return self.layers[i]
elif isinstance(i, slice):
return Sequential(self.layers[i])
else:
raise TypeError(f"Indexing with type {type(i)} is not supported")
def __iter__(self):
yield from self.layers
def __len__(self):
return len(self.layers)
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.inverse
.
Source code in generax/flows/base.py
def inverse(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(xt, log_det)
"""
return self(t, x0, y=y, inverse=True, **kwargs)
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.vector_field
.
Source code in generax/flows/base.py
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `xt`: A point in the data space.
- `y`: The conditioning information.
**Returns**:
`return vt`
"""
x0 = self.to_base_space(t, xt, y=y, **kwargs)
def ft(t):
return self.to_data_space(t, x0, y=y, **kwargs)
return jax.jvp(ft, (t,), (jnp.ones_like(t),))[1]
__init__(self, *layers: Sequence[BijectiveTransform], **kwargs)
¤
Arguments:
layers
: A sequence ofBijectiveTransform
.
Source code in generax/flows/base.py
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
for layer in layers:
assert layer.cond_shape == cond_shape
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)
data_dependent_init(self, t: Array, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/base.py
def data_dependent_init(self,
t: Array,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
# We need to initialize each of the layers
keys = random.split(key, self.n_layers)
new_layers = []
for i, (layer, key) in enumerate(zip(self.layers, keys)):
new_layer = layer.data_dependent_init(t, x=x, y=y, key=key)
new_layers.append(new_layer)
x, _ = eqx.filter_vmap(new_layer)(t, x)
new_layers = tuple(new_layers)
# Turn the new parameters into a new module
get_layers = lambda tree: tree.layers
updated_layer = eqx.tree_at(get_layers, self, new_layers)
return updated_layer
__call__(self, t: Array, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/base.py
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
delta_logpx = 0.0
layers = reversed(self.layers) if inverse else self.layers
for layer in layers:
x, log_det = layer(t, x, y=y, inverse=inverse, **kwargs)
delta_logpx += log_det
return x, delta_logpx
generax.flows.base.InjectiveTransform (BijectiveTransform)
¤
This represents an injective transformation. This is a special case of a bijective transformation.
Atributes:
input_shape
: The input shape.output_shape
: The output shape.cond_shape
: The shape of the conditioning information. If there is no conditioning information, this is None.
Source code in generax/flows/base.py
class InjectiveTransform(BijectiveTransform, ABC):
"""This represents an injective transformation. This is a special case of a bijective
transformation.
**Atributes**:
- `input_shape`: The input shape.
- `output_shape`: The output shape.
- `cond_shape`: The shape of the conditioning information. If there is no
conditioning information, this is None.
"""
output_shape: Tuple[int] = eqx.field(static=True)
def __init__(self,
*_,
input_shape: Tuple[int],
output_shape: Tuple[int],
cond_shape: Union[None, Tuple[int]] = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape.
- `output_shape`: The output shape.
- `cond_shape`: The shape of the conditioning information. If there is no
"""
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
assert isinstance(output_shape, tuple) or isinstance(output_shape, list)
self.output_shape = output_shape
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
z, _ = self(x, y=y, **kwargs)
x_proj, _ = self(z, y=y, inverse=True, **kwargs)
return x_proj
def log_determinant(self,
z: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Compute -0.5*log(det(J^TJ))
**Arguments**:
- `z`: An element of the base space
**Returns**:
The log determinant of (J^TJ)^-0.5
"""
def jvp(v_flat):
v = v_flat.reshape(z.shape)
_, (Jv) = jax.jvp(self.to_data_space, (z,), (v,))
return Jv.ravel()
z_dim = util.list_prod(z.shape)
eye = jnp.eye(z_dim)
J = jax.vmap(jvp, in_axes=1, out_axes=1)(eye)
return -0.5*jnp.linalg.slogdet(J.T@J)[1]
def log_determinant_surrogate(z: Array,
transform: eqx.Module,
method: str = 'brute_force',
key: PRNGKeyArray = None,
**kwargs) -> Array:
"""Compute a term that has the same expected gradient as `-0.5*log_det(J^TJ))`.
If `method='brute_force'`, then this is just -0.5*log(det(J^TJ)).
If `method='iterative'`, then this is a term that has the same expected gradient.
**Arguments**:
- `z`: An element of the base space
- `method`: How to compute the log determinant. Options are:
- `brute_force`: Compute the entire Jacobian
- `iterative`: Use conjugate gradient (https://arxiv.org/pdf/2106.01413.pdf)
- `key`: A `jax.random.PRNGKey` for initialization. Needed for some methods
**Returns**:
The log determinant of J^TJ or a term that has the same gradient
"""
def jvp(v_flat):
v = v_flat.reshape(z.shape)
_, (Jv) = jax.jvp(transform, (z,), (v,))
return Jv.ravel()
if method == 'brute_force':
z_dim = util.list_prod(z.shape)
eye = jnp.eye(z_dim)
J = jax.vmap(jvp, in_axes=1, out_axes=1)(eye)
return -0.5*jnp.linalg.slogdet(J.T@J)[1]
elif method == 'iterative':
def vjp(v_flat):
x, vjp = jax.vjp(transform, z)
v = v_flat.reshape(x.shape)
return vjp(v)[0].ravel()
def vjp_jvp(v_flat):
return vjp(jvp(v_flat))
v = random.normal(key, shape=z.shape)
operator = lx.FunctionLinearOperator(vjp_jvp, v, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-3, atol=1e-6)
JTJinv_v = lx.linear_solve(operator, v, solver).value
JTJ_v = vjp_jvp(v)
return -0.5*jnp.vdot(jax.lax.stop_gradient(JTJinv_v), JTJ_v)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/base.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
abstractmethod
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/base.py
@abstractmethod
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
pass
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/base.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, *_, *, input_shape: Tuple[int], output_shape: Tuple[int], cond_shape: Optional[Tuple[int]] = None, **kwargs)
¤
Arguments:
input_shape
: The input shape.output_shape
: The output shape.cond_shape
: The shape of the conditioning information. If there is no
Source code in generax/flows/base.py
def __init__(self,
*_,
input_shape: Tuple[int],
output_shape: Tuple[int],
cond_shape: Union[None, Tuple[int]] = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape.
- `output_shape`: The output shape.
- `cond_shape`: The shape of the conditioning information. If there is no
"""
super().__init__(input_shape=input_shape,
cond_shape=cond_shape,
**kwargs)
assert isinstance(output_shape, tuple) or isinstance(output_shape, list)
self.output_shape = output_shape
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Project a point onto the image of the transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: z
Source code in generax/flows/base.py
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
z, _ = self(x, y=y, **kwargs)
x_proj, _ = self(z, y=y, inverse=True, **kwargs)
return x_proj
log_determinant(self, z: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Compute -0.5*log(det(J^TJ))
Arguments:
z
: An element of the base space
Returns: The log determinant of (J^TJ)^-0.5
Source code in generax/flows/base.py
def log_determinant(self,
z: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Compute -0.5*log(det(J^TJ))
**Arguments**:
- `z`: An element of the base space
**Returns**:
The log determinant of (J^TJ)^-0.5
"""
def jvp(v_flat):
v = v_flat.reshape(z.shape)
_, (Jv) = jax.jvp(self.to_data_space, (z,), (v,))
return Jv.ravel()
z_dim = util.list_prod(z.shape)
eye = jnp.eye(z_dim)
J = jax.vmap(jvp, in_axes=1, out_axes=1)(eye)
return -0.5*jnp.linalg.slogdet(J.T@J)[1]
log_determinant_surrogate(z: Array, transform: Module, method: str = 'brute_force', key: PRNGKeyArray = None, **kwargs) -> Array
¤
Compute a term that has the same expected gradient as -0.5*log_det(J^TJ))
.
If method='brute_force'
, then this is just -0.5*log(det(J^TJ)).
If method='iterative'
, then this is a term that has the same expected gradient.
Arguments:
z
: An element of the base spacemethod
: How to compute the log determinant. Options are:brute_force
: Compute the entire Jacobianiterative
: Use conjugate gradient (https://arxiv.org/pdf/2106.01413.pdf)key
: Ajax.random.PRNGKey
for initialization. Needed for some methods
Returns: The log determinant of J^TJ or a term that has the same gradient
Source code in generax/flows/base.py
def log_determinant_surrogate(z: Array,
transform: eqx.Module,
method: str = 'brute_force',
key: PRNGKeyArray = None,
**kwargs) -> Array:
"""Compute a term that has the same expected gradient as `-0.5*log_det(J^TJ))`.
If `method='brute_force'`, then this is just -0.5*log(det(J^TJ)).
If `method='iterative'`, then this is a term that has the same expected gradient.
**Arguments**:
- `z`: An element of the base space
- `method`: How to compute the log determinant. Options are:
- `brute_force`: Compute the entire Jacobian
- `iterative`: Use conjugate gradient (https://arxiv.org/pdf/2106.01413.pdf)
- `key`: A `jax.random.PRNGKey` for initialization. Needed for some methods
**Returns**:
The log determinant of J^TJ or a term that has the same gradient
"""
def jvp(v_flat):
v = v_flat.reshape(z.shape)
_, (Jv) = jax.jvp(transform, (z,), (v,))
return Jv.ravel()
if method == 'brute_force':
z_dim = util.list_prod(z.shape)
eye = jnp.eye(z_dim)
J = jax.vmap(jvp, in_axes=1, out_axes=1)(eye)
return -0.5*jnp.linalg.slogdet(J.T@J)[1]
elif method == 'iterative':
def vjp(v_flat):
x, vjp = jax.vjp(transform, z)
v = v_flat.reshape(x.shape)
return vjp(v)[0].ravel()
def vjp_jvp(v_flat):
return vjp(jvp(v_flat))
v = random.normal(key, shape=z.shape)
operator = lx.FunctionLinearOperator(vjp_jvp, v, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-3, atol=1e-6)
JTJinv_v = lx.linear_solve(operator, v, solver).value
JTJ_v = vjp_jvp(v)
return -0.5*jnp.vdot(jax.lax.stop_gradient(JTJinv_v), JTJ_v)
generax.flows.base.InjectiveSequential (Sequential, InjectiveTransform)
¤
A sequence of injective or bijective transformations.
Source code in generax/flows/base.py
class InjectiveSequential(Sequential, InjectiveTransform):
"""A sequence of injective or bijective transformations.
"""
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
# and that the output shape of each layer matches the input shape of the next layer
layer_iter = iter(zip(layers[:-1], layers[1:]))
for l1, l2 in layer_iter:
assert l1.cond_shape == cond_shape
if isinstance(l1, InjectiveTransform):
assert l1.output_shape == l2.input_shape
assert l2.cond_shape == cond_shape
if isinstance(l2, InjectiveTransform):
output_shape = l2.output_shape
else:
output_shape = l2.input_shape
InjectiveTransform.__init__(self,
input_shape=input_shape,
output_shape=output_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/base.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Inherited from generax.flows.base.Sequential.data_dependent_init
.
Source code in generax/flows/base.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
# We need to initialize each of the layers
keys = random.split(key, self.n_layers)
new_layers = []
for i, (layer, key) in enumerate(zip(self.layers, keys)):
new_layer = layer.data_dependent_init(x=x, y=y, key=key)
new_layers.append(new_layer)
x, _ = eqx.filter_vmap(new_layer)(x)
new_layers = tuple(new_layers)
# Turn the new parameters into a new module
get_layers = lambda tree: tree.layers
updated_layer = eqx.tree_at(get_layers, self, new_layers)
return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Inherited from generax.flows.base.Sequential.__call__
.
Source code in generax/flows/base.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
delta_logpx = 0.0
layers = reversed(self.layers) if inverse else self.layers
for layer in layers:
x, log_det = layer(x, y=y, inverse=inverse, **kwargs)
delta_logpx += log_det
return x, delta_logpx
__init__(self, *layers: Sequence[BijectiveTransform], **kwargs)
¤
Arguments:
layers
: A sequence ofBijectiveTransform
.
Source code in generax/flows/base.py
def __init__(self,
*layers: Sequence[BijectiveTransform],
**kwargs):
"""**Arguments**:
- `layers`: A sequence of `BijectiveTransform`.
"""
input_shape = layers[0].input_shape
cond_shape = layers[0].cond_shape
# Check that all of the layers have the same cond shape
# and that the output shape of each layer matches the input shape of the next layer
layer_iter = iter(zip(layers[:-1], layers[1:]))
for l1, l2 in layer_iter:
assert l1.cond_shape == cond_shape
if isinstance(l1, InjectiveTransform):
assert l1.output_shape == l2.input_shape
assert l2.cond_shape == cond_shape
if isinstance(l2, InjectiveTransform):
output_shape = l2.output_shape
else:
output_shape = l2.input_shape
InjectiveTransform.__init__(self,
input_shape=input_shape,
output_shape=output_shape,
cond_shape=cond_shape,
**kwargs)
self.layers = tuple(layers)
self.n_layers = len(layers)