Skip to content

FFJORD¤

generax.flows.ffjord.FFJORDTransform (BijectiveTransform) ¤

Flow parametrized by a neural ODE https://arxiv.org/pdf/1810.01367.pdf

Attributes:

  • input_shape: The input shape. Output shape will have the same dimensionality as the input.
  • cond_shape: The shape of the conditioning information. If there is no conditioning information, this is None.
  • neural_ode: The neural ODE
  • adjoint: The adjoint method to use. Can be one of the following:
  • key: The random key to use for initialization
Source code in generax/flows/ffjord.py
class FFJORDTransform(BijectiveTransform):
  """Flow parametrized by a neural ODE https://arxiv.org/pdf/1810.01367.pdf

  **Attributes**:

  - `input_shape`: The input shape.  Output shape will have the same dimensionality
                  as the input.
  - `cond_shape`: The shape of the conditioning information.  If there is no
                  conditioning information, this is None.
  - `neural_ode`: The neural ODE
  - `adjoint`: The adjoint method to use.  Can be one of the following:
  - `key`: The random key to use for initialization
  """

  neural_ode: NeuralODE

  def __init__(self,
               input_shape: Tuple[int],
               net: eqx.Module = None,
               working_size: int = 16,
               hidden_size: int = 32,
               n_blocks: int = 4,
               time_embedding_size = 16,
               n_time_features = 8,
               cond_shape: Optional[Tuple[int]] = None,
               *,
               controller_rtol: Optional[float] = 1e-8,
               controller_atol: Optional[float] = 1e-8,
               adjoint='recursive_checkpoint',
               key: PRNGKeyArray,
               **kwargs):
    """**Arguments**:

    - `input_shape`: The shape of the input to the transformation
    - `net`: The neural network to use for the vector field.  If None, a default
              network will be used.  `net` should accept `net(t, x, y=y)`
    - `controller_rtol`: The relative tolerance of the stepsize controller.
    - `controller_atol`: The absolute tolerance of the stepsize controller.
    - `trace_estimate_likelihood`: Whether to use a trace estimate for the likelihood.
    - `adjoint`: The adjoint method to use.  Can be one of the following:
        - `"recursive_checkpoint"`: Use the recursive checkpoint method.  Doesn't support jvp.
        - `"direct"`: Use the direct method.  Supports jvps.
        - `"seminorm"`: Use the seminorm method.  Does fast backprop through the solver.
    - `key`: The random key to use for initialization
    """

    if net is None:
      net = TimeDependentResNet(input_shape=input_shape,
                            working_size=working_size,
                            hidden_size=hidden_size,
                            out_size=input_shape[-1],
                            n_blocks=n_blocks,
                            cond_shape=cond_shape,
                            embedding_size=time_embedding_size,
                            out_features=n_time_features,
                            key=key)

    self.neural_ode = NeuralODE(vf=net,
                                adjoint=adjoint,
                                controller_rtol=controller_rtol,
                                controller_atol=controller_atol)

    super().__init__(input_shape=input_shape,
                     **kwargs)

  @property
  def vector_field(self):
    return self.neural_ode.vector_field

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               inverse: bool = False,
               log_likelihood: bool = True,
               trace_estimate_likelihood: Optional[bool] = False,
               save_at: Optional[Array] = None,
               key: Optional[PRNGKeyArray] = None,
               **kwargs) -> Array:
    """**Arguments**:

    - `x`: The input to the transformation
    - `y`: The conditioning information
    - `inverse`: Whether to inverse the transformation
    - `log_likelihood`: Whether to compute the log likelihood of the transformation
    - `trace_estimate_likelihood`: Whether to compute a trace estimate of the likelihood of the neural ODE.
    - `save_at`: The times to save the neural ODE at.
    - `key`: The random key to use for initialization

    **Returns**:
    `(z, log_det)`
    """
    assert x.shape == self.input_shape, 'Only works on unbatched data'
    if log_likelihood and (trace_estimate_likelihood and (key is None)):
      raise TypeError(f'When using trace estimation, must pass random key')

    if log_likelihood == False:
      trace_estimate_likelihood = False

    solution = self.neural_ode(x,
                                 y=y,
                                 inverse=inverse,
                                 log_likelihood=log_likelihood,
                                 trace_estimate_likelihood=trace_estimate_likelihood,
                                 save_at=save_at,
                                 key=key,
                                 **kwargs)
    return solution.ys, solution.log_det
