Samplers#
The samplers module provides various sampling strategies for diffusion models, supporting both deterministic and stochastic approaches.
Overview#
This module contains interface-agnostic samplers that can work with any diffusion interface. The samplers support:
- Deterministic sampling: Euler, Heun methods 
- Stochastic sampling: Euler-Maruyama for SDE integration 
- Flexible time scheduling: Uniform and exponential schedules 
- Multiple time variables: Support for two-time variable models like MeanFlow 
- Guidance support: Optional guidance networks for conditional generation 
Available Samplers#
Base Sampler#
The samplers.samplers.Samplers class is the abstract base class for all samplers. It provides:
- Time grid generation with different schedules 
- Main sampling loop with high-performance & JAX/NNX-friendly scan 
- Support for guidance networks and custom time grids 
- Interface-agnostic design 
Deterministic Samplers#
Euler Sampler#
First-order deterministic sampler for ODE integration:
from samplers.samplers import EulerSampler
from samplers.samplers import SamplingTimeDistType
sampler = EulerSampler(
    num_sampling_steps=50,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Heun Sampler#
Second-order deterministic sampler (recommended for most cases):
from samplers.samplers import HeunSampler
sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Stochastic Samplers#
Euler-Maruyama Sampler#
Stochastic sampler for SDE integration with configurable diffusion coefficients:
from samplers.samplers import EulerMaruyamaSampler
from samplers.samplers import DiffusionCoeffType
sampler = EulerMaruyamaSampler(
    num_sampling_steps=250,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    diffusion_coeff=DiffusionCoeffType.LINEAR_KL,
    diffusion_coeff_norm=1.0
)
Specialized Samplers#
EulerJump Sampler#
For two-time variable models like MeanFlow:
from samplers.samplers import EulerJumpSampler
sampler = EulerJumpSampler(
    num_sampling_steps=50,
    sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Time Scheduling#
Uniform Schedule#
Equal time steps between t_start and t_end:
from samplers.samplers import SamplingTimeDistType
sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    sampling_time_kwargs={
        't_start': 1.0,
        't_end': 0.0
    }
)
Exponential Schedule#
Exponentially spaced time steps (EDM-style):
sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.EXP,
    sampling_time_kwargs={
        'sigma_min': 0.002,
        'sigma_max': 80.0,
        'rho': 7.0
    }
)
Custom Time Grid#
Provide your own time steps:
import jax.numpy as jnp
custom_times = jnp.linspace(1, 0, 50)
samples = sampler.sample(
    rng, interface, x,
    custom_timegrid=custom_times
)
Sampler Configuration#
All samplers support various configuration options:
sampler = HeunSampler(
    num_sampling_steps=32,
    sampling_time_dist=SamplingTimeDistType.UNIFORM,
    sampling_time_kwargs={
        't_start': 1.0,
        't_end': 0.0,
        't_shift_base': 4096,
        't_shift_cur': 4096
    }
)
Advanced Usage#
Custom Samplers#
You can create custom samplers by extending the base samplers.samplers.Samplers class.
from samplers.samplers import Samplers, SamplingTimeDistType
from flax import nnx
import jax.numpy as jnp
class CustomSampler(Samplers):
    """Custom sampler implementation."""
    def __init__(
        self,
        num_sampling_steps: int,
        sampling_time_dist: SamplingTimeDistType = SamplingTimeDistType.UNIFORM,
        sampling_time_kwargs: dict = {},
        custom_param: float = 1.0
    ):
        super().__init__(num_sampling_steps, sampling_time_dist, sampling_time_kwargs)
        self.custom_param = custom_param
    def forward(
        self, rng, net: nnx.Module, x: jnp.ndarray,
        t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """Implement your custom forward step."""
        # Get prediction from network
        ...
        return x_next
    def last_step(
        self, rng, net: nnx.Module, x: jnp.ndarray,
        t_curr: jnp.ndarray, t_last: jnp.ndarray,
        g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """Implement your custom final step."""
        # Simple final step - can be more sophisticated
        ...
        return x_final
Integration with Configuration System#
To integrate your custom sampler with the configuration system:
- Add to Registry: Update utils/initialize.py: 
# In utils/initialize.py
SAMPLER_REGISTRY = {
    'euler': samplers.EulerSampler,
    'heun': samplers.HeunSampler,
    'euler-maruyama': samplers.EulerMaruyamaSampler,
    'custom': your_module.CustomSampler,  # Add your sampler
}
- Use in Configuration: Reference in your config files: 
config.sampler = {
    'sampler_class': 'custom',
    'num_sampling_steps': 32,
    'sampling_time_dist': 'uniform',
    'custom_param': 0.5
}