Samplers API#

The samplers module provides various sampling strategies for diffusion models.

File containing samplers. Samplers are made model / interface agnostic.

class samplers.samplers.Samplers(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#

Bases: ABC

Base class for all samplers.

All samplers should support:
  • Sample discretized timegrid t

  • A single forward step in integration

__init__(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#
abstract forward(net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs)[source]#

A single forward step in integration.

Parameters:
  • net (-) – network to integrate vector field with.

  • x (-) – current state.

  • t_curr (-) – current time step.

  • t_next (-) – next time step.

  • g_net (-) – guidance network.

  • guidance_scale (-) – scale of guidance.

  • net_kwargs (-) – extra net args.

Returns:

x_next, next state.

Return type:

  • jnp.ndarray

abstract last_step(net: Module, x: Array, t_curr: Array, t_last: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs)[source]#

Last step in integration.

This interface is exposed since lots of samplers have special treatment for the last step:
  • Heun: last step is one first order Euler step.

  • Stochastic: last step returns the expected marginal value.

Parameters:
  • net (-) – network to integrate vector field with.

  • x (-) – current state.

  • t_curr (-) – current time step.

  • t_last (-) – last time step. Note: model is never evaluated at this step.

  • g_net (-) – guidance network.

  • guidance_scale (-) – scale of guidance.

  • net_kwargs (-) – extra net args.

Returns:

x_last, final state.

Return type:

  • jnp.ndarray

sample_t(steps: int) Array[source]#

Sampling time grid.

Parameters:

steps (-) – number of steps.

Returns:

t, time grid.

Return type:

  • jnp.ndarray

sample(rng, net: Module, x: Array, g_net: Module | None = None, guidance_scale: float = 1.0, num_sampling_steps: int | None = None, custom_timegrid: Array | None = None, **net_kwargs) Array[source]#

Main sample loop

Parameters:
  • rng (-) – random key for potentially stochastic samplers

  • net (-) – network to integrate vector field with.

  • x (-) – current state.

  • t (-) – current time.

  • g_net (-) – guidance network.

  • guidance_scale (-) – scale of guidance.

  • net_kwargs (-) – extra net args.

Returns:

x_final, final clean state.

Return type:

  • jnp.ndarray

get_default_sampling_kwargs(kwargs: dict, sampling_time_dist: SamplingTimeDistType) dict[source]#

Get default kwargs for sampling time distribution.

expand_right(x: Array | float, y: Array) Array[source]#

Expand x to match the batch dimension and broadcast x to the right to match the shape of y.

bcast_right(x: Array, y: Array) Array[source]#

Broadcast x to the right to match the shape of y.

class samplers.samplers.EulerSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#

Bases: Samplers

Euler Sampler.

First Order Deterministic Sampler.

forward(rng, net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs) Array[source]#

Euler step in integration.

\[x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i)\]
class samplers.samplers.EulerJumpSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#

Bases: EulerSampler

Euler Sampler that supports Jump with distilled models.

First Order Deterministic Sampler.

forward(rng, net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs) Array[source]#

Euler step with jump in integration.

\[x_{r} = x_{t} + (t - r) * f(x_{t}, t, r)\]
class samplers.samplers.HeunSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {})[source]#

Bases: Samplers

Heun Sampler.

Second Order Deterministic Sampler.

forward(rng, net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs) Array[source]#

Heun step in integration.

\[ \begin{align}\begin{aligned}\tilde{x}_{t_i} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i)\\x_{t_{i+1}} = x_{t_i} + \frac{t_{i+1} - t_i}{2} * (f(x_{t_i}, t_i) + f(\tilde{x}_{i_i}, t_{i+1}))\end{aligned}\end{align} \]
class samplers.samplers.EulerMaruyamaSampler(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, diffusion_coeff: DiffusionCoeffType | Callable[[Array], Array] = DiffusionCoeffType.LINEAR_KL, diffusion_coeff_norm: float = 1.0)[source]#

Bases: Samplers

EulerMaruyama Sampler.

First Order Stochastic Sampler.

__init__(num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, diffusion_coeff: DiffusionCoeffType | Callable[[Array], Array] = DiffusionCoeffType.LINEAR_KL, diffusion_coeff_norm: float = 1.0)[source]#
instantiate_diffusion_coeff(coeff: DiffusionCoeffType | Callable[[Array], Array], norm: float)[source]#

Instantiate the diffusion coefficient for SDE sampling.

Parameters:
  • diffusion_coeff (-) – the desired diffusion coefficient. If a Callable is passed in, directly returned;

  • settings. (otherwise instantiate the coefficient function based on our default)

  • norm (-) – the norm of the diffusion coefficient.

Returns:

diffusion_coeff_fn, w(t)

Return type:

  • Callable

forward(rng, net: Module, x: Array, t_curr: Array, t_next: Array, g_net: Module | None = None, guidance_scale: float = 1.0, **net_kwargs) Array[source]#

Euler-Maruyama step in integration.

\[x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i) + \sqrt{2 * w(t_i)} * \epsilon\]