neural_ode: NeuralODE dataclass-field ¤
vector_field property readonly ¤
__init__(self, input_shape: Tuple[int], net: Module = None, working_size: int = 16, hidden_size: int = 32, n_blocks: int = 4, time_embedding_size = 16, n_time_features = 8, cond_shape: Optional[Tuple[int]] = None, *, controller_rtol: Optional[float] = 1e-08, controller_atol: Optional[float] = 1e-08, adjoint = 'recursive_checkpoint', key: PRNGKeyArray, **kwargs) ¤

Arguments:

  • input_shape: The shape of the input to the transformation
  • net: The neural network to use for the vector field. If None, a default network will be used. net should accept net(t, x, y=y)
  • controller_rtol: The relative tolerance of the stepsize controller.
  • controller_atol: The absolute tolerance of the stepsize controller.
  • trace_estimate_likelihood: Whether to use a trace estimate for the likelihood.
  • adjoint: The adjoint method to use. Can be one of the following:
    • "recursive_checkpoint": Use the recursive checkpoint method. Doesn't support jvp.
    • "direct": Use the direct method. Supports jvps.
    • "seminorm": Use the seminorm method. Does fast backprop through the solver.
  • key: The random key to use for initialization
Source code in generax/flows/ffjord.py
def __init__(self,
             input_shape: Tuple[int],
             net: eqx.Module = None,
             working_size: int = 16,
             hidden_size: int = 32,
             n_blocks: int = 4,
             time_embedding_size = 16,
             n_time_features = 8,
             cond_shape: Optional[Tuple[int]] = None,
             *,
             controller_rtol: Optional[float] = 1e-8,
             controller_atol: Optional[float] = 1e-8,
             adjoint='recursive_checkpoint',
             key: PRNGKeyArray,
             **kwargs):
  """**Arguments**:

  - `input_shape`: The shape of the input to the transformation
  - `net`: The neural network to use for the vector field.  If None, a default
            network will be used.  `net` should accept `net(t, x, y=y)`
  - `controller_rtol`: The relative tolerance of the stepsize controller.
  - `controller_atol`: The absolute tolerance of the stepsize controller.
  - `trace_estimate_likelihood`: Whether to use a trace estimate for the likelihood.
  - `adjoint`: The adjoint method to use.  Can be one of the following:
      - `"recursive_checkpoint"`: Use the recursive checkpoint method.  Doesn't support jvp.
      - `"direct"`: Use the direct method.  Supports jvps.
      - `"seminorm"`: Use the seminorm method.  Does fast backprop through the solver.
  - `key`: The random key to use for initialization
  """

  if net is None:
    net = TimeDependentResNet(input_shape=input_shape,
                          working_size=working_size,
                          hidden_size=hidden_size,
                          out_size=input_shape[-1],
                          n_blocks=n_blocks,
                          cond_shape=cond_shape,
                          embedding_size=time_embedding_size,
                          out_features=n_time_features,
                          key=key)

  self.neural_ode = NeuralODE(vf=net,
                              adjoint=adjoint,
                              controller_rtol=controller_rtol,
                              controller_atol=controller_atol)

  super().__init__(input_shape=input_shape,
                   **kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) ¤

Inherited from generax.flows.base.BijectiveTransform.data_dependent_init.

