Samplers#

The samplers module provides various sampling strategies for diffusion models, supporting both deterministic and stochastic approaches.

Overview#

This module contains interface-agnostic samplers that can work with any diffusion interface. The samplers support:

  • Deterministic sampling: Euler, Heun methods

  • Stochastic sampling: Euler-Maruyama for SDE integration

  • Flexible time scheduling: Uniform and exponential schedules

  • Multiple time variables: Support for two-time variable models like MeanFlow

  • Guidance support: Optional guidance networks for conditional generation

Available Samplers#

Base Sampler#

The samplers.samplers.Samplers class is the abstract base class for all samplers. It provides:

  • Time grid generation with different schedules

  • Main sampling loop with high-performance & JAX/NNX-friendly scan

  • Support for guidance networks and custom time grids

  • Interface-agnostic design

Deterministic Samplers#

Euler Sampler#

First-order deterministic sampler for ODE integration:

from samplers.samplers import EulerSampler
from samplers.samplers import SamplingTimeDistType

sampler = EulerSampler(
    num_sampling_steps=50,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)

Heun Sampler#

Second-order deterministic sampler (recommended for most cases):

from samplers.samplers import HeunSampler

sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)

Stochastic Samplers#

Euler-Maruyama Sampler#

Stochastic sampler for SDE integration with configurable diffusion coefficients:

from samplers.samplers import EulerMaruyamaSampler
from samplers.samplers import DiffusionCoeffType

sampler = EulerMaruyamaSampler(
    num_sampling_steps=250,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    diffusion_coeff=DiffusionCoeffType.LINEAR_KL,
    diffusion_coeff_norm=1.0
)

Specialized Samplers#

EulerJump Sampler#

For two-time variable models like MeanFlow:

from samplers.samplers import EulerJumpSampler

sampler = EulerJumpSampler(
    num_sampling_steps=50,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)

Time Scheduling#

Uniform Schedule#

Equal time steps between t_start and t_end:

from samplers.samplers import SamplingTimeDistType

sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    sampling_time_kwargs={
        't_start': 1.0,
        't_end': 0.0
    }
)

Exponential Schedule#

Exponentially spaced time steps (EDM-style):

sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.EXP,
    sampling_time_kwargs={
        'sigma_min': 0.002,
        'sigma_max': 80.0,
        'rho': 7.0
    }
)

Custom Time Grid#

Provide your own time steps:

import jax.numpy as jnp

custom_times = jnp.linspace(1, 0, 50)
samples = sampler.sample(
    rng, interface, x,
    custom_timegrid=custom_times
)

Sampler Configuration#

All samplers support various configuration options:

sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    sampling_time_kwargs={
        't_start': 1.0,
        't_end': 0.0,
        't_shift_base': 4096,
        't_shift_cur': 4096
    }
)

Advanced Usage#

Custom Samplers#

You can create custom samplers by extending the base samplers.samplers.Samplers class.

from samplers.samplers import Samplers, SamplingTimeDistType
from flax import nnx
import jax.numpy as jnp

class CustomSampler(Samplers):
    """Custom sampler implementation."""

    def __init__(
        self,
        num_sampling_steps: int,
        sampling_time_dist: SamplingTimeDistType = SamplingTimeDistType.UNIFORM,
        sampling_time_kwargs: dict = {},
        custom_param: float = 1.0
    ):
        super().__init__(num_sampling_steps, sampling_time_dist, sampling_time_kwargs)
        self.custom_param = custom_param

    def forward(
        self, rng, net: nnx.Module, x: jnp.ndarray,
        t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """Implement your custom forward step."""
        # Get prediction from network
        ...

        return x_next

    def last_step(
        self, rng, net: nnx.Module, x: jnp.ndarray,
        t_curr: jnp.ndarray, t_last: jnp.ndarray,
        g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """Implement your custom final step."""
        # Simple final step - can be more sophisticated
        ...

        return x_final

Integration with Configuration System#

To integrate your custom sampler with the configuration system:

  1. Add to Registry: Update utils/initialize.py:

# In utils/initialize.py
SAMPLER_REGISTRY = {
    'euler': samplers.EulerSampler,
    'heun': samplers.HeunSampler,
    'euler-maruyama': samplers.EulerMaruyamaSampler,
    'custom': your_module.CustomSampler,  # Add your sampler
}
  1. Use in Configuration: Reference in your config files:

config.sampler = {
    'sampler_class': 'custom',
    'num_sampling_steps': 32,
    'sampling_time_dist': 'uniform',
    'custom_param': 0.5
}