Interfaces API#

The interfaces module provides unified APIs for different diffusion and flow matching formulations.

Continuous Interfaces#

class interfaces.continuous.Interfaces(*args: Any, **kwargs: Any)[source]#

Bases: Module, ABC

Base class for all diffusion / flow matching interfaces.

All interfaces be a wrapper around network backbone and should support:
  • Define the pre-conditionings (see EDM)

  • Calculate losses for training
    • Define transport path (alpha_t & sigma_t)

    • Sample t

    • Sample X_t

  • Give tangent for sampling

Required RNG Key:
  • time: for sampling t

  • noise: for sampling n

__init__(network: Module, train_time_dist_type: str | TrainingTimeDistType)[source]#
abstract c_in(t: Array) Array[source]#

Calculate c_in for the interface.

Parameters:

t – current timestep.

Returns:

c_in, c_in for the interface.

Return type:

jnp.ndarray

abstract c_out(t: Array) Array[source]#

Calculate c_out for the interface.

Parameters:

t – current timestep.

Returns:

c_out, c_out for the interface.

Return type:

jnp.ndarray

abstract c_skip(t: Array) Array[source]#

Calculate c_skip for the interface.

Parameters:

t – current timestep.

Returns:

c_skip, c_skip for the interface.

Return type:

jnp.ndarray

abstract c_noise(t: Array) Array[source]#

Calculate c_noise for the interface.

Parameters:

t – current timestep.

Returns:

c_noise, c_noise for the interface.

Return type:

jnp.ndarray

abstract sample_t(shape: tuple[int, ...]) Array[source]#

Sample t from the training time distribution.

Parameters:

shape – shape of timestep t.

Returns:

t, sampled timestep t.

Return type:

jnp.ndarray

abstract sample_n(shape: tuple[int, ...]) Array[source]#

Sample noises.

Parameters:

shape – shape of noise.

Returns:

n, sampled noise.

Return type:

jnp.ndarray

abstract sample_x_t(x: Array, n: Array, t: Array) Array[source]#

Sample X_t according to the defined interface.

Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

x_t, sampled X_t according to transport path.

Return type:

jnp.ndarray

abstract target(x: Array, n: Array, t: Array) Array[source]#

Get training target.

Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

target, training target.

Return type:

jnp.ndarray

abstract pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Predict ODE tangent according to the defined interface.

Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

Returns:

tangent, predicted ODE tangent.

Return type:

jnp.ndarray

abstract score(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Transform ODE tangent to the Score Function nabla log p_t(x).

Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

Returns:

score, score function nabla log p_t(x).

Return type:

jnp.ndarray

abstract loss(x: Array, *args, **kwargs) Array[source]#

Calculate loss for training.

Parameters:
  • x – input clean sample.

  • args – additional arguments for network forward.

  • kwargs – additional keyword arguments for network forward.

Returns:

loss, calculated loss.

Return type:

jnp.ndarray

static mean_flat(x: Array) Array[source]#

Take mean w.r.t. all dimensions of x except the first.

Parameters:

x – input array.

Returns:

mean, mean across all dimensions except the first.

Return type:

jnp.ndarray

static bcast_right(x: Array, y: Array) Array[source]#

Broadcast x to the right to match the shape of y.

Parameters:
  • x – array to broadcast.

  • y – target array to match shape.

Returns:

broadcasted, x broadcasted to match y’s shape.

Return type:

jnp.ndarray

static t_shift(t: Array, shift: float) Array[source]#

Shift t by a constant shift value.

Parameters:
  • t – input timestep array.

  • shift – shift value.

Returns:

shifted_t, t shifted by the shift value.

Return type:

jnp.ndarray

class interfaces.continuous.SiTInterface(*args: Any, **kwargs: Any)[source]#

Bases: Interfaces

Interface for SiT.

Transport path:

\[x_t = (1 - t) * x + t * n\]

Losses:

\[L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2\]

Predictions:

\[x = xt - t * D(x_t, t)\]
__init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5, t_shift_base: int = 4096)[source]#
c_in(t: Array) Array[source]#

Flow matching preconditioning.

\[c_{in} = 1\]
c_out(t: Array) Array[source]#

Flow matching preconditioning.

\[c_{out} = 1\]
c_skip(t: Array) Array[source]#

Flow matching preconditioning.

\[c_{skip} = 0\]
c_noise(t: Array) Array[source]#

Flow matching preconditioning.

\[c_{noise} = t\]
sample_x_t(x: Array, n: Array, t: Array) Array[source]#

Sample x_t defined by flow matching.

\[x_t = (1 - t) * x + t * n\]
Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

x_t, sampled x_t according to flow matching.

Return type:

jnp.ndarray

