Skip to content

Models¤

generax.flows.models.GeneralTransform (Repeat) ¤

GeneralTransform(args, *kwargs)

Source code in generax/flows/models.py
class GeneralTransform(Repeat):

  def __init__(self,
               TransformType: type,
               input_shape: Tuple[int],
               n_flow_layers: int = 3,
               working_size: int = 16,
               hidden_size: int = 32,
               n_blocks: int = 4,
               filter_shape: Optional[Tuple[int]] = (3, 3),
               cond_shape: Optional[Tuple[int]] = None,
               coupling_split_dim: Optional[int] = None,
               reverse_conditioning: Optional[bool] = False,
               create_net: Optional[Callable[[PRNGKeyArray], Any]] = None,
               *,
               key: PRNGKeyArray,
               **kwargs):

    def init_transform(transform_input_shape, key):
      return TransformType(input_shape=transform_input_shape,
                           cond_shape=cond_shape,
                           key=key)

    def _create_net(net_input_shape, net_output_size, key):
      return ResNet(input_shape=net_input_shape,
                    working_size=working_size,
                    hidden_size=hidden_size,
                    out_size=net_output_size,
                    n_blocks=n_blocks,
                    filter_shape=filter_shape,
                    cond_shape=cond_shape,
                    key=key)
    create_net = create_net if create_net is not None else _create_net

    def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
      k1, k2, k3 = random.split(key, 3)

      layers = []
      layer = Coupling(init_transform,
                       create_net,
                       input_shape=input_shape,
                       cond_shape=cond_shape,
                       split_dim=coupling_split_dim,
                       reverse_conditioning=reverse_conditioning,
                       key=k1)
      layers.append(layer)
      layers.append(PLUAffine(input_shape=input_shape,
                              cond_shape=cond_shape,
                              key=k2))
      layers.append(ShiftScale(input_shape=input_shape,
                               cond_shape=cond_shape,
                               key=k3))
      return Sequential(*layers, **kwargs)

    super().__init__(make_single_flow_layer, n_flow_layers, key=key)
__init__(self, TransformType: type, input_shape: Tuple[int], n_flow_layers: int = 3, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, filter_shape: Optional[Tuple[int]] = (3, 3), cond_shape: Optional[Tuple[int]] = None, coupling_split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, create_net: Optional[Callable[[PRNGKeyArray], Any]] = None, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             TransformType: type,
             input_shape: Tuple[int],
             n_flow_layers: int = 3,
             working_size: int = 16,
             hidden_size: int = 32,
             n_blocks: int = 4,
             filter_shape: Optional[Tuple[int]] = (3, 3),
             cond_shape: Optional[Tuple[int]] = None,
             coupling_split_dim: Optional[int] = None,
             reverse_conditioning: Optional[bool] = False,
             create_net: Optional[Callable[[PRNGKeyArray], Any]] = None,
             *,
             key: PRNGKeyArray,
             **kwargs):

  def init_transform(transform_input_shape, key):
    return TransformType(input_shape=transform_input_shape,
                         cond_shape=cond_shape,
                         key=key)

  def _create_net(net_input_shape, net_output_size, key):
    return ResNet(input_shape=net_input_shape,
                  working_size=working_size,
                  hidden_size=hidden_size,
                  out_size=net_output_size,
                  n_blocks=n_blocks,
                  filter_shape=filter_shape,
                  cond_shape=cond_shape,
                  key=key)
  create_net = create_net if create_net is not None else _create_net

  def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
    k1, k2, k3 = random.split(key, 3)

    layers = []
    layer = Coupling(init_transform,
                     create_net,
                     input_shape=input_shape,
                     cond_shape=cond_shape,
                     split_dim=coupling_split_dim,
                     reverse_conditioning=reverse_conditioning,
                     key=k1)
    layers.append(layer)
    layers.append(PLUAffine(input_shape=input_shape,
                            cond_shape=cond_shape,
                            key=k2))
    layers.append(ShiftScale(input_shape=input_shape,
                             cond_shape=cond_shape,
                             key=k3))
    return Sequential(*layers, **kwargs)

  super().__init__(make_single_flow_layer, n_flow_layers, key=key)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Inherited from generax.flows.base.Repeat.data_dependent_init.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Inherited from generax.flows.base.Repeat.__call__.

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.NICETransform (GeneralTransform) ¤

NICETransform(args, *kwargs)

Source code in generax/flows/models.py
class NICETransform(GeneralTransform):
  def __init__(self,
               *args,
               **kwargs):
    super().__init__(TransformType=Shift,
                     *args,
                     **kwargs)
