Skip to content

Nonlinearities¤

generax.flows.nonlinearities.Softplus (BijectiveTransform) ¤

Softplus(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class Softplus(BijectiveTransform):

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray = None,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == True:
      x = jnp.where(x < 0.0, 1e-5, x)
      dx = jnp.log1p(-jnp.exp(-x))
      z = x + dx
      log_det = dx.sum()
    else:
      z = jax.nn.softplus(x)
      log_det = jnp.log1p(-jnp.exp(-z)).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray = None,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == True:
    x = jnp.where(x < 0.0, 1e-5, x)
    dx = jnp.log1p(-jnp.exp(-x))
    z = x + dx
    log_det = dx.sum()
  else:
    z = jax.nn.softplus(x)
    log_det = jnp.log1p(-jnp.exp(-z)).sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.GaussianCDF (BijectiveTransform) ¤

GaussianCDF(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class GaussianCDF(BijectiveTransform):

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray = None,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      z = jax.scipy.stats.norm.cdf(x)
      log_det = jax.scipy.stats.norm.logpdf(x).sum()
    else:
      z = jax.scipy.stats.norm.ppf(x)
      log_det = jax.scipy.stats.norm.logpdf(z).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray = None,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    z = jax.scipy.stats.norm.cdf(x)
    log_det = jax.scipy.stats.norm.logpdf(x).sum()
  else:
    z = jax.scipy.stats.norm.ppf(x)
    log_det = jax.scipy.stats.norm.logpdf(z).sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.LogisticCDF (BijectiveTransform) ¤

LogisticCDF(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class LogisticCDF(BijectiveTransform):

  def __init__(self,
               input_shape: Tuple[int],
               key: PRNGKeyArray = None,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      z = jax.scipy.stats.logistic.cdf(x)
      log_det = jax.scipy.stats.logistic.logpdf(x).sum()
    else:
      z = jax.scipy.stats.logistic.ppf(x)
      log_det = jax.scipy.stats.logistic.logpdf(z).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], key: PRNGKeyArray = None, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             key: PRNGKeyArray = None,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    z = jax.scipy.stats.logistic.cdf(x)
    log_det = jax.scipy.stats.logistic.logpdf(x).sum()
  else:
    z = jax.scipy.stats.logistic.ppf(x)
    log_det = jax.scipy.stats.logistic.logpdf(z).sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.LeakyReLU (BijectiveTransform) ¤

LeakyReLU(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class LeakyReLU(BijectiveTransform):

  alpha: float

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               alpha: Optional[float] = 0.01,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.alpha = alpha

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      z = jnp.where(x > 0, x, self.alpha*x)
    else:
      z = jnp.where(x > 0, x, x/self.alpha)

    log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha))
    log_det = log_dx_dz.sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.01, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             alpha: Optional[float] = 0.01,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.alpha = alpha
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    z = jnp.where(x > 0, x, self.alpha*x)
  else:
    z = jnp.where(x > 0, x, x/self.alpha)

  log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha))
  log_det = log_dx_dz.sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.SneakyReLU (BijectiveTransform) ¤

Originally from https://invertibleworkshop.github.io/INNF_2019/accepted_papers/pdfs/INNF_2019_paper_26.pdf

Source code in generax/flows/nonlinearities.py
class SneakyReLU(BijectiveTransform):
  """ Originally from https://invertibleworkshop.github.io/INNF_2019/accepted_papers/pdfs/INNF_2019_paper_26.pdf
  """

  alpha: float

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               alpha: Optional[float] = 0.01,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    # Sneaky ReLU uses a different convention
    self.alpha = (1.0 - alpha)/(1.0 + alpha)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      sqrt_1px2 = jnp.sqrt(1 + x**2)
      z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha)
      log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha)
    else:
      alpha_sq = self.alpha**2
      b = (1 + self.alpha)*x + self.alpha
      z = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1)
      sqrt_1px2 = jnp.sqrt(1 + z**2)
      log_det = jnp.log(1 + self.alpha*z/sqrt_1px2) - jnp.log(1 + self.alpha)

    log_det = log_det.sum()

    if inverse:
      log_det = -log_det
    return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.01, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             alpha: Optional[float] = 0.01,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  # Sneaky ReLU uses a different convention
  self.alpha = (1.0 - alpha)/(1.0 + alpha)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    sqrt_1px2 = jnp.sqrt(1 + x**2)
    z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha)
    log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha)
  else:
    alpha_sq = self.alpha**2
    b = (1 + self.alpha)*x + self.alpha
    z = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1)
    sqrt_1px2 = jnp.sqrt(1 + z**2)
    log_det = jnp.log(1 + self.alpha*z/sqrt_1px2) - jnp.log(1 + self.alpha)

  log_det = log_det.sum()

  if inverse:
    log_det = -log_det
  return z, log_det

