Models¤
generax.flows.models.GeneralTransform (Repeat)
¤
GeneralTransform(args, *kwargs)
Source code in generax/flows/models.py
class GeneralTransform(Repeat):
def __init__(self,
TransformType: type,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
filter_shape: Optional[Tuple[int]] = (3, 3),
cond_shape: Optional[Tuple[int]] = None,
coupling_split_dim: Optional[int] = None,
reverse_conditioning: Optional[bool] = False,
create_net: Optional[Callable[[PRNGKeyArray], Any]] = None,
*,
key: PRNGKeyArray,
**kwargs):
def init_transform(transform_input_shape, key):
return TransformType(input_shape=transform_input_shape,
cond_shape=cond_shape,
key=key)
def _create_net(net_input_shape, net_output_size, key):
return ResNet(input_shape=net_input_shape,
working_size=working_size,
hidden_size=hidden_size,
out_size=net_output_size,
n_blocks=n_blocks,
filter_shape=filter_shape,
cond_shape=cond_shape,
key=key)
create_net = create_net if create_net is not None else _create_net
def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
k1, k2, k3 = random.split(key, 3)
layers = []
layer = Coupling(init_transform,
create_net,
input_shape=input_shape,
cond_shape=cond_shape,
split_dim=coupling_split_dim,
reverse_conditioning=reverse_conditioning,
key=k1)
layers.append(layer)
layers.append(PLUAffine(input_shape=input_shape,
cond_shape=cond_shape,
key=k2))
layers.append(ShiftScale(input_shape=input_shape,
cond_shape=cond_shape,
key=k3))
return Sequential(*layers, **kwargs)
super().__init__(make_single_flow_layer, n_flow_layers, key=key)
__init__(self, TransformType: type, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, filter_shape: Optional[Tuple[int]] = (3, 3), cond_shape: Optional[Tuple[int]] = None, coupling_split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, create_net: Optional[Callable[[PRNGKeyArray], Any]] = None, *, key: PRNGKeyArray, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
TransformType: type,
input_shape: Tuple[int],
n_flow_layers: int = 3,
working_size: int = 16,
hidden_size: int = 32,
n_blocks: int = 4,
filter_shape: Optional[Tuple[int]] = (3, 3),
cond_shape: Optional[Tuple[int]] = None,
coupling_split_dim: Optional[int] = None,
reverse_conditioning: Optional[bool] = False,
create_net: Optional[Callable[[PRNGKeyArray], Any]] = None,
*,
key: PRNGKeyArray,
**kwargs):
def init_transform(transform_input_shape, key):
return TransformType(input_shape=transform_input_shape,
cond_shape=cond_shape,
key=key)
def _create_net(net_input_shape, net_output_size, key):
return ResNet(input_shape=net_input_shape,
working_size=working_size,
hidden_size=hidden_size,
out_size=net_output_size,
n_blocks=n_blocks,
filter_shape=filter_shape,
cond_shape=cond_shape,
key=key)
create_net = create_net if create_net is not None else _create_net
def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
k1, k2, k3 = random.split(key, 3)
layers = []
layer = Coupling(init_transform,
create_net,
input_shape=input_shape,
cond_shape=cond_shape,
split_dim=coupling_split_dim,
reverse_conditioning=reverse_conditioning,
key=k1)
layers.append(layer)
layers.append(PLUAffine(input_shape=input_shape,
cond_shape=cond_shape,
key=k2))
layers.append(ShiftScale(input_shape=input_shape,
cond_shape=cond_shape,
key=k3))
return Sequential(*layers, **kwargs)
super().__init__(make_single_flow_layer, n_flow_layers, key=key)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Inherited from generax.flows.base.Repeat.data_dependent_init
.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Inherited from generax.flows.base.Repeat.__call__
.
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.NICETransform (GeneralTransform)
¤
NICETransform(args, *kwargs)
Source code in generax/flows/models.py
class NICETransform(GeneralTransform):
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=Shift,
*args,
**kwargs)
__init__(self, *args, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=Shift,
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.RealNVPTransform (GeneralTransform)
¤
RealNVPTransform(args, *kwargs)
Source code in generax/flows/models.py
class RealNVPTransform(GeneralTransform):
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=ShiftScale,
*args,
**kwargs)
__init__(self, *args, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=ShiftScale,
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.NeuralSplineTransform (GeneralTransform)
¤
NeuralSplineTransform(args, *kwargs)
Source code in generax/flows/models.py
class NeuralSplineTransform(GeneralTransform):
def __init__(self,
*args,
n_spline_knots: int = 8,
**kwargs):
super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
*args,
**kwargs)
__init__(self, *args, *, n_spline_knots: int = 8, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
n_spline_knots: int = 8,
**kwargs):
super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.GeneralImageTransform (Repeat)
¤
GeneralImageTransform(args, *kwargs)
Source code in generax/flows/models.py
class GeneralImageTransform(Repeat):
def __init__(self,
TransformType: type,
input_shape: Tuple[int],
n_flow_layers: int = 3,
cond_shape: Optional[Tuple[int]] = None,
unet: Optional[bool] = True,
coupling_split_dim: Optional[int] = None,
reverse_conditioning: Optional[bool] = False,
*,
key: PRNGKeyArray,
**kwargs):
def init_transform(transform_input_shape, key):
return TransformType(input_shape=transform_input_shape,
cond_shape=cond_shape,
key=key)
if unet:
def create_net(net_input_shape, net_output_size, key):
H, W, C = net_input_shape
return UNet(input_shape=net_input_shape,
dim=kwargs.pop('dim', 32),
out_channels=net_output_size//(H*W),
dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
attn_heads=kwargs.pop('attn_heads', 4),
attn_dim_head=kwargs.pop('attn_dim_head', 32),
cond_shape=cond_shape,
time_dependent=False,
key=key)
else:
def create_net(net_input_shape, net_output_size, key):
return Encoder(input_shape=net_input_shape,
dim=kwargs.pop('dim', 32),
dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
attn_heads=kwargs.pop('attn_heads', 4),
attn_dim_head=kwargs.pop('attn_dim_head', 32),
out_size=net_output_size,
cond_shape=cond_shape,
key=key)
def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
k1, k2, k3 = random.split(key, 3)
layers = []
layer = Coupling(init_transform,
create_net,
input_shape=input_shape,
cond_shape=cond_shape,
split_dim=coupling_split_dim,
reverse_conditioning=reverse_conditioning,
key=k1)
layers.append(layer)
layers.append(OneByOneConv(input_shape=input_shape,
key=k2))
layers.append(ShiftScale(input_shape=input_shape,
cond_shape=cond_shape,
key=k3))
return Sequential(*layers, **kwargs)
super().__init__(make_single_flow_layer, n_flow_layers, key=key)
__init__(self, TransformType: type, input_shape: Tuple[int], n_flow_layers: int = 3, cond_shape: Optional[Tuple[int]] = None, unet: Optional[bool] = True, coupling_split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, *, key: PRNGKeyArray, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
TransformType: type,
input_shape: Tuple[int],
n_flow_layers: int = 3,
cond_shape: Optional[Tuple[int]] = None,
unet: Optional[bool] = True,
coupling_split_dim: Optional[int] = None,
reverse_conditioning: Optional[bool] = False,
*,
key: PRNGKeyArray,
**kwargs):
def init_transform(transform_input_shape, key):
return TransformType(input_shape=transform_input_shape,
cond_shape=cond_shape,
key=key)
if unet:
def create_net(net_input_shape, net_output_size, key):
H, W, C = net_input_shape
return UNet(input_shape=net_input_shape,
dim=kwargs.pop('dim', 32),
out_channels=net_output_size//(H*W),
dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
attn_heads=kwargs.pop('attn_heads', 4),
attn_dim_head=kwargs.pop('attn_dim_head', 32),
cond_shape=cond_shape,
time_dependent=False,
key=key)
else:
def create_net(net_input_shape, net_output_size, key):
return Encoder(input_shape=net_input_shape,
dim=kwargs.pop('dim', 32),
dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
attn_heads=kwargs.pop('attn_heads', 4),
attn_dim_head=kwargs.pop('attn_dim_head', 32),
out_size=net_output_size,
cond_shape=cond_shape,
key=key)
def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
k1, k2, k3 = random.split(key, 3)
layers = []
layer = Coupling(init_transform,
create_net,
input_shape=input_shape,
cond_shape=cond_shape,
split_dim=coupling_split_dim,
reverse_conditioning=reverse_conditioning,
key=k1)
layers.append(layer)
layers.append(OneByOneConv(input_shape=input_shape,
key=k2))
layers.append(ShiftScale(input_shape=input_shape,
cond_shape=cond_shape,
key=k3))
return Sequential(*layers, **kwargs)
super().__init__(make_single_flow_layer, n_flow_layers, key=key)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Inherited from generax.flows.base.Repeat.data_dependent_init
.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Inherited from generax.flows.base.Repeat.__call__
.
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.NICEImageTransform (GeneralImageTransform)
¤
NICEImageTransform(args, *kwargs)
Source code in generax/flows/models.py
class NICEImageTransform(GeneralImageTransform):
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=Shift,
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=Shift,
*args,
**kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.RealNVPImageTransform (GeneralImageTransform)
¤
RealNVPImageTransform(args, *kwargs)
Source code in generax/flows/models.py
class RealNVPImageTransform(GeneralImageTransform):
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=ShiftScale,
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
**kwargs):
super().__init__(TransformType=ShiftScale,
*args,
**kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()
generax.flows.models.NeuralSplineImageTransform (GeneralImageTransform)
¤
NeuralSplineImageTransform(args, *kwargs)
Source code in generax/flows/models.py
class NeuralSplineImageTransform(GeneralImageTransform):
def __init__(self,
*args,
n_spline_knots: int = 8,
**kwargs):
super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
*args,
**kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/models.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, *, n_spline_knots: int = 8, **kwargs)
¤
Source code in generax/flows/models.py
def __init__(self,
*args,
n_spline_knots: int = 8,
**kwargs):
super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
*args,
**kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/models.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
seq = self.to_sequential()
# Apply the data dependent initalization
out_seq_layers = seq.data_dependent_init(x, y=y, key=key)
# Turn the sequential layers into a repeat layer
all_params = []
for i, layer in enumerate(out_seq_layers):
params, _ = eqx.partition(layer, eqx.is_array)
all_params.append(params)
# Combine the parameters back into a single layer
params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
_, static = eqx.partition(self.layers, eqx.is_array)
layers = eqx.combine(params, static)
get_layers = lambda tree: tree.layers
updated_module = eqx.tree_at(get_layers, self, layers)
return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/models.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
dynamic, static = eqx.partition(self.layers, eqx.is_array)
def scan_body(x, params):
block = eqx.combine(params, static)
x, log_det = block(x, y=y, inverse=inverse, **kwargs)
return x, log_det
x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
return x, log_dets.sum()