Source code in generax/flows/ffjord.py
def data_dependent_init(self,
                        x: Array,
                        y: Optional[Array] = None,
                        key: PRNGKeyArray = None):
  """Initialize the parameters of the layer based on the data.

  **Arguments**:

  - `x`: The data to initialize the parameters with.
  - `y`: The conditioning information
  - `key`: A `jax.random.PRNGKey` for initialization

  **Returns**:
  A new layer with the parameters initialized.
  """
  return self
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, log_likelihood: bool = True, trace_estimate_likelihood: Optional[bool] = False, save_at: Optional[Array] = None, key: Optional[PRNGKeyArray] = None, **kwargs) -> Array ¤

Arguments:

  • x: The input to the transformation
  • y: The conditioning information
  • inverse: Whether to inverse the transformation
  • log_likelihood: Whether to compute the log likelihood of the transformation
  • trace_estimate_likelihood: Whether to compute a trace estimate of the likelihood of the neural ODE.
  • save_at: The times to save the neural ODE at.
  • key: The random key to use for initialization

Returns: (z, log_det)

Source code in generax/flows/ffjord.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             inverse: bool = False,
             log_likelihood: bool = True,
             trace_estimate_likelihood: Optional[bool] = False,
             save_at: Optional[Array] = None,
             key: Optional[PRNGKeyArray] = None,
             **kwargs) -> Array:
  """**Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information
  - `inverse`: Whether to inverse the transformation
  - `log_likelihood`: Whether to compute the log likelihood of the transformation
  - `trace_estimate_likelihood`: Whether to compute a trace estimate of the likelihood of the neural ODE.
  - `save_at`: The times to save the neural ODE at.
  - `key`: The random key to use for initialization

  **Returns**:
  `(z, log_det)`
  """
  assert x.shape == self.input_shape, 'Only works on unbatched data'
  if log_likelihood and (trace_estimate_likelihood and (key is None)):
    raise TypeError(f'When using trace estimation, must pass random key')

  if log_likelihood == False:
    trace_estimate_likelihood = False

  solution = self.neural_ode(x,
                               y=y,
                               inverse=inverse,
                               log_likelihood=log_likelihood,
                               trace_estimate_likelihood=trace_estimate_likelihood,
                               save_at=save_at,
                               key=key,
                               **kwargs)
  return solution.ys, solution.log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array ¤

Inherited from generax.flows.base.BijectiveTransform.inverse.

Source code in generax/flows/ffjord.py
def inverse(self,
            x: Array,
            y: Optional[Array] = None,
            **kwargs) -> Array:
  """Apply the inverse transformation.

  **Arguments**:

  - `x`: The input to the transformation
  - `y`: The conditioning information

  **Returns**:
  (z, log_det)
  """
  return self(x, y=y, inverse=True, **kwargs)

generax.nn.neural_ode.NeuralODE ¤

Neural ODE