generax.flows.nonlinearities.SquarePlus (BijectiveTransform) ¤

SquarePlus(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class SquarePlus(BijectiveTransform):

  gamma: float

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               gamma: Optional[float] = 0.5,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.gamma = gamma

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      sqrt_arg = x**2 + 4*self.gamma
      z = 0.5*(x + jnp.sqrt(sqrt_arg))
      z = jnp.maximum(z, 0.0)
      dzdx = 0.5*(1 + x*jax.lax.rsqrt(sqrt_arg)) # Always positive
      dzdx = jnp.maximum(dzdx, 1e-5)
    else:
      z = x - self.gamma/x
      dzdx = 0.5*(1 + z*jax.lax.rsqrt(z**2 + 4*self.gamma))

    log_det = jnp.log(dzdx).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, gamma: Optional[float] = 0.5, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             gamma: Optional[float] = 0.5,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.gamma = gamma
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    sqrt_arg = x**2 + 4*self.gamma
    z = 0.5*(x + jnp.sqrt(sqrt_arg))
    z = jnp.maximum(z, 0.0)
    dzdx = 0.5*(1 + x*jax.lax.rsqrt(sqrt_arg)) # Always positive
    dzdx = jnp.maximum(dzdx, 1e-5)
  else:
    z = x - self.gamma/x
    dzdx = 0.5*(1 + z*jax.lax.rsqrt(z**2 + 4*self.gamma))

  log_det = jnp.log(dzdx).sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.SquareSigmoid (BijectiveTransform) ¤

SquareSigmoid(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class SquareSigmoid(BijectiveTransform):

  gamma: float

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               gamma: Optional[float] = 0.5,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.gamma = gamma

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    if inverse == False:
      rsqrt = jax.lax.rsqrt(x**2 + 4*self.gamma)
      z = 0.5*(1 + x*rsqrt)
    else:
      arg = 2*x - 1
      z = 2*jnp.sqrt(self.gamma)*arg*jax.lax.rsqrt(1 - arg**2)
      rsqrt = jax.lax.rsqrt(z**2 + 4*self.gamma)

    dzdx = 2*self.gamma*rsqrt**3
    log_det = jnp.log(dzdx).sum()

    if inverse:
      log_det = -log_det

    return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, gamma: Optional[float] = 0.5, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             gamma: Optional[float] = 0.5,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.gamma = gamma
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  if inverse == False:
    rsqrt = jax.lax.rsqrt(x**2 + 4*self.gamma)
    z = 0.5*(1 + x*rsqrt)
  else:
    arg = 2*x - 1
    z = 2*jnp.sqrt(self.gamma)*arg*jax.lax.rsqrt(1 - arg**2)
    rsqrt = jax.lax.rsqrt(z**2 + 4*self.gamma)

  dzdx = 2*self.gamma*rsqrt**3
  log_det = jnp.log(dzdx).sum()

  if inverse:
    log_det = -log_det

  return z, log_det

generax.flows.nonlinearities.SLog (BijectiveTransform) ¤

https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf

Source code in generax/flows/nonlinearities.py
class SLog(BijectiveTransform):
  """ https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
  """

  alpha: Union[float,None]

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               alpha: Optional[float] = 0.0,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.alpha = alpha

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    # Bound alpha to be positive
    alpha = misc.square_plus(self.alpha) + 1e-4

    if inverse == False:
      log_det = jnp.log1p(alpha*jnp.abs(x))
      z = jnp.sign(x)/alpha*log_det
    else:
      z = jnp.sign(x)/alpha*(jnp.exp(alpha*jnp.abs(x)) - 1)
      log_det = jnp.log1p(alpha*jnp.abs(z))

    log_det = -log_det.sum()

    if inverse:
      log_det = -log_det
    return z, log_det
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, alpha: Optional[float] = 0.0, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             alpha: Optional[float] = 0.0,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.alpha = alpha
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  # Bound alpha to be positive
  alpha = misc.square_plus(self.alpha) + 1e-4

  if inverse == False:
    log_det = jnp.log1p(alpha*jnp.abs(x))
    z = jnp.sign(x)/alpha*log_det
  else:
    z = jnp.sign(x)/alpha*(jnp.exp(alpha*jnp.abs(x)) - 1)
    log_det = jnp.log1p(alpha*jnp.abs(z))

  log_det = -log_det.sum()

  if inverse:
    log_det = -log_det
  return z, log_det

generax.flows.nonlinearities.CartesianToSpherical (BijectiveTransform) ¤

CartesianToSpherical(args, *kwargs)

Source code in generax/flows/nonlinearities.py
class CartesianToSpherical(BijectiveTransform):

  def __init__(self,
               input_shape: Tuple[int],
               *,
               key: PRNGKeyArray = None,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)

  def forward_fun(self, x, eps=1e-5):
    r = jnp.linalg.norm(x)
    denominators = jnp.sqrt(jnp.cumsum(x[::-1]**2)[::-1])[:-1]
    cos_phi = x[:-1]/denominators

    cos_phi = jnp.maximum(-1.0 + eps, cos_phi)
    cos_phi = jnp.minimum(1.0 - eps, cos_phi)
    phi = jnp.arccos(cos_phi)

    last_value = jnp.where(x[-1] >= 0, phi[-1], 2*jnp.pi - phi[-1])
    phi = phi.at[-1].set(last_value)

    return jnp.concatenate([r[None], phi])

  def inverse_fun(self, x):
    r = x[:1]
    phi = x[1:]
    sin_prod = jnp.cumprod(jnp.sin(phi))
    first_part = jnp.concatenate([jnp.ones(r.shape), sin_prod])
    second_part = jnp.concatenate([jnp.cos(phi), jnp.ones(r.shape)])
    return r*first_part*second_part

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    @partial(jnp.vectorize, signature='(n)->(n),()')
    def unbatched_apply(x):
      def _forward(x):
        z = self.forward_fun(x)
        r, phi = z[0], z[1:]
        return z, r, phi

      def _inverse(x):
        z = self.inverse_fun(x)
        r, phi = x[0], x[1:]
        return z, r, phi

      if inverse == False:
        z, r, phi = _forward(x)
      else:
        z, r, phi = _inverse(x)

      n = x.shape[-1]
      n_range = jnp.arange(n - 2, -1, -1)
      log_abs_sin_phi = jnp.log(jnp.abs(jnp.sin(phi)))
      log_det = -(n - 1)*jnp.log(r) - jnp.sum(n_range*log_abs_sin_phi, axis=-1)
      log_det = log_det.sum()

      if inverse:
        log_det = -log_det
      return z, log_det

    z, log_det = unbatched_apply(x)
    return z, log_det.sum()
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray = None, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
Source code in generax/flows/nonlinearities.py
def __init__(self,
             input_shape: Tuple[int],
             *,
             key: PRNGKeyArray = None,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/nonlinearities.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  @partial(jnp.vectorize, signature='(n)->(n),()')
  def unbatched_apply(x):
    def _forward(x):
      z = self.forward_fun(x)
      r, phi = z[0], z[1:]
      return z, r, phi

    def _inverse(x):
      z = self.inverse_fun(x)
      r, phi = x[0], x[1:]
      return z, r, phi

    if inverse == False:
      z, r, phi = _forward(x)
    else:
      z, r, phi = _inverse(x)

    n = x.shape[-1]
    n_range = jnp.arange(n - 2, -1, -1)
    log_abs_sin_phi = jnp.log(jnp.abs(jnp.sin(phi)))
    log_det = -(n - 1)*jnp.log(r) - jnp.sum(n_range*log_abs_sin_phi, axis=-1)
    log_det = log_det.sum()

    if inverse:
      log_det = -log_det
    return z, log_det

  z, log_det = unbatched_apply(x)
  return z, log_det.sum()

generax.flows.logistic_cdf_mixture_logit.LogisticCDFMixtureLogit (BijectiveTransform) ¤

Used in Flow++ https://arxiv.org/pdf/1902.00275.pdf This is a logistic CDF mixture model followed by a logit.

Attributes: - theta: The parameters of the transformation.

Source code in generax/flows/logistic_cdf_mixture_logit.py
class LogisticCDFMixtureLogit(BijectiveTransform):
  """Used in Flow++ https://arxiv.org/pdf/1902.00275.pdf
  This is a logistic CDF mixture model followed by a logit.

  **Attributes**:
  - `theta`: The parameters of the transformation.
  """

  theta: Array

  K: int = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               K: int = 8,
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `K`: The number of knots to use.
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.K = K

    x_dim = util.list_prod(input_shape)
    self.theta = random.normal(key, shape=(x_dim*(3*self.K),))*0.1

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    # Flatten x
    x = x.ravel()

    theta = self.theta.reshape(x.shape + (3*self.K,))

    # Split the parameters
    weight_logits, means, scales = theta[...,:self.K], theta[...,self.K:2*self.K], theta[...,2*self.K:]
    scales = misc.square_plus(scales, gamma=1.0) + 1e-4

    # Create the jvp function that we'll need
    def f_and_df(x, *args):
      primals = weight_logits, means, scales, x
      tangents = jax.tree_util.tree_map(jnp.zeros_like, primals[:-1]) + (jnp.ones_like(x),)
      return jax.jvp(logistic_cdf_mixture_logit, primals, tangents)

    if inverse == False:
      # Only need a single pass
      z, dzdx = f_and_df(x)
    else:
      # Invert with bisection method.
      f = lambda x, *args: f_and_df(x, *args)[0]
      lower, upper = -1000.0, 1000.0
      lower, upper = jnp.broadcast_to(lower, x.shape), jnp.broadcast_to(upper, x.shape)
      z = util.bisection(f, lower, upper, x)
      reconstr, dzdx = f_and_df(z)
    ew_log_det = jnp.log(dzdx)

    log_det = ew_log_det.sum()

    if inverse:
      log_det *= -1

    # Unflatten the output
    z = z.reshape(self.input_shape)
    return z, log_det
__init__(self, input_shape: Tuple[int], K: int = 8, *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
  • K: The number of knots to use.
Source code in generax/flows/logistic_cdf_mixture_logit.py
def __init__(self,
             input_shape: Tuple[int],
             K: int = 8,
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `K`: The number of knots to use.
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.K = K

  x_dim = util.list_prod(input_shape)
  self.theta = random.normal(key, shape=(x_dim*(3*self.K),))*0.1
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/logistic_cdf_mixture_logit.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  # Flatten x
  x = x.ravel()

  theta = self.theta.reshape(x.shape + (3*self.K,))

  # Split the parameters
  weight_logits, means, scales = theta[...,:self.K], theta[...,self.K:2*self.K], theta[...,2*self.K:]
  scales = misc.square_plus(scales, gamma=1.0) + 1e-4

  # Create the jvp function that we'll need
  def f_and_df(x, *args):
    primals = weight_logits, means, scales, x
    tangents = jax.tree_util.tree_map(jnp.zeros_like, primals[:-1]) + (jnp.ones_like(x),)
    return jax.jvp(logistic_cdf_mixture_logit, primals, tangents)

  if inverse == False:
    # Only need a single pass
    z, dzdx = f_and_df(x)
  else:
    # Invert with bisection method.
    f = lambda x, *args: f_and_df(x, *args)[0]
    lower, upper = -1000.0, 1000.0
    lower, upper = jnp.broadcast_to(lower, x.shape), jnp.broadcast_to(upper, x.shape)
    z = util.bisection(f, lower, upper, x)
    reconstr, dzdx = f_and_df(z)
  ew_log_det = jnp.log(dzdx)

  log_det = ew_log_det.sum()

  if inverse:
    log_det *= -1

  # Unflatten the output
  z = z.reshape(self.input_shape)
  return z, log_det

generax.flows.spline.RationalQuadraticSpline (BijectiveTransform) ¤

Splines from https://arxiv.org/pdf/1906.04032.pdf. This is the best overall choice to use in flows.

Attributes: - theta: The parameters of the spline.

Source code in generax/flows/spline.py
class RationalQuadraticSpline(BijectiveTransform):
  """Splines from https://arxiv.org/pdf/1906.04032.pdf.  This is the best overall choice to use in flows.


  **Attributes**:
  - `theta`: The parameters of the spline.
  """

  theta: Array

  K: int = eqx.field(static=True)
  min_width: float = eqx.field(static=True)
  min_height: float = eqx.field(static=True)
  min_derivative: float = eqx.field(static=True)
  bounds: Sequence[float] = eqx.field(static=True)

  def __init__(self,
               input_shape: Tuple[int],
               K: int = 8,
               min_width: Optional[float] = 1e-3,
               min_height: Optional[float] = 1e-3,
               min_derivative: Optional[float] = 1e-8,
               bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)),
               *,
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The input shape.  Output size is the same as shape.
    - `key`: A `jax.random.PRNGKey` for initialization
    - `K`: The number of knots to use.
    - `min_width`: The minimum width of the knots.
    - `min_height`: The minimum height of the knots.
    - `min_derivative`: The minimum derivative of the knots.
    - `bounds`: The bounds of the splines.
    """
    super().__init__(input_shape=input_shape,
                     **kwargs)
    self.K = K
    self.min_width = min_width
    self.min_height = min_height
    self.min_derivative = min_derivative
    self.bounds = bounds

    x_dim = util.list_prod(input_shape)
    self.theta = random.normal(key, shape=(x_dim*(3*self.K - 1),))*0.1

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool=False,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'

    # Flatten x
    x = x.ravel()

    # Get the parameters
    settings = self.K, self.min_width, self.min_height, self.min_derivative, self.bounds
    theta = jnp.broadcast_to(self.theta, x.shape + self.theta.shape)
    knot_x, knot_y, knot_derivs = get_knot_params(settings, theta)

    # The relevant knot depends on if we are inverting or not
    if inverse == False:
      mask = (x > self.bounds[0][0] + 1e-5) & (x < self.bounds[0][1] - 1e-5)
      apply_fun = forward_spline
    else:
      mask = (x > self.bounds[1][0] + 1e-5) & (x < self.bounds[1][1] - 1e-5)
      apply_fun = inverse_spline

    args = find_knots(x, knot_x, knot_y, knot_derivs, inverse)

    z, dzdx = apply_fun(x, mask, *args)
    elementwise_log_det = jnp.log(dzdx)

    log_det = elementwise_log_det.sum()
    if inverse:
      log_det = -log_det

    # Unflatten the output
    z = z.reshape(self.input_shape)

    return z, log_det
__init__(self, input_shape: Tuple[int], K: int = 8, min_width: Optional[float] = 0.001, min_height: Optional[float] = 0.001, min_derivative: Optional[float] = 1e-08, bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)), *, key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The input shape. Output size is the same as shape.
  • key: A jax.random.PRNGKey for initialization
  • K: The number of knots to use.
  • min_width: The minimum width of the knots.
  • min_height: The minimum height of the knots.
  • min_derivative: The minimum derivative of the knots.
  • bounds: The bounds of the splines.
Source code in generax/flows/spline.py
def __init__(self,
             input_shape: Tuple[int],
             K: int = 8,
             min_width: Optional[float] = 1e-3,
             min_height: Optional[float] = 1e-3,
             min_derivative: Optional[float] = 1e-8,
             bounds: Sequence[float] = ((-10.0, 10.0), (-10.0, 10.0)),
             *,
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The input shape.  Output size is the same as shape.
  - `key`: A `jax.random.PRNGKey` for initialization
  - `K`: The number of knots to use.
  - `min_width`: The minimum width of the knots.
  - `min_height`: The minimum height of the knots.
  - `min_derivative`: The minimum derivative of the knots.
  - `bounds`: The bounds of the splines.
  """
  super().__init__(input_shape=input_shape,
                   **kwargs)
  self.K = K
  self.min_width = min_width
  self.min_height = min_height
  self.min_derivative = min_derivative
  self.bounds = bounds

  x_dim = util.list_prod(input_shape)
  self.theta = random.normal(key, shape=(x_dim*(3*self.K - 1),))*0.1
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/spline.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool=False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'

  # Flatten x
  x = x.ravel()

  # Get the parameters
  settings = self.K, self.min_width, self.min_height, self.min_derivative, self.bounds
  theta = jnp.broadcast_to(self.theta, x.shape + self.theta.shape)
  knot_x, knot_y, knot_derivs = get_knot_params(settings, theta)

  # The relevant knot depends on if we are inverting or not
  if inverse == False:
    mask = (x > self.bounds[0][0] + 1e-5) & (x < self.bounds[0][1] - 1e-5)
    apply_fun = forward_spline
  else:
    mask = (x > self.bounds[1][0] + 1e-5) & (x < self.bounds[1][1] - 1e-5)
    apply_fun = inverse_spline

  args = find_knots(x, knot_x, knot_y, knot_derivs, inverse)

  z, dzdx = apply_fun(x, mask, *args)
  elementwise_log_det = jnp.log(dzdx)

  log_det = elementwise_log_det.sum()
  if inverse:
    log_det = -log_det

  # Unflatten the output
  z = z.reshape(self.input_shape)

  return z, log_det