__init__(self, *args, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             **kwargs):
  super().__init__(TransformType=Shift,
                   *args,
                   **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.RealNVPTransform (GeneralTransform) ¤

RealNVPTransform(args, *kwargs)

Source code in generax/flows/models.py
class RealNVPTransform(GeneralTransform):
  def __init__(self,
               *args,
               **kwargs):
    super().__init__(TransformType=ShiftScale,
                     *args,
                     **kwargs)
__init__(self, *args, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             **kwargs):
  super().__init__(TransformType=ShiftScale,
                   *args,
                   **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.NeuralSplineTransform (GeneralTransform) ¤

NeuralSplineTransform(args, *kwargs)

Source code in generax/flows/models.py
class NeuralSplineTransform(GeneralTransform):
  def __init__(self,
               *args,
               n_spline_knots: int = 8,
               **kwargs):
    super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
                     *args,
                     **kwargs)
__init__(self, *args, *, n_spline_knots: int = 8, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             n_spline_knots: int = 8,
             **kwargs):
  super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
                   *args,
                   **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.GeneralImageTransform (Repeat) ¤

GeneralImageTransform(args, *kwargs)

Source code in generax/flows/models.py
class GeneralImageTransform(Repeat):

  def __init__(self,
               TransformType: type,
               input_shape: Tuple[int],
               n_flow_layers: int = 3,
               cond_shape: Optional[Tuple[int]] = None,
               unet: Optional[bool] = True,
               coupling_split_dim: Optional[int] = None,
               reverse_conditioning: Optional[bool] = False,
               *,
               key: PRNGKeyArray,
               **kwargs):

    def init_transform(transform_input_shape, key):
      return TransformType(input_shape=transform_input_shape,
                           cond_shape=cond_shape,
                           key=key)

    if unet:
      def create_net(net_input_shape, net_output_size, key):
        H, W, C = net_input_shape
        return UNet(input_shape=net_input_shape,
                      dim=kwargs.pop('dim', 32),
                      out_channels=net_output_size//(H*W),
                      dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
                      resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
                      attn_heads=kwargs.pop('attn_heads', 4),
                      attn_dim_head=kwargs.pop('attn_dim_head', 32),
                      cond_shape=cond_shape,
                      time_dependent=False,
                      key=key)
    else:
      def create_net(net_input_shape, net_output_size, key):
        return Encoder(input_shape=net_input_shape,
                      dim=kwargs.pop('dim', 32),
                      dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
                      resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
                      attn_heads=kwargs.pop('attn_heads', 4),
                      attn_dim_head=kwargs.pop('attn_dim_head', 32),
                      out_size=net_output_size,
                      cond_shape=cond_shape,
                      key=key)

    def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
      k1, k2, k3 = random.split(key, 3)

      layers = []
      layer = Coupling(init_transform,
                       create_net,
                       input_shape=input_shape,
                       cond_shape=cond_shape,
                       split_dim=coupling_split_dim,
                       reverse_conditioning=reverse_conditioning,
                       key=k1)
      layers.append(layer)
      layers.append(OneByOneConv(input_shape=input_shape,
                                 key=k2))
      layers.append(ShiftScale(input_shape=input_shape,
                               cond_shape=cond_shape,
                               key=k3))
      return Sequential(*layers, **kwargs)

    super().__init__(make_single_flow_layer, n_flow_layers, key=key)
__init__(self, TransformType: type, input_shape: Tuple[int], n_flow_layers: int = 3, cond_shape: Optional[Tuple[int]] = None, unet: Optional[bool] = True, coupling_split_dim: Optional[int] = None, reverse_conditioning: Optional[bool] = False, *, key: PRNGKeyArray, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             TransformType: type,
             input_shape: Tuple[int],
             n_flow_layers: int = 3,
             cond_shape: Optional[Tuple[int]] = None,
             unet: Optional[bool] = True,
             coupling_split_dim: Optional[int] = None,
             reverse_conditioning: Optional[bool] = False,
             *,
             key: PRNGKeyArray,
             **kwargs):

  def init_transform(transform_input_shape, key):
    return TransformType(input_shape=transform_input_shape,
                         cond_shape=cond_shape,
                         key=key)

  if unet:
    def create_net(net_input_shape, net_output_size, key):
      H, W, C = net_input_shape
      return UNet(input_shape=net_input_shape,
                    dim=kwargs.pop('dim', 32),
                    out_channels=net_output_size//(H*W),
                    dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
                    resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
                    attn_heads=kwargs.pop('attn_heads', 4),
                    attn_dim_head=kwargs.pop('attn_dim_head', 32),
                    cond_shape=cond_shape,
                    time_dependent=False,
                    key=key)
  else:
    def create_net(net_input_shape, net_output_size, key):
      return Encoder(input_shape=net_input_shape,
                    dim=kwargs.pop('dim', 32),
                    dim_mults=kwargs.pop('dim_mults', (1, 2, 4)),
                    resnet_block_groups=kwargs.pop('resnet_block_groups', 8),
                    attn_heads=kwargs.pop('attn_heads', 4),
                    attn_dim_head=kwargs.pop('attn_dim_head', 32),
                    out_size=net_output_size,
                    cond_shape=cond_shape,
                    key=key)

  def make_single_flow_layer(key: PRNGKeyArray) -> Sequential:
    k1, k2, k3 = random.split(key, 3)

    layers = []
    layer = Coupling(init_transform,
                     create_net,
                     input_shape=input_shape,
                     cond_shape=cond_shape,
                     split_dim=coupling_split_dim,
                     reverse_conditioning=reverse_conditioning,
                     key=k1)
    layers.append(layer)
    layers.append(OneByOneConv(input_shape=input_shape,
                               key=k2))
    layers.append(ShiftScale(input_shape=input_shape,
                             cond_shape=cond_shape,
                             key=k3))
    return Sequential(*layers, **kwargs)

  super().__init__(make_single_flow_layer, n_flow_layers, key=key)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Inherited from generax.flows.base.Repeat.data_dependent_init.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Inherited from generax.flows.base.Repeat.__call__.

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.NICEImageTransform (GeneralImageTransform) ¤

NICEImageTransform(args, *kwargs)

Source code in generax/flows/models.py
class NICEImageTransform(GeneralImageTransform):
  def __init__(self,
               *args,
               **kwargs):
    super().__init__(TransformType=Shift,
                     *args,
                     **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             **kwargs):
  super().__init__(TransformType=Shift,
                   *args,
                   **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.RealNVPImageTransform (GeneralImageTransform) ¤

RealNVPImageTransform(args, *kwargs)

Source code in generax/flows/models.py
class RealNVPImageTransform(GeneralImageTransform):
  def __init__(self,
               *args,
               **kwargs):
    super().__init__(TransformType=ShiftScale,
                     *args,
                     **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             **kwargs):
  super().__init__(TransformType=ShiftScale,
                   *args,
                   **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()

generax.flows.models.NeuralSplineImageTransform (GeneralImageTransform) ¤

NeuralSplineImageTransform(args, *kwargs)

Source code in generax/flows/models.py
class NeuralSplineImageTransform(GeneralImageTransform):
  def __init__(self,
               *args,
               n_spline_knots: int = 8,
               **kwargs):
    super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
                     *args,
                     **kwargs)
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Apply the inverse transformation.

Arguments:

  • x: The input to the transformation
  • y: The conditioning information

Returns: (z, log_det)

Source code in generax/flows/models.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)
__init__(self, *args, *, n_spline_knots: int = 8, **kwargs) ¤
Source code in generax/flows/models.py
def __init__(self,
             *args,
             n_spline_knots: int = 8,
             **kwargs):
  super().__init__(TransformType=partial(RationalQuadraticSpline, K=n_spline_knots),
                   *args,
                   **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> BijectiveTransform ¤

Initialize the parameters of the layer based on the data.

Arguments:

  • x: The data to initialize the parameters with.
  • y: The conditioning information
  • key: A jax.random.PRNGKey for initialization

Returns: A new layer with the parameters initialized.

Source code in generax/flows/models.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None) -> BijectiveTransform:
  seq = self.to_sequential()

  # Apply the data dependent initalization
  out_seq_layers = seq.data_dependent_init(x, y=y, key=key)

  # Turn the sequential layers into a repeat layer
  all_params = []
  for i, layer in enumerate(out_seq_layers):
    params, _ = eqx.partition(layer, eqx.is_array)
    all_params.append(params)

  # Combine the parameters back into a single layer
  params = jax.tree_util.tree_map(lambda *args: jnp.array(args), *all_params)
  _, static = eqx.partition(self.layers, eqx.is_array)
  layers = eqx.combine(params, static)

  get_layers = lambda tree: tree.layers
  updated_module = eqx.tree_at(get_layers, self, layers)
  return updated_module
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation

Returns: (z, log_det)

Source code in generax/flows/models.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation

  **Returns**:
  (z, log_det)
  """
  dynamic, static = eqx.partition(self.layers, eqx.is_array)

  def scan_body(x, params):
    block = eqx.combine(params, static)
    x, log_det = block(x, y=y, inverse=inverse, **kwargs)
    return x, log_det

  x, log_dets = jax.lax.scan(scan_body, x, dynamic, reverse=inverse)
  return x, log_dets.sum()