Affine¤
generax.flows.affine.Shift (BijectiveTransform)
¤
This represents a shift transformation This is NICE https://arxiv.org/pdf/1410.8516.pdf when used in a coupling layer.
Attributes:
- b
: The shift parameter.
Source code in generax/flows/affine.py
class Shift(BijectiveTransform):
"""This represents a shift transformation
This is NICE https://arxiv.org/pdf/1410.8516.pdf when used
in a coupling layer.
**Attributes**:
- `b`: The shift parameter.
"""
b: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize the parameters randomly
self.b = random.normal(key, shape=input_shape)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
mean, std = misc.mean_and_std(x, axis=0)
# Initialize the parameters so that z will have
# zero mean and unit variance
b = mean
# Turn the new parameters into a new module
get_b = lambda tree: tree.b
updated_layer = eqx.tree_at(get_b, self, b)
return updated_layer
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = x - self.b
else:
z = x + self.b
log_det = jnp.array(0.0)
return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize the parameters randomly
self.b = random.normal(key, shape=input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
mean, std = misc.mean_and_std(x, axis=0)
# Initialize the parameters so that z will have
# zero mean and unit variance
b = mean
# Turn the new parameters into a new module
get_b = lambda tree: tree.b
updated_layer = eqx.tree_at(get_b, self, b)
return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = x - self.b
else:
z = x + self.b
log_det = jnp.array(0.0)
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
generax.flows.affine.ShiftScale (BijectiveTransform)
¤
This represents a shift and scale transformation. This is RealNVP https://arxiv.org/pdf/1605.08803.pdf when used in a coupling layer.
Attributes:
- s_unbounded
: The unbounded scaling parameter.
- b
: The shift parameter.
Source code in generax/flows/affine.py
class ShiftScale(BijectiveTransform):
"""This represents a shift and scale transformation.
This is RealNVP https://arxiv.org/pdf/1605.08803.pdf when used
in a coupling layer.
**Attributes**:
- `s_unbounded`: The unbounded scaling parameter.
- `b`: The shift parameter.
"""
s_unbounded: Array
b: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize the parameters randomly
self.s_unbounded, self.b = random.normal(key, shape=(2,) + input_shape)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
mean, std = misc.mean_and_std(x, axis=0)
std += 1e-4
# Initialize the parameters so that z will have
# zero mean and unit variance
b = mean
s_unbounded = std - 1/std
# Turn the new parameters into a new module
get_b = lambda tree: tree.b
get_s_unbounded = lambda tree: tree.s_unbounded
updated_layer = eqx.tree_at(get_b, self, b)
updated_layer = eqx.tree_at(get_s_unbounded, updated_layer, s_unbounded)
return updated_layer
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# s must be strictly positive
s = misc.square_plus(self.s_unbounded, gamma=1.0)# + 1e-4
log_s = jnp.log(s)
if inverse == False:
z = (x - self.b)/s
else:
z = x*s + self.b
if inverse == False:
log_det = -log_s.sum()
else:
log_det = log_s.sum()
return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize the parameters randomly
self.s_unbounded, self.b = random.normal(key, shape=(2,) + input_shape)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
mean, std = misc.mean_and_std(x, axis=0)
std += 1e-4
# Initialize the parameters so that z will have
# zero mean and unit variance
b = mean
s_unbounded = std - 1/std
# Turn the new parameters into a new module
get_b = lambda tree: tree.b
get_s_unbounded = lambda tree: tree.s_unbounded
updated_layer = eqx.tree_at(get_b, self, b)
updated_layer = eqx.tree_at(get_s_unbounded, updated_layer, s_unbounded)
return updated_layer
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# s must be strictly positive
s = misc.square_plus(self.s_unbounded, gamma=1.0)# + 1e-4
log_s = jnp.log(s)
if inverse == False:
z = (x - self.b)/s
else:
z = x*s + self.b
if inverse == False:
log_det = -log_s.sum()
else:
log_det = log_s.sum()
return z, log_det
generax.flows.affine.DenseLinear (BijectiveTransform)
¤
Multiply the last axis by a dense matrix. When applied to images, this is GLOW https://arxiv.org/pdf/1807.03039.pdf
Attributes:
- W
: The weight matrix
Source code in generax/flows/affine.py
class DenseLinear(BijectiveTransform):
"""Multiply the last axis by a dense matrix. When applied to images,
this is GLOW https://arxiv.org/pdf/1807.03039.pdf
**Attributes**:
- `W`: The weight matrix
"""
W: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
dim = self.input_shape[-1]
self.W = random.normal(key, shape=(dim, dim))
self.W = misc.whiten(self.W)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jnp.einsum('ij,...j->...i', self.W, x)
else:
W_inv = jnp.linalg.inv(self.W)
z = jnp.einsum('ij,...j->...i', W_inv, x)
# Need to multiply the log determinant by the number of times
# that we're applying the transformation.
if len(self.input_shape) > 1:
dim_mult = np.prod(self.input_shape[:-1])
else:
dim_mult = 1
log_det = jnp.linalg.slogdet(self.W)[1]*dim_mult
if inverse:
log_det *= -1
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
dim = self.input_shape[-1]
self.W = random.normal(key, shape=(dim, dim))
self.W = misc.whiten(self.W)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jnp.einsum('ij,...j->...i', self.W, x)
else:
W_inv = jnp.linalg.inv(self.W)
z = jnp.einsum('ij,...j->...i', W_inv, x)
# Need to multiply the log determinant by the number of times
# that we're applying the transformation.
if len(self.input_shape) > 1:
dim_mult = np.prod(self.input_shape[:-1])
else:
dim_mult = 1
log_det = jnp.linalg.slogdet(self.W)[1]*dim_mult
if inverse:
log_det *= -1
return z, log_det
generax.flows.affine.DenseAffine (BijectiveTransform)
¤
Multiply the last axis by a dense matrix. When applied to images, this is GLOW https://arxiv.org/pdf/1807.03039.pdf
Attributes:
- W
: The weight matrix
- b
: The bias vector
Source code in generax/flows/affine.py
class DenseAffine(BijectiveTransform):
"""Multiply the last axis by a dense matrix. When applied to images,
this is GLOW https://arxiv.org/pdf/1807.03039.pdf
**Attributes**:
- `W`: The weight matrix
- `b`: The bias vector
"""
W: DenseLinear
b: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.W = DenseLinear(input_shape=input_shape,
key=key,
**kwargs)
self.b = jnp.zeros(input_shape)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
x = x + self.b
z, log_det = self.W(x, y=y, inverse=False)
else:
z, log_det = self.W(x, y=y, inverse=True)
z = z - self.b
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.W = DenseLinear(input_shape=input_shape,
key=key,
**kwargs)
self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
x = x + self.b
z, log_det = self.W(x, y=y, inverse=False)
else:
z, log_det = self.W(x, y=y, inverse=True)
z = z - self.b
return z, log_det
generax.flows.affine.CaleyOrthogonalMVP (BijectiveTransform)
¤
Caley transform parametrization of an orthogonal matrix. This performs a matrix vector product with an orthogonal matrix.
Attributes:
- W
: The weight matrix
- b
: The bias vector
Source code in generax/flows/affine.py
class CaleyOrthogonalMVP(BijectiveTransform):
"""Caley transform parametrization of an orthogonal matrix. This performs
a matrix vector product with an orthogonal matrix.
**Attributes**:
- `W`: The weight matrix
- `b`: The bias vector
"""
W: Array
b: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
dim = self.input_shape[-1]
self.W = random.normal(key, shape=(dim, dim))
self.b = jnp.zeros(input_shape)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
A = self.W - self.W.T
dim = self.input_shape[-1]
# So that we can multiply with channel dim of images
@partial(jnp.vectorize, signature='(i,j),(j)->(i)')
def matmul(A, x):
return A@x
if inverse == False:
x += self.b
IpA_inv = jnp.linalg.inv(jnp.eye(dim) + A)
y = matmul(IpA_inv, x)
z = y - matmul(A, y)
else:
ImA_inv = jnp.linalg.inv(jnp.eye(dim) - A)
y = matmul(ImA_inv, x)
z = y + matmul(A, y)
z -= self.b
log_det = jnp.zeros(1)
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
dim = self.input_shape[-1]
self.W = random.normal(key, shape=(dim, dim))
self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
A = self.W - self.W.T
dim = self.input_shape[-1]
# So that we can multiply with channel dim of images
@partial(jnp.vectorize, signature='(i,j),(j)->(i)')
def matmul(A, x):
return A@x
if inverse == False:
x += self.b
IpA_inv = jnp.linalg.inv(jnp.eye(dim) + A)
y = matmul(IpA_inv, x)
z = y - matmul(A, y)
else:
ImA_inv = jnp.linalg.inv(jnp.eye(dim) - A)
y = matmul(ImA_inv, x)
z = y + matmul(A, y)
z -= self.b
log_det = jnp.zeros(1)
return z, log_det
generax.flows.affine.PLUAffine (BijectiveTransform)
¤
Multiply the last axis by a matrix that is parametrized using the LU decomposition. This is more efficient than the dense parametrization
Attributes:
- A
: The weight matrix components. The top half is the upper triangular matrix, and the bottom half is the
lower triangular matrix and the diagonal is ignored.
- b
: The bias vector
Source code in generax/flows/affine.py
class PLUAffine(BijectiveTransform):
"""Multiply the last axis by a matrix that is parametrized using the LU decomposition. This is more efficient
than the dense parametrization
**Attributes**:
- `A`: The weight matrix components. The top half is the upper triangular matrix, and the bottom half is the
lower triangular matrix and the diagonal is ignored.
- `b`: The bias vector
"""
A: Array
b: Array
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize so that this will be approximately the identity matrix
dim = input_shape[-1]
self.A = random.normal(key, shape=(dim, dim))*0.01
self.A = self.A.at[jnp.arange(dim),jnp.arange(dim)].set(1.0)
self.b = jnp.zeros(input_shape)
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
dim = x.shape[-1]
mask = jnp.ones((dim, dim), dtype=bool)
upper_mask = jnp.triu(mask)
lower_mask = jnp.tril(mask, k=-1)
if inverse == False:
x += self.b
z = jnp.einsum("ij,...j->...i", self.A*upper_mask, x)
z = jnp.einsum("ij,...j->...i", self.A*lower_mask, z) + z
else:
# vmap in order to handle images
L_solve_vmap = L_solve
U_solve_vmap = U_solve_with_diag
for _ in x.shape[:-1]:
L_solve_vmap = jax.vmap(L_solve_vmap, in_axes=(None, 0))
U_solve_vmap = jax.vmap(U_solve_vmap, in_axes=(None, 0))
z = L_solve_vmap(self.A*lower_mask, x)
z = U_solve_vmap(self.A*upper_mask, z)
z -= self.b
log_det = jnp.log(jnp.abs(jnp.diag(self.A))).sum()*misc.list_prod(x.shape[:-1])
if inverse:
log_det *= -1
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Initialize so that this will be approximately the identity matrix
dim = input_shape[-1]
self.A = random.normal(key, shape=(dim, dim))*0.01
self.A = self.A.at[jnp.arange(dim),jnp.arange(dim)].set(1.0)
self.b = jnp.zeros(input_shape)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> BijectiveTransform:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'x must be batched'
b = -jnp.mean(x, axis=0)
return eqx.tree_at(lambda tree: tree.b, self, b)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Implements generax.flows.base.BijectiveTransform.__call__
.
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
dim = x.shape[-1]
mask = jnp.ones((dim, dim), dtype=bool)
upper_mask = jnp.triu(mask)
lower_mask = jnp.tril(mask, k=-1)
if inverse == False:
x += self.b
z = jnp.einsum("ij,...j->...i", self.A*upper_mask, x)
z = jnp.einsum("ij,...j->...i", self.A*lower_mask, z) + z
else:
# vmap in order to handle images
L_solve_vmap = L_solve
U_solve_vmap = U_solve_with_diag
for _ in x.shape[:-1]:
L_solve_vmap = jax.vmap(L_solve_vmap, in_axes=(None, 0))
U_solve_vmap = jax.vmap(U_solve_vmap, in_axes=(None, 0))
z = L_solve_vmap(self.A*lower_mask, x)
z = U_solve_vmap(self.A*upper_mask, z)
z -= self.b
log_det = jnp.log(jnp.abs(jnp.diag(self.A))).sum()*misc.list_prod(x.shape[:-1])
if inverse:
log_det *= -1
return z, log_det
generax.flows.affine.ConditionalOptionalTransport (TimeDependentBijectiveTransform)
¤
Given x1, compute f(t, x0) = tx1 + (1-t)x0. This is the optimal transport map between the two points. Used in flow matching https://arxiv.org/pdf/2210.02747.pdf
Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.
Attributes:
Source code in generax/flows/affine.py
class ConditionalOptionalTransport(TimeDependentBijectiveTransform):
"""Given x1, compute f(t, x0) = t*x1 + (1-t)*x0. This is the optimal transport
map between the two points. Used in flow matching https://arxiv.org/pdf/2210.02747.pdf
Non-inverse mode goes t -> 0 while inverse mode goes t -> 1.
**Attributes**:
"""
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `t`: The time point.
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to invert the transformation (0 -> t)
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if y is None:
raise ValueError(f'Expected a conditional input')
if y.shape != x.shape:
raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
x1 = y
if inverse:
x0 = x
xt = (1 - t)*x0 + t*x1
log_det = jnp.log(1 - t)
return xt, log_det
else:
xt = x
x0 = (xt - t*x1)/(1 - t)
log_det = -jnp.log(1 - t)
return x0, log_det
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `x0`: A point in the base space.
- `y`: The conditioning information.
**Returns**:
The vector field that samples evolve on at (t, x).
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if y is None:
raise ValueError(f'Expected a conditional input')
if y.shape != x.shape:
raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
return y - x
data_dependent_init(self, t: Array, xt: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.data_dependent_init
.
Source code in generax/flows/affine.py
def data_dependent_init(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `t`: Time.
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, t: Array, x0: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.TimeDependentBijectiveTransform.inverse
.
Source code in generax/flows/affine.py
def inverse(self,
t: Array,
x0: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(xt, log_det)
"""
return self(t, x0, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, t: Array, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
t
: The time point.x
: The input to the transformationy
: The conditioning informationinverse
: Whether to invert the transformation (0 -> t)
Returns:
(z, log_det)
Source code in generax/flows/affine.py
def __call__(self,
t: Array,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `t`: The time point.
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to invert the transformation (0 -> t)
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if y is None:
raise ValueError(f'Expected a conditional input')
if y.shape != x.shape:
raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
x1 = y
if inverse:
x0 = x
xt = (1 - t)*x0 + t*x1
log_det = jnp.log(1 - t)
return xt, log_det
else:
xt = x
x0 = (xt - t*x1)/(1 - t)
log_det = -jnp.log(1 - t)
return x0, log_det
vector_field(self, t: Array, xt: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
The vector field that samples evolve on as t changes
Arguments:
t
: Time.x0
: A point in the base space.y
: The conditioning information.
Returns: The vector field that samples evolve on at (t, x).
Source code in generax/flows/affine.py
def vector_field(self,
t: Array,
xt: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""The vector field that samples evolve on as t changes
**Arguments**:
- `t`: Time.
- `x0`: A point in the base space.
- `y`: The conditioning information.
**Returns**:
The vector field that samples evolve on at (t, x).
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if y is None:
raise ValueError(f'Expected a conditional input')
if y.shape != x.shape:
raise ValueError(f'Expected y.shape ({y.shape}) to match x.shape ({x.shape})')
return y - x
generax.flows.affine.TallDenseLinear (InjectiveTransform)
¤
Matrix vector product with a tall matrix.
Attributes:
- W
: The weight matrix
Source code in generax/flows/affine.py
class TallDenseLinear(InjectiveTransform):
"""Matrix vector product with a tall matrix.
**Attributes**:
- `W`: The weight matrix
"""
W: Array
def __init__(self,
input_shape: Tuple[int],
output_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
assert len(input_shape) == 1, 'Only implemented for 1d data'
super().__init__(input_shape=input_shape,
output_shape=output_shape,
**kwargs)
dim_in = self.input_shape[-1]
dim_out = self.output_shape[-1]
self.W = random.normal(key, shape=(dim_in, dim_out))
self.W = misc.whiten(self.W)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
if inverse == False:
assert x.shape == self.input_shape, 'Only works on unbatched data'
else:
assert x.shape == self.output_shape, 'Only works on unbatched data'
if inverse == False:
W_pinv = jnp.linalg.pinv(self.W)
z = jnp.einsum('ij,...j->...i', W_pinv, x)
else:
z = jnp.einsum('ij,...j->...i', self.W, x)
log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]
if inverse:
log_det *= -1
return z, log_det
def log_determinant(self,
z: Array,
**kwargs) -> Array:
"""Compute -0.5*log(det(J^TJ))
**Arguments**:
- `z`: An element of the base space
**Returns**:
The log determinant of (J^TJ)^0.5
"""
log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]
return log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/affine.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/affine.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
project(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.InjectiveTransform.project
.
Source code in generax/flows/affine.py
def project(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Project a point onto the image of the transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
z
"""
z, _ = self(x, y=y, **kwargs)
x_proj, _ = self(z, y=y, inverse=True, **kwargs)
return x_proj
__init__(self, input_shape: Tuple[int], output_shape: Tuple[int], key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/affine.py
def __init__(self,
input_shape: Tuple[int],
output_shape: Tuple[int],
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
assert len(input_shape) == 1, 'Only implemented for 1d data'
super().__init__(input_shape=input_shape,
output_shape=output_shape,
**kwargs)
dim_in = self.input_shape[-1]
dim_out = self.output_shape[-1]
self.W = random.normal(key, shape=(dim_in, dim_out))
self.W = misc.whiten(self.W)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns: (z, log_det)
Source code in generax/flows/affine.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
(z, log_det)
"""
if inverse == False:
assert x.shape == self.input_shape, 'Only works on unbatched data'
else:
assert x.shape == self.output_shape, 'Only works on unbatched data'
if inverse == False:
W_pinv = jnp.linalg.pinv(self.W)
z = jnp.einsum('ij,...j->...i', W_pinv, x)
else:
z = jnp.einsum('ij,...j->...i', self.W, x)
log_det = -0.5*jnp.linalg.slogdet(self.W.T@self.W)[1]
if inverse:
log_det *= -1
return z, log_det