Interfaces#

The interfaces module is the heart of the library, providing unified interfaces for different diffusion and flow matching formulations.

Overview#

This module contains the core abstractions that allow you to work with different diffusion algorithms through a consistent API. The main interfaces are:

  • Continuous-time interfaces: For algorithms like SiT, EDM, and MeanFlow

  • REPA wrapper: For representation alignment

  • Discrete-time interfaces: Currently experimental

Core Components#

Base Interface#

The interfaces.continuous.Interfaces class is the abstract base class for all diffusion and flow matching interfaces. It provides:

  • Unified API across different algorithms

  • Support for both deterministic and stochastic sampling

  • Flexible time scheduling

  • JAX/NNX compatibility

  • Required RNG infrastructure for time and noise sampling

Continuous-time Interfaces#

The library provides several concrete implementations:

  • SiTInterface: Straight-through transport with linear interpolation between data and noise

  • EDMInterface: EDM-style variance preserving diffusion with log-normal time families

  • MeanFlowInterface: Mean field flow matching with guidance mixing and stochastic jump times

  • sCTInterface/sCDInterface: Score-based consistency training (experimental)

REPA Interface#

The interfaces.repa.DiT_REPA class provides a wrapper for representation alignment:

  • Wraps any diffusion interface

  • Adds representation alignment capabilities

  • Improves sample quality through better representations

  • Uses feature projection networks for alignment

Usage Examples#

Basic Interface Usage#

Thanks to our unified Interface & Sampler API, you can use any interface with any sampler with the following syntax.

import jax
import jax.numpy as jnp
from networks.transformers.dit_nnx import DiT
from interfaces.continuous import SiTInterface
from samplers.samplers import HeunSampler

# Create DiT network
network = DiT(
    input_size=32,
    hidden_size=1152,
    depth=28,
    num_heads=16,
    rngs=jax.random.PRNGKey(0)
)

# Create SiT interface
interface = SiTInterface(
    network=network,
    train_time_dist_type='uniform'
)

# Create sampler
sampler = HeunSampler(num_steps=32)

# Generate samples
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (4, 32, 32, 4))
samples = sampler.sample(interface, x)

REPA Usage#

from interfaces.repa import DiT_REPA

# Create REPA wrapper
repa_interface = DiT_REPA(
    interface=interface,
    feature_dim=512,
    repa_loss_weight=0.1,
    repa_depth=6,
    proj_dim=256
)

# Use with stochastic sampler
from samplers.samplers import EulerMaruyamaSampler
sampler = EulerMaruyamaSampler(num_steps=250)
samples = sampler.sample(repa_interface, params, key, batch_size=4)

Advanced Usage#

Custom Algorithms#

You can implement custom diffusion algorithms by extending the base interface:

from interfaces.continuous import Interfaces
from interfaces.continuous import TrainingTimeDistType

class CustomInterface(Interfaces):
    def __init__(self, network, train_time_dist_type='uniform'):
        super().__init__(network, train_time_dist_type)
        # Initialize your algorithm

    def c_in(self, t):
        # Implement c_in preconditioning
        pass

    def c_out(self, t):
        # Implement c_out preconditioning
        pass

    def c_skip(self, t):
        # Implement c_skip preconditioning
        pass

    def c_noise(self, t):
        # Implement c_noise preconditioning
        pass

    def sample_t(self, shape):
        # Implement time sampling
        pass

    def sample_n(self, shape):
        # Implement noise sampling
        pass

    def sample_x_t(self, x, n, t):
        # Implement transport path
        pass

    def target(self, x, n, t):
        # Implement training target
        pass

    def pred(self, x_t, t, *args, **kwargs):
        # Implement prediction
        pass

    def score(self, x_t, t, *args, **kwargs):
        # Implement score function
        pass

    def loss(self, x, *args, **kwargs):
        # Implement loss calculation
        pass

Time Distribution Types#

The interfaces support different time distribution types:

# Uniform time distribution
interface = SiTInterface(network, train_time_dist_type='uniform')

# Log-normal time distribution
interface = EDMInterface(network, train_time_dist_type='lognormal')

# Logit-normal time distribution
interface = SiTInterface(network, train_time_dist_type='logitnormal')

Interface Methods#

All interfaces provide a consistent API:

# Calculate loss for training
loss_dict = interface(x, *args, **kwargs)

# Get prediction for sampling
prediction = interface.pred(x_t, t, *args, **kwargs)

# Get score function for SDE sampling
score = interface.score(x_t, t, *args, **kwargs)

# Sample noisy state
x_t = interface.sample_x_t(x, n, t)

# Get training target
target = interface.target(x, n, t)