Reshape¤
generax.flows.reshape.Flatten (BijectiveTransform)
¤
Flatten
Source code in generax/flows/reshape.py
class Flatten(BijectiveTransform):
"""Flatten
"""
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
log_det = jnp.array(0.0)
if inverse == False:
assert x.shape == self.input_shape
return x.ravel(), log_det
else:
return x.reshape(self.input_shape), log_det
__init__(self, input_shape: Tuple[int], **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/reshape.py
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: The transformed input and 0
Source code in generax/flows/reshape.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
log_det = jnp.array(0.0)
if inverse == False:
assert x.shape == self.input_shape
return x.ravel(), log_det
else:
return x.reshape(self.input_shape), log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/reshape.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/reshape.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
generax.flows.reshape.Reverse (BijectiveTransform)
¤
Reverse an input
Source code in generax/flows/reshape.py
class Reverse(BijectiveTransform):
"""Reverse an input
"""
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
assert x.shape == self.input_shape
z = x[..., ::-1]
log_det = jnp.array(0.0)
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/reshape.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/reshape.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: The transformed input and 0
Source code in generax/flows/reshape.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
assert x.shape == self.input_shape
z = x[..., ::-1]
log_det = jnp.array(0.0)
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/reshape.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
generax.flows.reshape.Checkerboard (BijectiveTransform)
¤
Checkerboard pattern from https://arxiv.org/pdf/1605.08803.pdf
Source code in generax/flows/reshape.py
class Checkerboard(BijectiveTransform):
"""Checkerboard pattern from https://arxiv.org/pdf/1605.08803.pdf"""
output_shape: Tuple[int] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
"""
H, W, C = input_shape
assert W%2 == 0, 'Need even width'
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H, W//2, C*2)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
if inverse == False:
assert x.shape == self.input_shape
z = einops.rearrange(x, 'h (w k) c -> h w (c k)', k=2)
else:
assert x.shape == self.output_shape
z = einops.rearrange(x, 'h w (c k) -> h (w k) c', k=2)
log_det = jnp.array(0.0)
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/reshape.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__init__(self, input_shape: Tuple[int], **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.
Source code in generax/flows/reshape.py
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
"""
H, W, C = input_shape
assert W%2 == 0, 'Need even width'
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H, W//2, C*2)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: The transformed input and 0
Source code in generax/flows/reshape.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
if inverse == False:
assert x.shape == self.input_shape
z = einops.rearrange(x, 'h (w k) c -> h w (c k)', k=2)
else:
assert x.shape == self.output_shape
z = einops.rearrange(x, 'h w (c k) -> h (w k) c', k=2)
log_det = jnp.array(0.0)
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/reshape.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
generax.flows.reshape.Squeeze (BijectiveTransform)
¤
Space to depth. (H, W, C) -> (H//2, W//2, C*4)
Source code in generax/flows/reshape.py
class Squeeze(BijectiveTransform):
"""Space to depth. (H, W, C) -> (H//2, W//2, C*4)"""
output_shape: Tuple[int] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
H, W, C = input_shape
assert H % 2 == 0, 'Need even height'
assert W % 2 == 0, 'Need even width'
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H//2, W//2, C*4)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
if inverse == False:
assert x.shape == self.input_shape
z = einops.rearrange(x, '(h m) (w n) c -> h w (c m n)', m=2, n=2)
else:
assert x.shape == self.output_shape
z = einops.rearrange(x, 'h w (c m n) -> (h m) (w n) c', m=2, n=2)
log_det = jnp.array(0.0)
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/reshape.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/reshape.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/reshape.py
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
H, W, C = input_shape
assert H % 2 == 0, 'Need even height'
assert W % 2 == 0, 'Need even width'
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H//2, W//2, C*4)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: The transformed input and 0
Source code in generax/flows/reshape.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
The transformed input and 0
"""
if inverse == False:
assert x.shape == self.input_shape
z = einops.rearrange(x, '(h m) (w n) c -> h w (c m n)', m=2, n=2)
else:
assert x.shape == self.output_shape
z = einops.rearrange(x, 'h w (c m n) -> (h m) (w n) c', m=2, n=2)
log_det = jnp.array(0.0)
return z, log_det
generax.flows.reshape.Slice (InjectiveTransform)
¤
Slice an input to reduce the dimension
Source code in generax/flows/reshape.py
class Slice(InjectiveTransform):
"""Slice an input to reduce the dimension
"""
def __init__(self,
input_shape: Tuple[int],
output_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
assert input_shape[:-1] == output_shape[:-1], 'Need to keep the same spatial dimensions'
super().__init__(input_shape=input_shape,
output_shape=output_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
if inverse == False:
assert x.shape == self.input_shape, 'Only works on unbatched data'
else:
assert x.shape == self.output_shape, 'Only works on unbatched data'
if inverse == False:
z = x[..., :self.output_shape[-1]]
else:
pad_shape = self.input_shape[:-1] + (self.input_shape[-1] - self.output_shape[-1],)
z = jnp.concatenate([x, jnp.zeros(pad_shape)], axis=-1)
log_det = jnp.array(0.0)
return z, log_det
def log_determinant(self,
z: Array,
**kwargs) -> Array:
"""Compute -0.5*log(det(J^TJ))
**Arguments**:
- `z`: An element of the base space
**Returns**:
The log determinant of (J^TJ)^0.5
"""
return jnp.array(0.0)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/reshape.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/reshape.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], output_shape: Tuple[int], **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/reshape.py
def __init__(self,
input_shape: Tuple[int],
output_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
assert input_shape[:-1] == output_shape[:-1], 'Need to keep the same spatial dimensions'
super().__init__(input_shape=input_shape,
output_shape=output_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/reshape.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
if inverse == False:
assert x.shape == self.input_shape, 'Only works on unbatched data'
else:
assert x.shape == self.output_shape, 'Only works on unbatched data'
if inverse == False:
z = x[..., :self.output_shape[-1]]
else:
pad_shape = self.input_shape[:-1] + (self.input_shape[-1] - self.output_shape[-1],)
z = jnp.concatenate([x, jnp.zeros(pad_shape)], axis=-1)
log_det = jnp.array(0.0)
return z, log_det
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.InjectiveTransform.project
.
Source code in generax/flows/reshape.py
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
z, _ = self(x, y=y, **kwargs)
x_proj, _ = self(z, y=y, inverse=True, **kwargs)
return x_proj