Interfaces API#
The interfaces module provides unified APIs for different diffusion and flow matching formulations.
Continuous Interfaces#
- class interfaces.continuous.Interfaces(*args: Any, **kwargs: Any)[source]#
-
Base class for all diffusion / flow matching interfaces.
- All interfaces be a wrapper around network backbone and should support:
Define the pre-conditionings (see EDM)
- Calculate losses for training
Define transport path (alpha_t & sigma_t)
Sample t
Sample X_t
Give tangent for sampling
- Required RNG Key:
time: for sampling t
noise: for sampling n
- abstract c_in(t: Array) Array[source]#
Calculate c_in for the interface.
- Parameters:
t – current timestep.
- Returns:
c_in, c_in for the interface.
- Return type:
jnp.ndarray
- abstract c_out(t: Array) Array[source]#
Calculate c_out for the interface.
- Parameters:
t – current timestep.
- Returns:
c_out, c_out for the interface.
- Return type:
jnp.ndarray
- abstract c_skip(t: Array) Array[source]#
Calculate c_skip for the interface.
- Parameters:
t – current timestep.
- Returns:
c_skip, c_skip for the interface.
- Return type:
jnp.ndarray
- abstract c_noise(t: Array) Array[source]#
Calculate c_noise for the interface.
- Parameters:
t – current timestep.
- Returns:
c_noise, c_noise for the interface.
- Return type:
jnp.ndarray
- abstract sample_t(shape: tuple[int, ...]) Array[source]#
Sample t from the training time distribution.
- Parameters:
shape – shape of timestep t.
- Returns:
t, sampled timestep t.
- Return type:
jnp.ndarray
- abstract sample_n(shape: tuple[int, ...]) Array[source]#
Sample noises.
- Parameters:
shape – shape of noise.
- Returns:
n, sampled noise.
- Return type:
jnp.ndarray
- abstract sample_x_t(x: Array, n: Array, t: Array) Array[source]#
Sample X_t according to the defined interface.
- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
x_t, sampled X_t according to transport path.
- Return type:
jnp.ndarray
- abstract target(x: Array, n: Array, t: Array) Array[source]#
Get training target.
- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
target, training target.
- Return type:
jnp.ndarray
- abstract pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Predict ODE tangent according to the defined interface.
- Parameters:
x_t – input noisy sample.
t – current timestep.
- Returns:
tangent, predicted ODE tangent.
- Return type:
jnp.ndarray
- abstract score(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Transform ODE tangent to the Score Function nabla log p_t(x).
- Parameters:
x_t – input noisy sample.
t – current timestep.
- Returns:
score, score function nabla log p_t(x).
- Return type:
jnp.ndarray
- abstract loss(x: Array, *args, **kwargs) Array[source]#
Calculate loss for training.
- Parameters:
x – input clean sample.
args – additional arguments for network forward.
kwargs – additional keyword arguments for network forward.
- Returns:
loss, calculated loss.
- Return type:
jnp.ndarray
- static mean_flat(x: Array) Array[source]#
Take mean w.r.t. all dimensions of x except the first.
- Parameters:
x – input array.
- Returns:
mean, mean across all dimensions except the first.
- Return type:
jnp.ndarray
- class interfaces.continuous.SiTInterface(*args: Any, **kwargs: Any)[source]#
Bases:
InterfacesInterface for SiT.
Transport path:
\[x_t = (1 - t) * x + t * n\]Losses:
\[L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2\]Predictions:
\[x = xt - t * D(x_t, t)\]- __init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5, t_shift_base: int = 4096)[source]#
- sample_x_t(x: Array, n: Array, t: Array) Array[source]#
Sample x_t defined by flow matching.
\[x_t = (1 - t) * x + t * n\]- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
x_t, sampled x_t according to flow matching.
- Return type:
jnp.ndarray
- target(x: Array, n: Array, t: Array) Array[source]#
Return flow matching target
\[v = n - x\]- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
v, flow matching target.
- Return type:
jnp.ndarray
- pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Predict flow matching tangent.
\[v = D(x_t, t)\]- Parameters:
x_t – input noisy sample.
t – current timestep.
*args – additional arguments for network forward.
**kwargs – additional keyword arguments for network forward.
- Returns:
v, predicted flow matching tangent.
- Return type:
jnp.ndarray
- score(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Transform flow matching tangent to the score function.
\[\nabla \log p_t(x) = -x_t - (1 - t) * D(x_t, t)\]- Parameters:
x_t – input noisy sample.
t – current timestep.
*args – additional arguments for network forward.
**kwargs – additional keyword arguments for network forward.
- Returns:
score, score function nabla log p_t(x).
- Return type:
jnp.ndarray
- loss(x: Array, *args, return_aux=False, **kwargs) Array[source]#
Calculate flow matching loss.
\[L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2\]- Parameters:
x – input clean sample.
*args – additional arguments for network forward.
return_aux – whether to return auxiliary outputs.
**kwargs – additional keyword arguments for network forward.
- Returns:
loss, calculated loss (or tuple with aux outputs if return_aux=True).
- Return type:
jnp.ndarray or tuple
- class interfaces.continuous.EDMInterface(*args: Any, **kwargs: Any)[source]#
Bases:
InterfacesInterface for EDM.
Transport Path:
\[x_t = x + t * n\]Losses:
\[L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2\]Predictions:
\[x = D(x_t, t)\]- __init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5)[source]#
- c_out(t: Array) Array[source]#
EDM preconditioning.
\[c_{out} = t * x_sigma / \sqrt{t ^ 2 + x_sigma ^ 2}\]
- sample_x_t(x: Array, n: Array, t: Array) Array[source]#
Sample x_t defined by EDM.
\[x_t = x + t * n\]- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
x_t, sampled x_t according to EDM.
- Return type:
jnp.ndarray
- target(x: Array, n: Array, t: Array) Array[source]#
Return EDM target.
\[target = x\]- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
- Returns:
target, EDM target.
- Return type:
jnp.ndarray
- pred(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Predict EDM tangent.
\[v = (x_t - D(x_t, t)) / t\]- Parameters:
x_t – input noisy sample.
t – current timestep.
*args – additional arguments for network forward.
**kwargs – additional keyword arguments for network forward.
- Returns:
v, predicted EDM tangent.
- Return type:
jnp.ndarray
- score(x_t: Array, t: Array, *args, **kwargs) Array[source]#
Transform EDM tangent to the score function.
\[\nabla \log p_t(x) = -(x_t - v) / t ^ 2\]- Parameters:
x_t – input noisy sample.
t – current timestep.
*args – additional arguments for network forward.
**kwargs – additional keyword arguments for network forward.
- Returns:
score, score function nabla log p_t(x).
- Return type:
jnp.ndarray
- loss(x: Array, *args, return_aux=False, **kwargs) Array[source]#
Calculate EDM loss.
\[L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2\]- Parameters:
x – input clean sample.
*args – additional arguments for network forward.
return_aux – whether to return auxiliary outputs.
**kwargs – additional keyword arguments for network forward.
- Returns:
loss, calculated loss (or tuple with aux outputs if return_aux=True).
- Return type:
jnp.ndarray or tuple
- class interfaces.continuous.MeanFlowInterface(*args: Any, **kwargs: Any)[source]#
Bases:
SiTInterfaceInterface for Mean Flow.
Transport Path:
\[x_t = (1 - t) * x + t * n\]Losses:
\[L = \mathbb{E} \Vert u(x_t, t, r) - \text{sg}(v - (t - r) * \frac{du}{dt}) \Vert ^ 2\]Predictions:
\[x_r = x_t - (t - r) * u(x_t, t, r)\]- __init__(network: Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0.0, t_sigma: float = 1.0, n_mu: float = 0.0, n_sigma: float = 1.0, x_sigma: float = 0.5, guidance_scale: float = 1.0, guidance_mixture_ratio: float = 0.5, guidance_t_min: float = 0.0, guidance_t_max: float = 1.0, norm_eps: float = 0.001, norm_power: float = 1.0, fm_portion: float = 0.75, cond_drop_ratio: float = 0.5, t_shift_base: int = 4096)[source]#
- sample_t_r(shape: tuple[int, ...]) tuple[Array, Array][source]#
Sample time pairs (t, r) for Mean Flow training.
- Parameters:
shape – shape of the time arrays.
- Returns:
(t, r), time pairs where t >= r.
- Return type:
tuple[jnp.ndarray, jnp.ndarray]
- cond_drop(x: Array, n: Array, v: Array, y: Array, neg_y: Array) tuple[Array, Array][source]#
Drop the condition with a certain ratio.
- Note: the reason why we need to drop the condition outside of the model is that
the effective regression target depends on the resulted from dropout insta velocity
- Parameters:
x – input clean sample.
n – noise.
v – velocity.
y – condition.
neg_y – negative condition.
- Returns:
(v, y), updated velocity and condition after dropout.
- Return type:
tuple[jnp.ndarray, jnp.ndarray]
- insta_velocity(x: Array, n: Array, t: Array, *args, y: Array | None = None, neg_y: Array | None = None, **kwargs) tuple[Array, Array, Array][source]#
Instantaneous velocity of the mean flow. For exact formulation, see https://arxiv.org/pdf/2505.13447.
- Parameters:
x – input clean sample.
n – noise.
t – current timestep.
*args – additional arguments for network forward.
y – condition (optional).
neg_y – negative condition (optional).
**kwargs – additional keyword arguments for network forward.
- Returns:
(v, y, neg_y), instantaneous velocity and conditions.
- Return type:
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
- target(x: Array, n: Array, t: Array, r: Array, *args, y: Array | None = None, neg_y: Array | None = None, **kwargs) Array[source]#
Get training target for Mean Flow.
Note: network must be augmented with r, the jump size, as an additional input
\[target = v - (t - r) * \frac{du}{dt}\]
REPA Interfaces#
- interfaces.repa.build_mlp(hidden_size, projector_dim, feature_dim, rngs, dtype=<class 'jax.numpy.float32'>)[source]#
Build a multi-layer perceptron for feature projection.
- Parameters:
hidden_size – input hidden size.
projector_dim – projector dimension.
feature_dim – output feature dimension.
rngs – random number generators.
dtype – data type.
- Returns:
mlp, multi-layer perceptron for feature projection.
- Return type:
nnx.Sequential
- class interfaces.repa.DiT_REPA(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDiT with REPA (Representation Alignment) wrapper.
This class wraps a diffusion interface with REPA functionality for representation alignment.
- __init__(interface, *, feature_dim: int, repa_loss_weight: float, repa_depth: int, proj_dim: int, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Initialize DiT_REPA.
- Parameters:
interface – diffusion interface to wrap.
feature_dim – feature dimension for alignment.
repa_loss_weight – weight for REPA loss.
repa_depth – depth for REPA feature extraction.
proj_dim – projection dimension.
dtype – data type.
- loss(x: Array, x_feature: Array, *args, **kwargs) tuple[Array, Array][source]#
Calculate combined diffusion and REPA loss.
- Parameters:
x – input clean sample.
x_feature – target features for alignment.
*args – additional arguments for interface.
**kwargs – additional keyword arguments for interface.
- Returns:
(diffusion_loss, repa_loss), diffusion and REPA losses.
- Return type:
tuple[jnp.ndarray, jnp.ndarray]
- pred(*args, **kwargs) Array[source]#
Predict ODE tangent.
- Parameters:
*args – arguments passed to interface.
**kwargs – keyword arguments passed to interface.
- Returns:
tangent, predicted ODE tangent from interface.
- Return type:
jnp.ndarray
- score(*args, **kwargs) Array[source]#
Calculate score function.
- Parameters:
*args – arguments passed to interface.
**kwargs – keyword arguments passed to interface.
- Returns:
score, score function from interface.
- Return type:
jnp.ndarray