Source code for samplers.samplers

"""File containing samplers. Samplers are made model / interface agnostic."""

# built-in libs
from abc import ABC, abstractmethod
import copy
from enum import Enum
import math
from typing import Callable

# external libs
import flax.linen as nn
from flax import nnx
import jax
import jax.numpy as jnp


class SamplingTimeDistType(Enum):
    """Class for Sampling Time Distribution Types.
    
    :meta private:
    """
    UNIFORM = 1
    EXP     = 2

    # TODO: Add more sampling time distribution types


DEFAULT_SAMPLING_TIME_KWARGS = {
    SamplingTimeDistType.UNIFORM: {
        't_start': 1.0,
        't_end': 0.0,
        't_shift_base': 4096,
        't_shift_cur': 4096
    },
    SamplingTimeDistType.EXP: {
        'sigma_min': 0.002,
        'sigma_max': 80.0,
        'rho': 7.0
    }
}


[docs] class Samplers(ABC): r"""Base class for all samplers. All samplers should support: - Sample discretized timegrid t - A single forward step in integration """
[docs] def __init__( self, num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, ): self.num_sampling_steps = num_sampling_steps if isinstance(sampling_time_dist, str): self.sampling_time_dist = SamplingTimeDistType[sampling_time_dist.replace('_', '').upper()] else: self.sampling_time_dist = sampling_time_dist self.sampling_time_kwargs = self.get_default_sampling_kwargs( sampling_time_kwargs, self.sampling_time_dist )
[docs] @abstractmethod def forward( self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ): r"""A single forward step in integration. Args: - net: network to integrate vector field with. - x: current state. - t_curr: current time step. - t_next: next time step. - g_net: guidance network. - guidance_scale: scale of guidance. - net_kwargs: extra net args. Return: - jnp.ndarray: x_next, next state. """
[docs] @abstractmethod def last_step( self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_last: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ): r"""Last step in integration. This interface is exposed since lots of samplers have special treatment for the last step: - Heun: last step is one first order Euler step. - Stochastic: last step returns the expected marginal value. Args: - net: network to integrate vector field with. - x: current state. - t_curr: current time step. - t_last: last time step. Note: model is never evaluated at this step. - g_net: guidance network. - guidance_scale: scale of guidance. - net_kwargs: extra net args. Return: - jnp.ndarray: x_last, final state. """
########## Sampling ##########
[docs] def sample_t(self, steps: int) -> jnp.ndarray: r"""Sampling time grid. Args: - steps: number of steps. Returns: - jnp.ndarray: t, time grid. """ if self.sampling_time_dist == SamplingTimeDistType.UNIFORM: t_start = self.sampling_time_kwargs['t_start'] t_end = self.sampling_time_kwargs['t_end'] t = jnp.linspace(t_start, t_end, steps) t_shift_base = self.sampling_time_kwargs['t_shift_base'] t_shift_cur = self.sampling_time_kwargs['t_shift_cur'] shift_ratio = math.sqrt(t_shift_cur / t_shift_base) return shift_ratio * t / (1 + (shift_ratio - 1) * t) elif self.sampling_time_dist == SamplingTimeDistType.EXP: # following aligns with EDM implementation step_indices = jnp.arange(steps) sigma_min = self.sampling_time_kwargs['sigma_min'] sigma_max = self.sampling_time_kwargs['sigma_max'] rho = self.sampling_time_kwargs['rho'] t_steps = ( sigma_max ** (1 / rho) + step_indices / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho # ensure last step is 0 return jnp.concatenate([t_steps, jnp.array([0.])]) else: raise ValueError(f"Sampling Time Distribution {self.sampling_time_dist} not supported.")
[docs] def sample( self, rng, net: nn.Module, x: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, num_sampling_steps: int | None = None, custom_timegrid: jnp.ndarray | None = None, **net_kwargs ) -> jnp.ndarray: r"""Main sample loop Args: - rng: random key for potentially stochastic samplers - net: network to integrate vector field with. - x: current state. - t: current time. - g_net: guidance network. - guidance_scale: scale of guidance. - net_kwargs: extra net args. Return: - jnp.ndarray: x_final, final clean state. """ if custom_timegrid is not None: timegrid = custom_timegrid elif num_sampling_steps is not None: # exposing this pathway for flexibility in sampling timegrid = self.sample_t(num_sampling_steps + 1) else: # if not provided, use the default number of sampling steps timegrid = self.sample_t(self.num_sampling_steps + 1) def _fn(carry, t_index): t_curr, t_next = timegrid[t_index], timegrid[t_index + 1] net, g_net, x_curr, rng = carry # rng, cur_rng = jax.random.split(rng) x_next = self.forward( rng, net, x_curr, t_curr, t_next, g_net, guidance_scale, **net_kwargs ) return (net, g_net, x_next, rng), x_next # (x_curr, _, rng), _ = jax.lax.scan(_fn, (x, timegrid[0], rng), timegrid[1:-1]) # lift scan to nnx.scan to capture the reference passed in from net & g_net # otherwise the rng state will leak since an global counter is maintained. (_, _, x_curr, rng), _ = nnx.scan( _fn, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0) )((net, g_net, x * timegrid[0], rng), jnp.arange(len(timegrid) - 2)) x_final = self.last_step(rng, net, x_curr, timegrid[-2], timegrid[-1], g_net, guidance_scale, **net_kwargs) return x_final
########## Helper Functions ##########
[docs] def get_default_sampling_kwargs(self, kwargs: dict, sampling_time_dist: SamplingTimeDistType) -> dict: """Get default kwargs for sampling time distribution.""" default_kwargs = copy.deepcopy(DEFAULT_SAMPLING_TIME_KWARGS[sampling_time_dist]) for key, value in default_kwargs.items(): if key in kwargs: # overwrite default value default_kwargs[key] = kwargs[key] return default_kwargs
[docs] def expand_right(self, x: jnp.ndarray | float, y: jnp.ndarray) -> jnp.ndarray: """ Expand x to match the batch dimension and broadcast x to the right to match the shape of y. """ if isinstance(x, jnp.ndarray): assert len(y.shape) >= x.ndim return jnp.ones((y.shape[0],)) * x
[docs] def bcast_right(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Broadcast x to the right to match the shape of y.""" assert len(y.shape) >= x.ndim return x.reshape(x.shape + (1,) * (len(y.shape) - x.ndim))
[docs] class EulerSampler(Samplers): r"""Euler Sampler. First Order Deterministic Sampler. """
[docs] def forward( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: r"""Euler step in integration. .. math:: x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i) """ del rng t_curr = self.expand_right(t_curr, x) net_out = net.pred(x, t_curr, **net_kwargs) if g_net is None: g_net = net # make uncond generation g_net_kwargs = { k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000) for k, v in net_kwargs.items() } def guided_fn(g_net, x, t): g_net_out = g_net.pred(x, t, **g_net_kwargs) # TODO: consider using different set of args for g_net return g_net_out + guidance_scale * (net_out - g_net_out) def unguided_fn(g_net, x, t): return net_out d_curr = nnx.cond( guidance_scale == 1., unguided_fn, guided_fn, g_net, x, t_curr ) dt = t_next - t_curr return x + d_curr * self.bcast_right(dt, d_curr)
def last_step( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: """:meta private:""" return self.forward(rng, net, x, t_curr, t_next, g_net, guidance_scale, **net_kwargs)
[docs] class EulerJumpSampler(EulerSampler): r"""Euler Sampler that supports Jump with distilled models. First Order Deterministic Sampler. """
[docs] def forward( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: r"""Euler step with jump in integration. .. math:: x_{r} = x_{t} + (t - r) * f(x_{t}, t, r) """ del rng t_curr = self.expand_right(t_curr, x) t_next = self.expand_right(t_next, x) net_out = net.pred(x, t_curr, r=t_next, **net_kwargs) dt = t_next - t_curr return x + net_out * self.bcast_right(dt, net_out)
[docs] class HeunSampler(Samplers): r"""Heun Sampler. Second Order Deterministic Sampler. """
[docs] def forward( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: r"""Heun step in integration. .. math:: \tilde{x}_{t_i} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i) x_{t_{i+1}} = x_{t_i} + \frac{t_{i+1} - t_i}{2} * (f(x_{t_i}, t_i) + f(\tilde{x}_{i_i}, t_{i+1})) """ del rng t_curr = self.expand_right(t_curr, x) net_out = net.pred(x, t_curr, **net_kwargs) if g_net is None: g_net = net # make uncond generation g_net_kwargs = { k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000) for k, v in net_kwargs.items() } def guided_fn(g_net, x, t): g_net_out = g_net.pred(x, t, **g_net_kwargs) # TODO: consider using different set of args for g_net return g_net_out + guidance_scale * (net_out - g_net_out) def unguided_fn(g_net, x, t): return net_out d_curr = nnx.cond( guidance_scale == 1., unguided_fn, guided_fn, g_net, x, t_curr ) dt = t_next - t_curr x_next = x + d_curr * self.bcast_right(dt, d_curr) t_next = self.expand_right(t_next, x) # Heun's Method d_next = nnx.cond( guidance_scale == 1., unguided_fn, guided_fn, g_net, x_next, t_next ) return x + 0.5 * self.bcast_right(dt, d_curr) * (d_curr + d_next)
def last_step( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: """:meta private:""" del rng # Heun's last step is one first order Euler step t_curr = self.expand_right(t_curr, x) net_out = net.pred(x, t_curr, **net_kwargs) g_net_kwargs = { k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000) for k, v in net_kwargs.items() } if g_net is None: g_net = net def guided_fn(x, t): # TODO: consider using different set of args for g_net g_net_out = g_net.pred(x, t, **g_net_kwargs) return g_net_out + guidance_scale * (net_out - g_net_out) def unguided_fn(x, t): return net_out d_curr = nnx.cond( guidance_scale == 1.0, unguided_fn, guided_fn, x, t_curr ) dt = t_next - t_curr return x + d_curr * self.bcast_right(dt, d_curr)
class DiffusionCoeffType(Enum): """Class for Sampling Time Distribution Types. :meta private: """ CONSTANT = 1 LINEAR_KL = 2 LINEAR = 3 COS = 4 CONCAVE = 5
[docs] class EulerMaruyamaSampler(Samplers): r"""EulerMaruyama Sampler. First Order Stochastic Sampler. """
[docs] def __init__( self, num_sampling_steps: int, sampling_time_dist: SamplingTimeDistType, sampling_time_kwargs: dict = {}, # below are args for stochastic samplers diffusion_coeff: DiffusionCoeffType | Callable[[jnp.ndarray], jnp.ndarray] = DiffusionCoeffType.LINEAR_KL, diffusion_coeff_norm: float = 1.0 ): super().__init__( num_sampling_steps, sampling_time_dist, sampling_time_kwargs ) self.diffusion_coeff_fn = self.instantiate_diffusion_coeff( diffusion_coeff, diffusion_coeff_norm )
[docs] def instantiate_diffusion_coeff( self, coeff: DiffusionCoeffType | Callable[[jnp.ndarray], jnp.ndarray], norm: float ): """Instantiate the diffusion coefficient for SDE sampling. Args: - diffusion_coeff: the desired diffusion coefficient. If a Callable is passed in, directly returned; otherwise instantiate the coefficient function based on our default settings. - norm: the norm of the diffusion coefficient. Returns: - Callable: diffusion_coeff_fn, w(t) """ if type(coeff) == Callable: return coeff choices = { DiffusionCoeffType.CONSTANT: lambda t: norm, DiffusionCoeffType.LINEAR_KL: lambda t: norm * (1 / (1 - t) * t**2 + t), DiffusionCoeffType.LINEAR: lambda t: norm * t, DiffusionCoeffType.COS: lambda t: 0.25 * (norm * jnp.cos(jnp.pi * t) + 1) ** 2, DiffusionCoeffType.CONCAVE: lambda t: 0.25 * (norm * jnp.sin(jnp.pi * t) + 1) ** 2, } try: fn = choices[coeff] except KeyError: raise ValueError(f"Diffusion coefficient function {coeff} not supported. Consider using custom functions.") return fn
def drift( self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, **net_kwargs ): """:meta private:""" tangent = net.pred(x, t_curr, **net_kwargs) score = net.score(x, t_curr, **net_kwargs) return tangent - 0.5 * self.bcast_right( self.diffusion_coeff_fn(t_curr), score ) * score
[docs] def forward( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: """Euler-Maruyama step in integration. .. math:: x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i) + \sqrt{2 * w(t_i)} * \epsilon """ t_curr = self.expand_right(t_curr, x) net_out = self.drift(net, x, t_curr, **net_kwargs) if g_net is None: g_net = net # make uncond generation g_net_kwargs = { k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000) for k, v in net_kwargs.items() } def guided_fn(x, t): # TODO: consider using different set of args for g_net g_net_out = self.drift(g_net, x, t, **g_net_kwargs) return g_net_out + guidance_scale * (net_out - g_net_out) def unguided_fn(x, t): return net_out d_curr = nnx.cond( guidance_scale == 1., unguided_fn, guided_fn, x, t_curr ) dt = t_next - t_curr x_mean = x + d_curr * self.bcast_right(dt, d_curr) wiener = jax.random.normal(rng(), x_mean.shape) * self.bcast_right( jnp.sqrt(jnp.abs(dt)), x_mean ) x = x_mean + self.bcast_right( jnp.sqrt(self.diffusion_coeff_fn(t_curr)), x_mean ) * wiener return x
def last_step( self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray, g_net: nn.Module | None = None, guidance_scale: float = 1.0, **net_kwargs ) -> jnp.ndarray: """:meta private:""" del rng t_curr = self.expand_right(t_curr, x) net_out = self.drift(net, x, t_curr, **net_kwargs) if g_net is None: g_net = net def guided_fn(x, t): # TODO: consider using different set of args for g_net g_net_out = self.drift(g_net, x, t, **net_kwargs) return g_net_out + guidance_scale * (net_out - g_net_out) def unguided_fn(x, t): return net_out d_curr = nnx.cond( guidance_scale == 1., unguided_fn, guided_fn, x, t_curr ) dt = t_next - t_curr return x + d_curr * self.bcast_right(dt, d_curr)
class EDMSampler(Samplers): r"""EDM Stochastic Sampler. Second Order Stochastic Sampler proposed in https://arxiv.org/abs/2206.00364 :meta private: """ pass