Samplers#
The samplers module provides various sampling strategies for diffusion models, supporting both deterministic and stochastic approaches.
Overview#
This module contains interface-agnostic samplers that can work with any diffusion interface. The samplers support:
Deterministic sampling: Euler, Heun methods
Stochastic sampling: Euler-Maruyama for SDE integration
Flexible time scheduling: Uniform and exponential schedules
Multiple time variables: Support for two-time variable models like MeanFlow
Guidance support: Optional guidance networks for conditional generation
Available Samplers#
Base Sampler#
The samplers.samplers.Samplers class is the abstract base class for all samplers. It provides:
Time grid generation with different schedules
Main sampling loop with high-performance & JAX/NNX-friendly scan
Support for guidance networks and custom time grids
Interface-agnostic design
Deterministic Samplers#
Euler Sampler#
First-order deterministic sampler for ODE integration:
from samplers.samplers import EulerSampler
from samplers.samplers import SamplingTimeDistType
sampler = EulerSampler(
num_sampling_steps=50,
sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Heun Sampler#
Second-order deterministic sampler (recommended for most cases):
from samplers.samplers import HeunSampler
sampler = HeunSampler(
num_sampling_steps=32,
sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Stochastic Samplers#
Euler-Maruyama Sampler#
Stochastic sampler for SDE integration with configurable diffusion coefficients:
from samplers.samplers import EulerMaruyamaSampler
from samplers.samplers import DiffusionCoeffType
sampler = EulerMaruyamaSampler(
num_sampling_steps=250,
sampling_time_dist=SamplingTimeDistType.UNIFORM,
diffusion_coeff=DiffusionCoeffType.LINEAR_KL,
diffusion_coeff_norm=1.0
)
Specialized Samplers#
EulerJump Sampler#
For two-time variable models like MeanFlow:
from samplers.samplers import EulerJumpSampler
sampler = EulerJumpSampler(
num_sampling_steps=50,
sampling_time_dist=SamplingTimeDistType.UNIFORM
)
Time Scheduling#
Uniform Schedule#
Equal time steps between t_start and t_end:
from samplers.samplers import SamplingTimeDistType
sampler = HeunSampler(
num_sampling_steps=32,
sampling_time_dist=SamplingTimeDistType.UNIFORM,
sampling_time_kwargs={
't_start': 1.0,
't_end': 0.0
}
)
Exponential Schedule#
Exponentially spaced time steps (EDM-style):
sampler = HeunSampler(
num_sampling_steps=32,
sampling_time_dist=SamplingTimeDistType.EXP,
sampling_time_kwargs={
'sigma_min': 0.002,
'sigma_max': 80.0,
'rho': 7.0
}
)
Custom Time Grid#
Provide your own time steps:
import jax.numpy as jnp
custom_times = jnp.linspace(1, 0, 50)
samples = sampler.sample(
rng, interface, x,
custom_timegrid=custom_times
)
Sampler Configuration#
All samplers support various configuration options:
sampler = HeunSampler(
num_sampling_steps=32,
sampling_time_dist=SamplingTimeDistType.UNIFORM,
sampling_time_kwargs={
't_start': 1.0,
't_end': 0.0,
't_shift_base': 4096,
't_shift_cur': 4096
}
)
Advanced Usage#
Custom Samplers#
You can create custom samplers by extending the base samplers.samplers.Samplers class.
from samplers.samplers import Samplers, SamplingTimeDistType
from flax import nnx
import jax.numpy as jnp
class CustomSampler(Samplers):
"""Custom sampler implementation."""
def __init__(
self,
num_sampling_steps: int,
sampling_time_dist: SamplingTimeDistType = SamplingTimeDistType.UNIFORM,
sampling_time_kwargs: dict = {},
custom_param: float = 1.0
):
super().__init__(num_sampling_steps, sampling_time_dist, sampling_time_kwargs)
self.custom_param = custom_param
def forward(
self, rng, net: nnx.Module, x: jnp.ndarray,
t_curr: jnp.ndarray, t_next: jnp.ndarray,
g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
**net_kwargs
) -> jnp.ndarray:
"""Implement your custom forward step."""
# Get prediction from network
...
return x_next
def last_step(
self, rng, net: nnx.Module, x: jnp.ndarray,
t_curr: jnp.ndarray, t_last: jnp.ndarray,
g_net: nnx.Module | None = None, guidance_scale: float = 1.0,
**net_kwargs
) -> jnp.ndarray:
"""Implement your custom final step."""
# Simple final step - can be more sophisticated
...
return x_final
Integration with Configuration System#
To integrate your custom sampler with the configuration system:
Add to Registry: Update utils/initialize.py:
# In utils/initialize.py
SAMPLER_REGISTRY = {
'euler': samplers.EulerSampler,
'heun': samplers.HeunSampler,
'euler-maruyama': samplers.EulerMaruyamaSampler,
'custom': your_module.CustomSampler, # Add your sampler
}
Use in Configuration: Reference in your config files:
config.sampler = {
'sampler_class': 'custom',
'num_sampling_steps': 32,
'sampling_time_dist': 'uniform',
'custom_param': 0.5
}