Source code in generax/nn/neural_ode.py
class NeuralODE(eqx.Module):
  """Neural ODE"""

  vector_field: eqx.Module
  adjoint: diffrax.AbstractAdjoint
  stepsize_controller: diffrax.AbstractAdaptiveStepSizeController

  def __init__(self,
               vf: eqx.Module,
               adjoint: Optional[str] = 'recursive_checkpoint',
               controller_rtol: Optional[float] = 1e-3,
               controller_atol: Optional[float] = 1e-5,
  ):
    """**Arguments**:

    - `vf`: A function that computes the vector field.  It must output
            a vector of the same shape as its input.
    - `adjoint`: The adjoint method to use.  Can be one of the following:
       - `"recursive_checkpoint"`: Use the recursive checkpoint method.  Doesn't support jvp.
       - `"direct"`: Use the direct method.  Supports jvps.
       - `"seminorm"`: Use the seminorm method.  Does fast backprop through the solver.
    - `controller_rtol`: The relative tolerance of the stepsize controller.
    - `controller_atol`: The absolute tolerance of the stepsize controller.
    """
    self.vector_field = vf

    if adjoint == 'recursive_checkpoint':
      self.adjoint = diffrax.RecursiveCheckpointAdjoint()
    elif adjoint == 'direct':
      self.adjoint = diffrax.DirectAdjoint()
    elif adjoint == 'seminorm':
      adjoint_controller = diffrax.PIDController(
          rtol=1e-3, atol=1e-6, norm=diffrax.adjoint_rms_seminorm)
      self.adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller)

    self.stepsize_controller = diffrax.PIDController(rtol=controller_rtol, atol=controller_atol)

  def __call__(self,
               x: Array,
               y: Optional[Array] = None,
               *,
               inverse: Optional[bool] = False,
               log_likelihood: Optional[bool] = False,
               trace_estimate_likelihood: Optional[bool] = False,
               save_at: Optional[Array] = None,
               key: Optional[PRNGKeyArray] = None,
               t0: Optional[float] = 0.0,
               t1: Optional[float] = 1.0) -> Array:
    """**Arguemnts**:

    - `x`: The input to the neural ODE.  Must be a rank 1 array.
    - `key`: The random number generator key.
    - `inverse`: Whether to compute the inverse of the neural ODE.  `inverse=True`
               corresponds to going from the base space to the data space.
    - `log_likelihood`: Whether to compute the log likelihood of the neural ODE.
    - `trace_estimate_likelihood`: Whether to compute a trace estimate of the likelihood of the neural ODE.
    - `save_at`: The times to save the neural ODE at.
    - `key`: The random key to use for initialization
    - `t0`: The initial time.
    - `t1`: The final time.

    **Returns**:
    - `z`: The output of the neural ODE.
    - `log_likelihood`: The log likelihood of the neural ODE if `log_likelihood=True`.
    """
    assert x.shape == self.vector_field.input_shape

    if trace_estimate_likelihood:
      # Get a random vector for hutchinsons trace estimator
      k1, _ = random.split(key, 2)
      v = random.normal(k1, x.shape)

    # Split the model into its static and dynamic parts so that backprop
    # through the ode solver can be faster.
    params, static = eqx.partition(self.vector_field, eqx.is_array)

    def f(t, carry, params):
      x, log_det, total_vf_norm, total_jac_frob_norm = carry

      if inverse == False:
        # If we're inverting the flow, we need to adjust the time
        t = t1 - t

      # Recombine the model
      model = eqx.combine(params, static)

      # Fill the model with the current time
      def apply_vf(x):
        return model(t, x, y=y)

      if log_likelihood:
        if trace_estimate_likelihood:
          # Hutchinsons trace estimator.  See ContinuousNormalizingFlow https://arxiv.org/pdf/1810.01367.pdf
          dxdt, dudxv = jax.jvp(apply_vf, (x,), (v,))
          dlogpxdt = -jnp.sum(dudxv*v)
          dtjfndt = jnp.sum(dudxv**2)
        else:
          # Brute force dlogpx/dt.  See NeuralODE https://arxiv.org/pdf/1806.07366.pdf
          x_flat = x.ravel()
          eye = jnp.eye(x_flat.shape[-1])
          x_shape = x.shape

          def jvp_flat(x_flat, dx_flat):
            x = x_flat.reshape(x_shape)
            dx = dx_flat.reshape(x_shape)
            dxdt, d2dx_dtdx = jax.jvp(apply_vf, (x,), (dx,))
            return dxdt, d2dx_dtdx.ravel()

          dxdt, d2dx_dtdx_flat = jax.vmap(jvp_flat, in_axes=(None, 0))(x_flat, eye)
          dxdt = dxdt[0]
          dlogpxdt = -jnp.trace(d2dx_dtdx_flat)
          dtjfndt = jnp.sum(d2dx_dtdx_flat**2)

      else:
        # Don't worry about the log likelihood
        dxdt = apply_vf(x)
        dlogpxdt = jnp.zeros_like(log_det)
        dtjfndt = jnp.zeros_like(total_jac_frob_norm)

      if inverse == False:
        # If we're inverting the flow, we need to flip the sign of dxdt
        dxdt = -dxdt

      # Accumulate the norm of the vector field
      dvfnormdt = jnp.sum(dxdt**2)

      if inverse:
        dlogpxdt = -dlogpxdt
      return dxdt, dlogpxdt, dvfnormdt, dtjfndt

    term = diffrax.ODETerm(f)
    solver = diffrax.Dopri5()

    # Determine which times we want to save the neural ODE at.
    if save_at is None:
      saveat = diffrax.SaveAt(ts=[t1])
    else:
      saveat = diffrax.SaveAt(ts=save_at)

    log_det = jnp.array(0.0)
    total_vf_norm = jnp.array(0.0)
    total_jac_frob_norm = jnp.array(0.0)

    # Run the ODE solver
    solution = diffrax.diffeqsolve(term,
                                   solver,
                                   saveat=saveat,
                                   t0=t0,
                                   t1=t1,
                                   dt0=0.001,
                                   y0=(x,
                                       log_det,
                                       total_vf_norm,
                                       total_jac_frob_norm),
                                   args=params,
                                   adjoint=self.adjoint,
                                   stepsize_controller=self.stepsize_controller,
                                   throw=True)
    outs = solution.ys

    if save_at is None:
      # Only take the first time
      outs = jax.tree_util.tree_map(lambda x: x[0], outs)

    z, log_det, total_vf_norm, total_jac_frob_norm = outs

    # Construct the new solution
    kwargs = {f.name: getattr(solution, f.name)
              for f in fields(solution)}
    kwargs['ys'] = z

    new_solution = NeuralODESolution(log_det=log_det,
                                     total_vf_norm=total_vf_norm,
                                     total_jac_frob_norm=total_jac_frob_norm,
                                     **kwargs)
    return new_solution
