Nonlinearities¤
generax.flows.nonlinearities.Softplus (BijectiveTransform)
¤
Softplus(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class Softplus(BijectiveTransform):
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == True:
x = jnp.where(x < 0.0, 1e-5, x)
dx = jnp.log1p(-jnp.exp(-x))
z = x + dx
log_det = dx.sum()
else:
z = jax.nn.softplus(x)
log_det = jnp.log1p(-jnp.exp(-z)).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == True:
x = jnp.where(x < 0.0, 1e-5, x)
dx = jnp.log1p(-jnp.exp(-x))
z = x + dx
log_det = dx.sum()
else:
z = jax.nn.softplus(x)
log_det = jnp.log1p(-jnp.exp(-z)).sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.GaussianCDF (BijectiveTransform)
¤
GaussianCDF(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class GaussianCDF(BijectiveTransform):
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jax.scipy.stats.norm.cdf(x)
log_det = jax.scipy.stats.norm.logpdf(x).sum()
else:
z = jax.scipy.stats.norm.ppf(x)
log_det = jax.scipy.stats.norm.logpdf(z).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jax.scipy.stats.norm.cdf(x)
log_det = jax.scipy.stats.norm.logpdf(x).sum()
else:
z = jax.scipy.stats.norm.ppf(x)
log_det = jax.scipy.stats.norm.logpdf(z).sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.LogisticCDF (BijectiveTransform)
¤
LogisticCDF(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class LogisticCDF(BijectiveTransform):
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jax.scipy.stats.logistic.cdf(x)
log_det = jax.scipy.stats.logistic.logpdf(x).sum()
else:
z = jax.scipy.stats.logistic.ppf(x)
log_det = jax.scipy.stats.logistic.logpdf(z).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jax.scipy.stats.logistic.cdf(x)
log_det = jax.scipy.stats.logistic.logpdf(x).sum()
else:
z = jax.scipy.stats.logistic.ppf(x)
log_det = jax.scipy.stats.logistic.logpdf(z).sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.LeakyReLU (BijectiveTransform)
¤
LeakyReLU(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class LeakyReLU(BijectiveTransform):
alpha: float
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.01,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.alpha = alpha
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jnp.where(x > 0, x, self.alpha*x)
else:
z = jnp.where(x > 0, x, x/self.alpha)
log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha))
log_det = log_dx_dz.sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.01, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.01,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.alpha = alpha
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
z = jnp.where(x > 0, x, self.alpha*x)
else:
z = jnp.where(x > 0, x, x/self.alpha)
log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha))
log_det = log_dx_dz.sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.SneakyReLU (BijectiveTransform)
¤
Originally from https://invertibleworkshop.github.io/INNF_2019/accepted_papers/pdfs/INNF_2019_paper_26.pdf
Source code in generax/flows/nonlinearities.py
class SneakyReLU(BijectiveTransform):
""" Originally from https://invertibleworkshop.github.io/INNF_2019/accepted_papers/pdfs/INNF_2019_paper_26.pdf
"""
alpha: float
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.01,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Sneaky ReLU uses a different convention
self.alpha = (1.0 - alpha)/(1.0 + alpha)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
sqrt_1px2 = jnp.sqrt(1 + x**2)
z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha)
log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha)
else:
alpha_sq = self.alpha**2
b = (1 + self.alpha)*x + self.alpha
z = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1)
sqrt_1px2 = jnp.sqrt(1 + z**2)
log_det = jnp.log(1 + self.alpha*z/sqrt_1px2) - jnp.log(1 + self.alpha)
log_det = log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.01, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.01,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
# Sneaky ReLU uses a different convention
self.alpha = (1.0 - alpha)/(1.0 + alpha)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
sqrt_1px2 = jnp.sqrt(1 + x**2)
z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha)
log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha)
else:
alpha_sq = self.alpha**2
b = (1 + self.alpha)*x + self.alpha
z = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1)
sqrt_1px2 = jnp.sqrt(1 + z**2)
log_det = jnp.log(1 + self.alpha*z/sqrt_1px2) - jnp.log(1 + self.alpha)
log_det = log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.SquarePlus (BijectiveTransform)
¤
SquarePlus(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class SquarePlus(BijectiveTransform):
gamma: float
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
gamma: Optional[float] = 0.5,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.gamma = gamma
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
sqrt_arg = x**2 + 4*self.gamma
z = 0.5*(x + jnp.sqrt(sqrt_arg))
z = jnp.maximum(z, 0.0)
dzdx = 0.5*(1 + x*jax.lax.rsqrt(sqrt_arg)) # Always positive
dzdx = jnp.maximum(dzdx, 1e-5)
else:
z = x - self.gamma/x
dzdx = 0.5*(1 + z*jax.lax.rsqrt(z**2 + 4*self.gamma))
log_det = jnp.log(dzdx).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, gamma: Optional[float] = 0.5, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
gamma: Optional[float] = 0.5,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.gamma = gamma
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
sqrt_arg = x**2 + 4*self.gamma
z = 0.5*(x + jnp.sqrt(sqrt_arg))
z = jnp.maximum(z, 0.0)
dzdx = 0.5*(1 + x*jax.lax.rsqrt(sqrt_arg)) # Always positive
dzdx = jnp.maximum(dzdx, 1e-5)
else:
z = x - self.gamma/x
dzdx = 0.5*(1 + z*jax.lax.rsqrt(z**2 + 4*self.gamma))
log_det = jnp.log(dzdx).sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.SquareSigmoid (BijectiveTransform)
¤
SquareSigmoid(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class SquareSigmoid(BijectiveTransform):
gamma: float
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
gamma: Optional[float] = 0.5,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.gamma = gamma
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
rsqrt = jax.lax.rsqrt(x**2 + 4*self.gamma)
z = 0.5*(1 + x*rsqrt)
else:
arg = 2*x - 1
z = 2*jnp.sqrt(self.gamma)*arg*jax.lax.rsqrt(1 - arg**2)
rsqrt = jax.lax.rsqrt(z**2 + 4*self.gamma)
dzdx = 2*self.gamma*rsqrt**3
log_det = jnp.log(dzdx).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, gamma: Optional[float] = 0.5, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
gamma: Optional[float] = 0.5,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.gamma = gamma
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
if inverse == False:
rsqrt = jax.lax.rsqrt(x**2 + 4*self.gamma)
z = 0.5*(1 + x*rsqrt)
else:
arg = 2*x - 1
z = 2*jnp.sqrt(self.gamma)*arg*jax.lax.rsqrt(1 - arg**2)
rsqrt = jax.lax.rsqrt(z**2 + 4*self.gamma)
dzdx = 2*self.gamma*rsqrt**3
log_det = jnp.log(dzdx).sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.SLog (BijectiveTransform)
¤
https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
Source code in generax/flows/nonlinearities.py
class SLog(BijectiveTransform):
""" https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
"""
alpha: Union[float,None]
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.0,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.alpha = alpha
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Bound alpha to be positive
alpha = misc.square_plus(self.alpha) + 1e-4
if inverse == False:
log_det = jnp.log1p(alpha*jnp.abs(x))
z = jnp.sign(x)/alpha*log_det
else:
z = jnp.sign(x)/alpha*(jnp.exp(alpha*jnp.abs(x)) - 1)
log_det = jnp.log1p(alpha*jnp.abs(z))
log_det = -log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.0, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
alpha: Optional[float] = 0.0,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.alpha = alpha
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Bound alpha to be positive
alpha = misc.square_plus(self.alpha) + 1e-4
if inverse == False:
log_det = jnp.log1p(alpha*jnp.abs(x))
z = jnp.sign(x)/alpha*log_det
else:
z = jnp.sign(x)/alpha*(jnp.exp(alpha*jnp.abs(x)) - 1)
log_det = jnp.log1p(alpha*jnp.abs(z))
log_det = -log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
generax.flows.nonlinearities.CartesianToSpherical (BijectiveTransform)
¤
CartesianToSpherical(args, *kwargs)
Source code in generax/flows/nonlinearities.py
class CartesianToSpherical(BijectiveTransform):
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
def forward_fun(self, x, eps=1e-5):
r = jnp.linalg.norm(x)
denominators = jnp.sqrt(jnp.cumsum(x[::-1]**2)[::-1])[:-1]
cos_phi = x[:-1]/denominators
cos_phi = jnp.maximum(-1.0 + eps, cos_phi)
cos_phi = jnp.minimum(1.0 - eps, cos_phi)
phi = jnp.arccos(cos_phi)
last_value = jnp.where(x[-1] >= 0, phi[-1], 2*jnp.pi - phi[-1])
phi = phi.at[-1].set(last_value)
return jnp.concatenate([r[None], phi])
def inverse_fun(self, x):
r = x[:1]
phi = x[1:]
sin_prod = jnp.cumprod(jnp.sin(phi))
first_part = jnp.concatenate([jnp.ones(r.shape), sin_prod])
second_part = jnp.concatenate([jnp.cos(phi), jnp.ones(r.shape)])
return r*first_part*second_part
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
@partial(jnp.vectorize, signature='(n)->(n),()')
def unbatched_apply(x):
def _forward(x):
z = self.forward_fun(x)
r, phi = z[0], z[1:]
return z, r, phi
def _inverse(x):
z = self.inverse_fun(x)
r, phi = x[0], x[1:]
return z, r, phi
if inverse == False:
z, r, phi = _forward(x)
else:
z, r, phi = _inverse(x)
n = x.shape[-1]
n_range = jnp.arange(n - 2, -1, -1)
log_abs_sin_phi = jnp.log(jnp.abs(jnp.sin(phi)))
log_det = -(n - 1)*jnp.log(r) - jnp.sum(n_range*log_abs_sin_phi, axis=-1)
log_det = log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
z, log_det = unbatched_apply(x)
return z, log_det.sum()
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray = None,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/nonlinearities.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
@partial(jnp.vectorize, signature='(n)->(n),()')
def unbatched_apply(x):
def _forward(x):
z = self.forward_fun(x)
r, phi = z[0], z[1:]
return z, r, phi
def _inverse(x):
z = self.inverse_fun(x)
r, phi = x[0], x[1:]
return z, r, phi
if inverse == False:
z, r, phi = _forward(x)
else:
z, r, phi = _inverse(x)
n = x.shape[-1]
n_range = jnp.arange(n - 2, -1, -1)
log_abs_sin_phi = jnp.log(jnp.abs(jnp.sin(phi)))
log_det = -(n - 1)*jnp.log(r) - jnp.sum(n_range*log_abs_sin_phi, axis=-1)
log_det = log_det.sum()
if inverse:
log_det = -log_det
return z, log_det
z, log_det = unbatched_apply(x)
return z, log_det.sum()
generax.flows.logistic_cdf_mixture_logit.LogisticCDFMixtureLogit (BijectiveTransform)
¤
Used in Flow++ https://arxiv.org/pdf/1902.00275.pdf This is a logistic CDF mixture model followed by a logit.
Attributes:
- theta
: The parameters of the transformation.
Source code in generax/flows/logistic_cdf_mixture_logit.py
class LogisticCDFMixtureLogit(BijectiveTransform):
"""Used in Flow++ https://arxiv.org/pdf/1902.00275.pdf
This is a logistic CDF mixture model followed by a logit.
**Attributes**:
- `theta`: The parameters of the transformation.
"""
theta: Array
K: int = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
K: int = 8,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `K`: The number of knots to use.
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.K = K
x_dim = util.list_prod(input_shape)
self.theta = random.normal(key, shape=(x_dim*(3*self.K),))*0.1
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Flatten x
x = x.ravel()
theta = self.theta.reshape(x.shape + (3*self.K,))
# Split the parameters
weight_logits, means, scales = theta[...,:self.K], theta[...,self.K:2*self.K], theta[...,2*self.K:]
scales = misc.square_plus(scales, gamma=1.0) + 1e-4
# Create the jvp function that we'll need
def f_and_df(x, *args):
primals = weight_logits, means, scales, x
tangents = jax.tree_util.tree_map(jnp.zeros_like, primals[:-1]) + (jnp.ones_like(x),)
return jax.jvp(logistic_cdf_mixture_logit, primals, tangents)
if inverse == False:
# Only need a single pass
z, dzdx = f_and_df(x)
else:
# Invert with bisection method.
f = lambda x, *args: f_and_df(x, *args)[0]
lower, upper = -1000.0, 1000.0
lower, upper = jnp.broadcast_to(lower, x.shape), jnp.broadcast_to(upper, x.shape)
z = util.bisection(f, lower, upper, x)
reconstr, dzdx = f_and_df(z)
ew_log_det = jnp.log(dzdx)
log_det = ew_log_det.sum()
if inverse:
log_det *= -1
# Unflatten the output
z = z.reshape(self.input_shape)
return z, log_det
__init__(self, input_shape: Tuple[int], K: int = 8, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initializationK
: The number of knots to use.
Source code in generax/flows/logistic_cdf_mixture_logit.py
def __init__(self,
input_shape: Tuple[int],
K: int = 8,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `K`: The number of knots to use.
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.K = K
x_dim = util.list_prod(input_shape)
self.theta = random.normal(key, shape=(x_dim*(3*self.K),))*0.1
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/logistic_cdf_mixture_logit.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Flatten x
x = x.ravel()
theta = self.theta.reshape(x.shape + (3*self.K,))
# Split the parameters
weight_logits, means, scales = theta[...,:self.K], theta[...,self.K:2*self.K], theta[...,2*self.K:]
scales = misc.square_plus(scales, gamma=1.0) + 1e-4
# Create the jvp function that we'll need
def f_and_df(x, *args):
primals = weight_logits, means, scales, x
tangents = jax.tree_util.tree_map(jnp.zeros_like, primals[:-1]) + (jnp.ones_like(x),)
return jax.jvp(logistic_cdf_mixture_logit, primals, tangents)
if inverse == False:
# Only need a single pass
z, dzdx = f_and_df(x)
else:
# Invert with bisection method.
f = lambda x, *args: f_and_df(x, *args)[0]
lower, upper = -1000.0, 1000.0
lower, upper = jnp.broadcast_to(lower, x.shape), jnp.broadcast_to(upper, x.shape)
z = util.bisection(f, lower, upper, x)
reconstr, dzdx = f_and_df(z)
ew_log_det = jnp.log(dzdx)
log_det = ew_log_det.sum()
if inverse:
log_det *= -1
# Unflatten the output
z = z.reshape(self.input_shape)
return z, log_det
generax.flows.spline.RationalQuadraticSpline (BijectiveTransform)
¤
Splines from https://arxiv.org/pdf/1906.04032.pdf. This is the best overall choice to use in flows.
Attributes:
- theta
: The parameters of the spline.
Source code in generax/flows/spline.py
class RationalQuadraticSpline(BijectiveTransform):
"""Splines from https://arxiv.org/pdf/1906.04032.pdf. This is the best overall choice to use in flows.
**Attributes**:
- `theta`: The parameters of the spline.
"""
theta: Array
K: int = eqx.field(static=True)
min_width: float = eqx.field(static=True)
min_height: float = eqx.field(static=True)
min_derivative: float = eqx.field(static=True)
bounds: Sequence[float] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
K: int = 8,
min_width: Optional[float] = 1e-3,
min_height: Optional[float] = 1e-3,
min_derivative: Optional[float] = 1e-8,
bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `K`: The number of knots to use.
- `min_width`: The minimum width of the knots.
- `min_height`: The minimum height of the knots.
- `min_derivative`: The minimum derivative of the knots.
- `bounds`: The bounds of the splines.
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.K = K
self.min_width = min_width
self.min_height = min_height
self.min_derivative = min_derivative
self.bounds = bounds
x_dim = util.list_prod(input_shape)
self.theta = random.normal(key, shape=(x_dim*(3*self.K - 1),))*0.1
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Flatten x
x = x.ravel()
# Get the parameters
settings = self.K, self.min_width, self.min_height, self.min_derivative, self.bounds
theta = jnp.broadcast_to(self.theta, x.shape + self.theta.shape)
knot_x, knot_y, knot_derivs = get_knot_params(settings, theta)
# The relevant knot depends on if we are inverting or not
if inverse == False:
mask = (x > self.bounds[0][0] + 1e-5) & (x < self.bounds[0][1] - 1e-5)
apply_fun = forward_spline
else:
mask = (x > self.bounds[1][0] + 1e-5) & (x < self.bounds[1][1] - 1e-5)
apply_fun = inverse_spline
args = find_knots(x, knot_x, knot_y, knot_derivs, inverse)
z, dzdx = apply_fun(x, mask, *args)
elementwise_log_det = jnp.log(dzdx)
log_det = elementwise_log_det.sum()
if inverse:
log_det = -log_det
# Unflatten the output
z = z.reshape(self.input_shape)
return z, log_det
__init__(self, input_shape: Tuple[int], K: int = 8, min_width: Optional[float] = 0.001, min_height: Optional[float] = 0.001, min_derivative: Optional[float] = 1e-08, bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)), *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initializationK
: The number of knots to use.min_width
: The minimum width of the knots.min_height
: The minimum height of the knots.min_derivative
: The minimum derivative of the knots.bounds
: The bounds of the splines.
Source code in generax/flows/spline.py
def __init__(self,
input_shape: Tuple[int],
K: int = 8,
min_width: Optional[float] = 1e-3,
min_height: Optional[float] = 1e-3,
min_derivative: Optional[float] = 1e-8,
bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `K`: The number of knots to use.
- `min_width`: The minimum width of the knots.
- `min_height`: The minimum height of the knots.
- `min_derivative`: The minimum derivative of the knots.
- `bounds`: The bounds of the splines.
"""
super().__init__(input_shape=input_shape,
**kwargs)
self.K = K
self.min_width = min_width
self.min_height = min_height
self.min_derivative = min_derivative
self.bounds = bounds
x_dim = util.list_prod(input_shape)
self.theta = random.normal(key, shape=(x_dim*(3*self.K - 1),))*0.1
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/spline.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool=False,
**kwargs) -> Array:
"""**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
# Flatten x
x = x.ravel()
# Get the parameters
settings = self.K, self.min_width, self.min_height, self.min_derivative, self.bounds
theta = jnp.broadcast_to(self.theta, x.shape + self.theta.shape)
knot_x, knot_y, knot_derivs = get_knot_params(settings, theta)
# The relevant knot depends on if we are inverting or not
if inverse == False:
mask = (x > self.bounds[0][0] + 1e-5) & (x < self.bounds[0][1] - 1e-5)
apply_fun = forward_spline
else:
mask = (x > self.bounds[1][0] + 1e-5) & (x < self.bounds[1][1] - 1e-5)
apply_fun = inverse_spline
args = find_knots(x, knot_x, knot_y, knot_derivs, inverse)
z, dzdx = apply_fun(x, mask, *args)
elementwise_log_det = jnp.log(dzdx)
log_det = elementwise_log_det.sum()
if inverse:
log_det = -log_det
# Unflatten the output
z = z.reshape(self.input_shape)
return z, log_det