target(x: Array, n: Array, t: Array) Array[source]#

Return flow matching target

\[v = n - x\]
Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

v, flow matching target.

Return type:

jnp.ndarray

pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Predict flow matching tangent.

\[v = D(x_t, t)\]
Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

  • *args – additional arguments for network forward.

  • **kwargs – additional keyword arguments for network forward.

Returns:

v, predicted flow matching tangent.

Return type:

jnp.ndarray

score(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Transform flow matching tangent to the score function.

\[\nabla \log p_t(x) = -x_t - (1 - t) * D(x_t, t)\]
Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

  • *args – additional arguments for network forward.

  • **kwargs – additional keyword arguments for network forward.

Returns:

score, score function nabla log p_t(x).

Return type:

jnp.ndarray

loss(x: Array, *args, return_aux=False, **kwargs) Array[source]#

Calculate flow matching loss.

\[L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2\]
Parameters:
  • x – input clean sample.

  • *args – additional arguments for network forward.

  • return_aux – whether to return auxiliary outputs.

  • **kwargs – additional keyword arguments for network forward.

Returns:

loss, calculated loss (or tuple with aux outputs if return_aux=True).

Return type:

jnp.ndarray or tuple

class interfaces.continuous.EDMInterface(*args: Any, **kwargs: Any)[source]#

Bases: Interfaces

Interface for EDM.

Transport Path:

\[x_t = x + t * n\]

Losses:

\[L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2\]

Predictions:

\[x = D(x_t, t)\]
__init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5)[source]#
c_in(t: Array) Array[source]#

EDM preconditioning.

\[c_{in} = 1 / \sqrt{x_sigma ^ 2 + t ^ 2}\]
c_out(t: Array) Array[source]#

EDM preconditioning.

\[c_{out} = t * x_sigma / \sqrt{t ^ 2 + x_sigma ^ 2}\]
c_skip(t) Array[source]#

EDM preconditioning.

\[c_{skip} = x_sigma ^ 2 / (t ^ 2 + x_sigma ^ 2)\]
c_noise(t: Array) Array[source]#

EDM preconditioning.

\[c_{noise} = \log(t) / 4\]
sample_x_t(x: Array, n: Array, t: Array) Array[source]#

Sample x_t defined by EDM.

\[x_t = x + t * n\]
Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

x_t, sampled x_t according to EDM.

Return type:

jnp.ndarray

target(x: Array, n: Array, t: Array) Array[source]#

Return EDM target.

\[target = x\]
Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

Returns:

target, EDM target.

Return type:

jnp.ndarray

pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Predict EDM tangent.

\[v = (x_t - D(x_t, t)) / t\]
Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

  • *args – additional arguments for network forward.

  • **kwargs – additional keyword arguments for network forward.

Returns:

v, predicted EDM tangent.

Return type:

jnp.ndarray

score(x_t: Array, t: Array, *args, **kwargs) Array[source]#

Transform EDM tangent to the score function.

\[\nabla \log p_t(x) = -(x_t - v) / t ^ 2\]
Parameters:
  • x_t – input noisy sample.

  • t – current timestep.

  • *args – additional arguments for network forward.

  • **kwargs – additional keyword arguments for network forward.

Returns:

score, score function nabla log p_t(x).

Return type:

jnp.ndarray

loss(x: Array, *args, return_aux=False, **kwargs) Array[source]#

Calculate EDM loss.

\[L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2\]
Parameters:
  • x – input clean sample.

  • *args – additional arguments for network forward.

  • return_aux – whether to return auxiliary outputs.

  • **kwargs – additional keyword arguments for network forward.

Returns:

loss, calculated loss (or tuple with aux outputs if return_aux=True).

Return type:

jnp.ndarray or tuple

class interfaces.continuous.MeanFlowInterface(*args: Any, **kwargs: Any)[source]#

Bases: SiTInterface

Interface for Mean Flow.

Transport Path:

\[x_t = (1 - t) * x + t * n\]

Losses:

\[L = \mathbb{E} \Vert u(x_t, t, r) - \text{sg}(v - (t - r) * \frac{du}{dt}) \Vert ^ 2\]

Predictions:

\[x_r = x_t - (t - r) * u(x_t, t, r)\]
__init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5, guidance_scale: float = 1.0, guidance_mixture_ratio: float = 0.5, guidance_t_min: float = 0.0, guidance_t_max: float = 1.0, norm_eps: float = 0.001, norm_power: float = 1.0, fm_portion: float = 0.75, cond_drop_ratio: float = 0.5, t_shift_base: int = 4096)[source]#
sample_t_r(shape: tuple[int, ...]) tuple[Array, Array][source]#

Sample time pairs (t, r) for Mean Flow training.