__init__(self, vf: Module, adjoint: Optional[str] = 'recursive_checkpoint', controller_rtol: Optional[float] = 0.001, controller_atol: Optional[float] = 1e-05) ¤

Arguments:

  • vf: A function that computes the vector field. It must output a vector of the same shape as its input.
  • adjoint: The adjoint method to use. Can be one of the following:
  • "recursive_checkpoint": Use the recursive checkpoint method. Doesn't support jvp.
  • "direct": Use the direct method. Supports jvps.
  • "seminorm": Use the seminorm method. Does fast backprop through the solver.
  • controller_rtol: The relative tolerance of the stepsize controller.
  • controller_atol: The absolute tolerance of the stepsize controller.
Source code in generax/nn/neural_ode.py
def __init__(self,
             vf: eqx.Module,
             adjoint: Optional[str] = 'recursive_checkpoint',
             controller_rtol: Optional[float] = 1e-3,
             controller_atol: Optional[float] = 1e-5,
):
  """**Arguments**:

  - `vf`: A function that computes the vector field.  It must output
          a vector of the same shape as its input.
  - `adjoint`: The adjoint method to use.  Can be one of the following:
     - `"recursive_checkpoint"`: Use the recursive checkpoint method.  Doesn't support jvp.
     - `"direct"`: Use the direct method.  Supports jvps.
     - `"seminorm"`: Use the seminorm method.  Does fast backprop through the solver.
  - `controller_rtol`: The relative tolerance of the stepsize controller.
  - `controller_atol`: The absolute tolerance of the stepsize controller.
  """
  self.vector_field = vf

  if adjoint == 'recursive_checkpoint':
    self.adjoint = diffrax.RecursiveCheckpointAdjoint()
  elif adjoint == 'direct':
    self.adjoint = diffrax.DirectAdjoint()
  elif adjoint == 'seminorm':
    adjoint_controller = diffrax.PIDController(
        rtol=1e-3, atol=1e-6, norm=diffrax.adjoint_rms_seminorm)
    self.adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller)

  self.stepsize_controller = diffrax.PIDController(rtol=controller_rtol, atol=controller_atol)
