Samplers API#
The samplers module provides various sampling strategies for diffusion models.
File containing samplers. Samplers are made model / interface agnostic.
- class samplers.samplers.Samplers(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
Bases:
ABCBase class for all samplers.
- All samplers should support:
Sample discretized timegrid t
A single forward step in integration
- __init__(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
- abstract forward(net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs)[source]#
A single forward step in integration.
- Parameters:
net (-) – network to integrate vector field with.
x (-) – current state.
t_curr (-) – current time step.
t_next (-) – next time step.
g_net (-) – guidance network.
guidance_scale (-) – scale of guidance.
net_kwargs (-) – extra net args.
- Returns:
x_next, next state.
- Return type:
jnp.ndarray
- abstract last_step(net: Module, x: Array, t_curr: Array, t_last: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs)[source]#
Last step in integration.
- This interface is exposed since lots of samplers have special treatment for the last step:
Heun: last step is one first order Euler step.
Stochastic: last step returns the expected marginal value.
- Parameters:
net (-) – network to integrate vector field with.
x (-) – current state.
t_curr (-) – current time step.
t_last (-) – last time step. Note: model is never evaluated at this step.
g_net (-) – guidance network.
guidance_scale (-) – scale of guidance.
net_kwargs (-) – extra net args.
- Returns:
x_last, final state.
- Return type:
jnp.ndarray
- sample_t(steps: int) Array[source]#
Sampling time grid.
- Parameters:
steps (-) – number of steps.
- Returns:
t, time grid.
- Return type:
jnp.ndarray
- sample(rng, net: Module, x: Array, g_net: Module | None = None, guidance_scale: float = 1.0, num_sampling_steps: int | None = None, custom_timegrid: Array | None = None, **net_kwargs) Array[source]#
Main sample loop
- Parameters:
rng (-) – random key for potentially stochastic samplers
net (-) – network to integrate vector field with.
x (-) – current state.
t (-) – current time.
g_net (-) – guidance network.
guidance_scale (-) – scale of guidance.
net_kwargs (-) – extra net args.
- Returns:
x_final, final clean state.
- Return type:
jnp.ndarray
- get_default_sampling_kwargs(kwargs: dict, sampling_time_dist: SamplingTimeDistType) dict[source]#
Get default kwargs for sampling time distribution.
- class samplers.samplers.EulerSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
Bases:
SamplersEuler Sampler.
First Order Deterministic Sampler.
- class samplers.samplers.EulerJumpSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
Bases:
EulerSamplerEuler Sampler that supports Jump with distilled models.
First Order Deterministic Sampler.
- class samplers.samplers.HeunSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
Bases:
SamplersHeun Sampler.
Second Order Deterministic Sampler.
- forward(rng, net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs) Array[source]#
Heun step in integration.
\[ \begin{align}\begin{aligned}\tilde{x}_{t_i} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i)\\x_{t_{i+1}} = x_{t_i} + \frac{t_{i+1} - t_i}{2} * (f(x_{t_i}, t_i) + f(\tilde{x}_{i_i}, t_{i+1}))\end{aligned}\end{align} \]
- class samplers.samplers.EulerMaruyamaSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, diffusion_coeff: DiffusionCoeffType | Callable[[Array], Array] = DiffusionCoeffType.LINEAR_KL, diffusion_coeff_norm: float = 1.0)[source]#
Bases:
SamplersEulerMaruyama Sampler.
First Order Stochastic Sampler.
- __init__(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, diffusion_coeff: DiffusionCoeffType | Callable[[Array], Array] = DiffusionCoeffType.LINEAR_KL, diffusion_coeff_norm: float = 1.0)[source]#
- instantiate_diffusion_coeff(coeff: DiffusionCoeffType | Callable[[Array], Array], norm: float)[source]#
Instantiate the diffusion coefficient for SDE sampling.
- Parameters:
diffusion_coeff (-) – the desired diffusion coefficient. If a Callable is passed in, directly returned;
settings. (otherwise instantiate the coefficient function based on our default)
norm (-) – the norm of the diffusion coefficient.
- Returns:
diffusion_coeff_fn, w(t)
- Return type:
Callable