Parameters:

shape – shape of the time arrays.

Returns:

(t, r), time pairs where t >= r.

Return type:

tuple[jnp.ndarray, jnp.ndarray]

cond_drop(x: Array, n: Array, v: Array, y: Array, neg_y: Array) tuple[Array, Array][source]#

Drop the condition with a certain ratio.

Note: the reason why we need to drop the condition outside of the model is that

the effective regression target depends on the resulted from dropout insta velocity

Parameters:
  • x – input clean sample.

  • n – noise.

  • v – velocity.

  • y – condition.

  • neg_y – negative condition.

Returns:

(v, y), updated velocity and condition after dropout.

Return type:

tuple[jnp.ndarray, jnp.ndarray]

insta_velocity(x: Array, n: Array, t: Array, *args, y: Array | None = None, neg_y: Array | None = None, **kwargs) tuple[Array, Array, Array][source]#

Instantaneous velocity of the mean flow. For exact formulation, see https://arxiv.org/pdf/2505.13447.

Parameters:
  • x – input clean sample.

  • n – noise.

  • t – current timestep.

  • *args – additional arguments for network forward.

  • y – condition (optional).

  • neg_y – negative condition (optional).

  • **kwargs – additional keyword arguments for network forward.

Returns:

(v, y, neg_y), instantaneous velocity and conditions.

Return type:

tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

target(x: Array, n: Array, t: Array, r: Array, *args, y: Array | None = None, neg_y: Array | None = None, **kwargs) Array[source]#

Get training target for Mean Flow.

Note: network must be augmented with r, the jump size, as an additional input

\[target = v - (t - r) * \frac{du}{dt}\]
pred(x_t: Array, t: Array, r: Array, *args, **kwargs) Array[source]#

Predict ODE tangent according to the Mean Flow interface.

\[v_{(t, r)} = u(x_t, t, r)\]
loss(x: Array, *args, return_aux=False, **kwargs) Array[source]#

Calculate the Mean Flow loss.

\[L = \mathbb{E} \Vert u(x_t, t, r) - v_{(t, r)} \Vert ^ 2\]

REPA Interfaces#

interfaces.repa.build_mlp(hidden_size, projector_dim, feature_dim, rngs, dtype=<class 'jax.numpy.float32'>)[source]#

Build a multi-layer perceptron for feature projection.

Parameters:
  • hidden_size – input hidden size.

  • projector_dim – projector dimension.

  • feature_dim – output feature dimension.

  • rngs – random number generators.

  • dtype – data type.

Returns:

mlp, multi-layer perceptron for feature projection.

Return type:

nnx.Sequential

class interfaces.repa.DiT_REPA(*args: Any, **kwargs: Any)[source]#

Bases: Module

DiT with REPA (Representation Alignment) wrapper.

This class wraps a diffusion interface with REPA functionality for representation alignment.

__init__(interface, *, feature_dim: int, repa_loss_weight: float, repa_depth: int, proj_dim: int, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#

Initialize DiT_REPA.

Parameters:
  • interface – diffusion interface to wrap.

  • feature_dim – feature dimension for alignment.

  • repa_loss_weight – weight for REPA loss.

  • repa_depth – depth for REPA feature extraction.

  • proj_dim – projection dimension.

  • dtype – data type.

loss(x: Array, x_feature: Array, *args, **kwargs) tuple[Array, Array][source]#

Calculate combined diffusion and REPA loss.

Parameters:
  • x – input clean sample.

  • x_feature – target features for alignment.

  • *args – additional arguments for interface.

  • **kwargs – additional keyword arguments for interface.

Returns:

(diffusion_loss, repa_loss), diffusion and REPA losses.

Return type:

tuple[jnp.ndarray, jnp.ndarray]

pred(*args, **kwargs) Array[source]#

Predict ODE tangent.

Parameters:
  • *args – arguments passed to interface.

  • **kwargs – keyword arguments passed to interface.

Returns:

tangent, predicted ODE tangent from interface.

Return type:

jnp.ndarray

score(*args, **kwargs) Array[source]#

Calculate score function.

Parameters:
  • *args – arguments passed to interface.

  • **kwargs – keyword arguments passed to interface.

Returns:

score, score function from interface.

Return type:

jnp.ndarray

__call__(x: Array, x_feature: Array, *args, **kwargs) dict[str, Array][source]#

Forward pass with combined diffusion and REPA loss.

Parameters:
  • x – input clean sample.

  • x_feature – target features for alignment.

  • *args – additional arguments for interface.

  • **kwargs – additional keyword arguments for interface.

Returns:

losses, dictionary containing total, diffusion, and REPA losses.

Return type:

dict[str, jnp.ndarray]