__call__(self, x: Array, y: Optional[Array] = None, *, inverse: Optional[bool] = False, log_likelihood: Optional[bool] = False, trace_estimate_likelihood: Optional[bool] = False, save_at: Optional[Array] = None, key: Optional[PRNGKeyArray] = None, t0: Optional[float] = 0.0, t1: Optional[float] = 1.0) -> Array ¤

Arguemnts:

  • x: The input to the neural ODE. Must be a rank 1 array.
  • key: The random number generator key.
  • inverse: Whether to compute the inverse of the neural ODE. inverse=True corresponds to going from the base space to the data space.
  • log_likelihood: Whether to compute the log likelihood of the neural ODE.
  • trace_estimate_likelihood: Whether to compute a trace estimate of the likelihood of the neural ODE.
  • save_at: The times to save the neural ODE at.
  • key: The random key to use for initialization
  • t0: The initial time.
  • t1: The final time.

Returns: - z: The output of the neural ODE. - log_likelihood: The log likelihood of the neural ODE if log_likelihood=True.

Source code in generax/nn/neural_ode.py
def __call__(self,
             x: Array,
             y: Optional[Array] = None,
             *,
             inverse: Optional[bool] = False,
             log_likelihood: Optional[bool] = False,
             trace_estimate_likelihood: Optional[bool] = False,
             save_at: Optional[Array] = None,
             key: Optional[PRNGKeyArray] = None,
             t0: Optional[float] = 0.0,
             t1: Optional[float] = 1.0) -> Array:
  """**Arguemnts**:

  - `x`: The input to the neural ODE.  Must be a rank 1 array.
  - `key`: The random number generator key.
  - `inverse`: Whether to compute the inverse of the neural ODE.  `inverse=True`
             corresponds to going from the base space to the data space.
  - `log_likelihood`: Whether to compute the log likelihood of the neural ODE.
  - `trace_estimate_likelihood`: Whether to compute a trace estimate of the likelihood of the neural ODE.
  - `save_at`: The times to save the neural ODE at.
  - `key`: The random key to use for initialization
  - `t0`: The initial time.
  - `t1`: The final time.

  **Returns**:
  - `z`: The output of the neural ODE.
  - `log_likelihood`: The log likelihood of the neural ODE if `log_likelihood=True`.
  """
  assert x.shape == self.vector_field.input_shape

  if trace_estimate_likelihood:
    # Get a random vector for hutchinsons trace estimator
    k1, _ = random.split(key, 2)
    v = random.normal(k1, x.shape)

  # Split the model into its static and dynamic parts so that backprop
  # through the ode solver can be faster.
  params, static = eqx.partition(self.vector_field, eqx.is_array)

  def f(t, carry, params):
    x, log_det, total_vf_norm, total_jac_frob_norm = carry

    if inverse == False:
      # If we're inverting the flow, we need to adjust the time
      t = t1 - t

    # Recombine the model
    model = eqx.combine(params, static)

    # Fill the model with the current time
    def apply_vf(x):
      return model(t, x, y=y)

    if log_likelihood:
      if trace_estimate_likelihood:
        # Hutchinsons trace estimator.  See ContinuousNormalizingFlow https://arxiv.org/pdf/1810.01367.pdf
        dxdt, dudxv = jax.jvp(apply_vf, (x,), (v,))
        dlogpxdt = -jnp.sum(dudxv*v)
        dtjfndt = jnp.sum(dudxv**2)
      else:
        # Brute force dlogpx/dt.  See NeuralODE https://arxiv.org/pdf/1806.07366.pdf
        x_flat = x.ravel()
        eye = jnp.eye(x_flat.shape[-1])
        x_shape = x.shape

        def jvp_flat(x_flat, dx_flat):
          x = x_flat.reshape(x_shape)
          dx = dx_flat.reshape(x_shape)
          dxdt, d2dx_dtdx = jax.jvp(apply_vf, (x,), (dx,))
          return dxdt, d2dx_dtdx.ravel()

        dxdt, d2dx_dtdx_flat = jax.vmap(jvp_flat, in_axes=(None, 0))(x_flat, eye)
        dxdt = dxdt[0]
        dlogpxdt = -jnp.trace(d2dx_dtdx_flat)
        dtjfndt = jnp.sum(d2dx_dtdx_flat**2)

    else:
      # Don't worry about the log likelihood
      dxdt = apply_vf(x)
      dlogpxdt = jnp.zeros_like(log_det)
      dtjfndt = jnp.zeros_like(total_jac_frob_norm)

    if inverse == False:
      # If we're inverting the flow, we need to flip the sign of dxdt
      dxdt = -dxdt

    # Accumulate the norm of the vector field
    dvfnormdt = jnp.sum(dxdt**2)

    if inverse:
      dlogpxdt = -dlogpxdt
    return dxdt, dlogpxdt, dvfnormdt, dtjfndt

  term = diffrax.ODETerm(f)
  solver = diffrax.Dopri5()

  # Determine which times we want to save the neural ODE at.
  if save_at is None:
    saveat = diffrax.SaveAt(ts=[t1])
  else:
    saveat = diffrax.SaveAt(ts=save_at)

  log_det = jnp.array(0.0)
  total_vf_norm = jnp.array(0.0)
  total_jac_frob_norm = jnp.array(0.0)

  # Run the ODE solver
  solution = diffrax.diffeqsolve(term,
                                 solver,
                                 saveat=saveat,
                                 t0=t0,
                                 t1=t1,
                                 dt0=0.001,
                                 y0=(x,
                                     log_det,
                                     total_vf_norm,
                                     total_jac_frob_norm),
                                 args=params,
                                 adjoint=self.adjoint,
                                 stepsize_controller=self.stepsize_controller,
                                 throw=True)
  outs = solution.ys

  if save_at is None:
    # Only take the first time
    outs = jax.tree_util.tree_map(lambda x: x[0], outs)

  z, log_det, total_vf_norm, total_jac_frob_norm = outs

  # Construct the new solution
  kwargs = {f.name: getattr(solution, f.name)
            for f in fields(solution)}
  kwargs['ys'] = z

  new_solution = NeuralODESolution(log_det=log_det,
                                   total_vf_norm=total_vf_norm,
                                   total_jac_frob_norm=total_jac_frob_norm,
                                   **kwargs)
  return new_solution

generax.nn.neural_ode.NeuralODESolution ¤

The solution to a neural ODE. This wraps the diffrax solution class and adds the log determinant of the transformation and some other items from http://proceedings.mlr.press/v119/finlay20a/finlay20a.pdf

Attributes:

  • log_det: The log determinant of the transformation.
  • total_vf_norm: The total norm of the vector on the path. This can help determine how straight the path is.
  • total_jac_frob_norm: The total norm of the jacobian of the vector field. This
Source code in generax/nn/neural_ode.py
class NeuralODESolution(Solution):
  """The solution to a neural ODE.  This wraps the diffrax solution
  class and adds the log determinant of the transformation and some
  other items from http://proceedings.mlr.press/v119/finlay20a/finlay20a.pdf

  **Attributes**:

  - `log_det`: The log determinant of the transformation.
  - `total_vf_norm`: The total norm of the vector on the path.
                     This can help determine how straight the path is.
  - `total_jac_frob_norm`: The total norm of the jacobian of the vector field.
                           This
  """

  log_det: Array
  total_vf_norm: Array
  total_jac_frob_norm: Array
__init__(self, t0: Scalar, t1: Scalar, ts: Optional[Array['times']], ys: Optional[PyTree['times', ...]], interpolation: Optional[diffrax.global_interpolation.DenseInterpolation], stats: Dict[str, Any], result: RESULTS, solver_state: Optional[PyTree], controller_state: Optional[PyTree], made_jump: Optional[bool], log_det: Array, total_vf_norm: Array, total_jac_frob_norm: Array) ¤