"""File containing samplers. Samplers are made model / interface agnostic."""
# built-in libs
from abc import ABC, abstractmethod
import copy
from enum import Enum
import math
from typing import Callable
# external libs
import flax.linen as nn
from flax import nnx
import jax
import jax.numpy as jnp
class SamplingTimeDistType(Enum):
    """Class for Sampling Time Distribution Types.
    
    :meta private:
    """
    UNIFORM = 1
    EXP     = 2
    # TODO: Add more sampling time distribution types
DEFAULT_SAMPLING_TIME_KWARGS = {
    SamplingTimeDistType.UNIFORM: {
        't_start': 1.0,
        't_end': 0.0,
        't_shift_base': 4096,
        't_shift_cur': 4096
    },
    SamplingTimeDistType.EXP: {
        'sigma_min': 0.002,
        'sigma_max': 80.0,
        'rho': 7.0
    }
}
[docs]
class Samplers(ABC):
    r"""Base class for all samplers.
    All samplers should support:
        - Sample discretized timegrid t
        - A single forward step in integration
    """
[docs]
    def __init__(
        self,
        num_sampling_steps: int,
        sampling_time_dist: SamplingTimeDistType,
        sampling_time_kwargs: dict = {},
    ):
        self.num_sampling_steps = num_sampling_steps
        if isinstance(sampling_time_dist, str):
            self.sampling_time_dist = SamplingTimeDistType[sampling_time_dist.replace('_', '').upper()]
        else:
            self.sampling_time_dist = sampling_time_dist
        self.sampling_time_kwargs = self.get_default_sampling_kwargs(
            sampling_time_kwargs, self.sampling_time_dist
        ) 
    
[docs]
    @abstractmethod
    def forward(
        self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ):
        r"""A single forward step in integration.
        Args:
            - net: network to integrate vector field with.
            - x: current state.
            - t_curr: current time step.
            - t_next: next time step.
            - g_net: guidance network.
            - guidance_scale: scale of guidance.
            - net_kwargs: extra net args.
        Return:
            - jnp.ndarray: x_next, next state.
        """ 
    
[docs]
    @abstractmethod
    def last_step(
        self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_last: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ):
        r"""Last step in integration. 
        This interface is exposed since lots of samplers have special treatment for the last step:
            - Heun: last step is one first order Euler step.
            - Stochastic: last step returns the expected marginal value.
        
        Args:
            - net: network to integrate vector field with.
            - x: current state.
            - t_curr: current time step.
            - t_last: last time step. Note: model is never evaluated at this step.
            - g_net: guidance network.
            - guidance_scale: scale of guidance.
            - net_kwargs: extra net args.
        Return:
            - jnp.ndarray: x_last, final state.
        """ 
    ########## Sampling ##########
[docs]
    def sample_t(self, steps: int) -> jnp.ndarray:
        r"""Sampling time grid.
        Args:
            - steps: number of steps.
        Returns:
            - jnp.ndarray: t, time grid.
        """
        if self.sampling_time_dist == SamplingTimeDistType.UNIFORM:
            t_start = self.sampling_time_kwargs['t_start']
            t_end = self.sampling_time_kwargs['t_end']
            t = jnp.linspace(t_start, t_end, steps)
            t_shift_base = self.sampling_time_kwargs['t_shift_base']
            t_shift_cur = self.sampling_time_kwargs['t_shift_cur']
            shift_ratio = math.sqrt(t_shift_cur / t_shift_base)
            return shift_ratio * t / (1 + (shift_ratio - 1) * t)
        elif self.sampling_time_dist == SamplingTimeDistType.EXP:
            # following aligns with EDM implementation
            step_indices = jnp.arange(steps)
            sigma_min = self.sampling_time_kwargs['sigma_min']
            sigma_max = self.sampling_time_kwargs['sigma_max']
            rho = self.sampling_time_kwargs['rho']
            t_steps = (
                sigma_max ** (1 / rho)
                +
                step_indices / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
            ) ** rho
            # ensure last step is 0
            return jnp.concatenate([t_steps, jnp.array([0.])])
        else:
            raise ValueError(f"Sampling Time Distribution {self.sampling_time_dist} not supported.") 
[docs]
    def sample(
        self, rng, net: nn.Module, x: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        num_sampling_steps: int | None = None,
        custom_timegrid: jnp.ndarray | None = None,
        **net_kwargs
    ) -> jnp.ndarray:
        r"""Main sample loop
        Args:
            - rng: random key for potentially stochastic samplers
            - net: network to integrate vector field with.
            - x: current state.
            - t: current time.
            - g_net: guidance network.
            - guidance_scale: scale of guidance.
            - net_kwargs: extra net args.
        Return:
            - jnp.ndarray: x_final, final clean state.
        """
        if custom_timegrid is not None:
            timegrid = custom_timegrid
        elif num_sampling_steps is not None:
            # exposing this pathway for flexibility in sampling
            timegrid = self.sample_t(num_sampling_steps + 1)
        else:
            # if not provided, use the default number of sampling steps
            timegrid = self.sample_t(self.num_sampling_steps + 1)
        def _fn(carry, t_index):
            t_curr, t_next = timegrid[t_index], timegrid[t_index + 1]
            net, g_net, x_curr, rng = carry
            # rng, cur_rng = jax.random.split(rng)
            x_next = self.forward(
                rng, net, x_curr, t_curr, t_next, g_net, guidance_scale, **net_kwargs
            )
            return (net, g_net, x_next, rng), x_next
        # (x_curr, _, rng), _ = jax.lax.scan(_fn, (x, timegrid[0], rng), timegrid[1:-1])
        # lift scan to nnx.scan to capture the reference passed in from net & g_net
        # otherwise the rng state will leak since an global counter is maintained.
        (_, _, x_curr, rng), _ = nnx.scan(
            _fn, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0)
        )((net, g_net, x * timegrid[0], rng), jnp.arange(len(timegrid) - 2))
        x_final = self.last_step(rng, net, x_curr, timegrid[-2], timegrid[-1], g_net, guidance_scale, **net_kwargs)
        return x_final 
    
    ########## Helper Functions ##########
[docs]
    def get_default_sampling_kwargs(self, kwargs: dict, sampling_time_dist: SamplingTimeDistType) -> dict:
        """Get default kwargs for sampling time distribution."""
        default_kwargs = copy.deepcopy(DEFAULT_SAMPLING_TIME_KWARGS[sampling_time_dist])
        for key, value in default_kwargs.items():
            if key in kwargs:
                # overwrite default value
                default_kwargs[key] = kwargs[key]
        
        return default_kwargs 
                
[docs]
    def expand_right(self, x: jnp.ndarray | float, y: jnp.ndarray) -> jnp.ndarray:
        """
            Expand x to match the batch dimension
            and broadcast x to the right to match the shape of y.
        """
        if isinstance(x, jnp.ndarray):
            assert len(y.shape) >= x.ndim
        return jnp.ones((y.shape[0],)) * x 
[docs]
    def bcast_right(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        """Broadcast x to the right to match the shape of y."""
        assert len(y.shape) >= x.ndim
        return x.reshape(x.shape + (1,) * (len(y.shape) - x.ndim)) 
 
    
[docs]
class EulerSampler(Samplers):
    r"""Euler Sampler.
    First Order Deterministic Sampler.
    """
[docs]
    def forward(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        r"""Euler step in integration.
        .. math::
            x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i)
        """
        del rng
        t_curr = self.expand_right(t_curr, x)
        net_out = net.pred(x, t_curr, **net_kwargs)
        if g_net is None:
            g_net = net
        
        # make uncond generation
        g_net_kwargs = {
            k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000)
            for k, v in net_kwargs.items()
        }
        def guided_fn(g_net, x, t):
            g_net_out = g_net.pred(x, t, **g_net_kwargs)
            # TODO: consider using different set of args for g_net
            return g_net_out + guidance_scale * (net_out - g_net_out)
        def unguided_fn(g_net, x, t):
            return net_out
        
        d_curr = nnx.cond(
            guidance_scale == 1., unguided_fn, guided_fn, g_net, x, t_curr
        )
        dt = t_next - t_curr
        return x + d_curr * self.bcast_right(dt, d_curr) 
    
    def last_step(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """:meta private:"""
        return self.forward(rng, net, x, t_curr, t_next, g_net, guidance_scale, **net_kwargs) 
[docs]
class EulerJumpSampler(EulerSampler):
    r"""Euler Sampler that supports Jump with distilled models.
    First Order Deterministic Sampler.
    """
[docs]
    def forward(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        r"""Euler step with jump in integration.
        .. math::
            x_{r} = x_{t} + (t - r) * f(x_{t}, t, r)
        """
        del rng
        t_curr = self.expand_right(t_curr, x)
        t_next = self.expand_right(t_next, x)
        net_out = net.pred(x, t_curr, r=t_next, **net_kwargs)
        dt = t_next - t_curr
        return x + net_out * self.bcast_right(dt, net_out) 
 
[docs]
class HeunSampler(Samplers):
    r"""Heun Sampler.
    Second Order Deterministic Sampler.
    """
    
[docs]
    def forward(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        r"""Heun step in integration.
        .. math::
            \tilde{x}_{t_i} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i)
            x_{t_{i+1}} = x_{t_i} + \frac{t_{i+1} - t_i}{2} * (f(x_{t_i}, t_i) + f(\tilde{x}_{i_i}, t_{i+1}))
        """
        del rng
        t_curr = self.expand_right(t_curr, x)
        net_out = net.pred(x, t_curr, **net_kwargs)
        if g_net is None:
            g_net = net
        
        # make uncond generation
        g_net_kwargs = {
            k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000)
            for k, v in net_kwargs.items()
        }
        
        def guided_fn(g_net, x, t):
            g_net_out = g_net.pred(x, t, **g_net_kwargs)
            # TODO: consider using different set of args for g_net
            return g_net_out + guidance_scale * (net_out - g_net_out)
        def unguided_fn(g_net, x, t):
            return net_out
        
        d_curr = nnx.cond(
            guidance_scale == 1., unguided_fn, guided_fn, g_net, x, t_curr
        )
        dt = t_next - t_curr
        x_next = x + d_curr * self.bcast_right(dt, d_curr)
        t_next = self.expand_right(t_next, x)
        # Heun's Method
        d_next = nnx.cond(
            guidance_scale == 1., unguided_fn, guided_fn, g_net, x_next, t_next
        )
        return x + 0.5 * self.bcast_right(dt, d_curr) * (d_curr + d_next) 
    
    def last_step(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """:meta private:"""
        del rng
        # Heun's last step is one first order Euler step
        t_curr = self.expand_right(t_curr, x)
        net_out = net.pred(x, t_curr, **net_kwargs)
        g_net_kwargs = {
            k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000)
            for k, v in net_kwargs.items()
        }
        if g_net is None:
            g_net = net
        
        def guided_fn(x, t):
            # TODO: consider using different set of args for g_net
            g_net_out = g_net.pred(x, t, **g_net_kwargs)
            return g_net_out + guidance_scale * (net_out - g_net_out)
        def unguided_fn(x, t):
            return net_out
        
        d_curr = nnx.cond(
            guidance_scale == 1.0, unguided_fn, guided_fn, x, t_curr
        )
        dt = t_next - t_curr
        return x + d_curr * self.bcast_right(dt, d_curr) 
    
class DiffusionCoeffType(Enum):
    """Class for Sampling Time Distribution Types.
    
    :meta private:
    """
    CONSTANT  = 1
    LINEAR_KL = 2
    LINEAR    = 3
    COS       = 4
    CONCAVE   = 5
[docs]
class EulerMaruyamaSampler(Samplers):
    r"""EulerMaruyama Sampler.
    
    First Order Stochastic Sampler.
    """
[docs]
    def __init__(
        self,
        num_sampling_steps: int,
        sampling_time_dist: SamplingTimeDistType,
        sampling_time_kwargs: dict = {},
        # below are args for stochastic samplers
        diffusion_coeff: DiffusionCoeffType | Callable[[jnp.ndarray], jnp.ndarray] = DiffusionCoeffType.LINEAR_KL,
        diffusion_coeff_norm: float = 1.0
    ):
        super().__init__(
            num_sampling_steps,
            sampling_time_dist,
            sampling_time_kwargs
        )
        self.diffusion_coeff_fn = self.instantiate_diffusion_coeff(
            diffusion_coeff, diffusion_coeff_norm
        ) 
[docs]
    def instantiate_diffusion_coeff(
        self, coeff: DiffusionCoeffType | Callable[[jnp.ndarray], jnp.ndarray], norm: float
    ):
        """Instantiate the diffusion coefficient for SDE sampling.
        
        Args:
            - diffusion_coeff: the desired diffusion coefficient. If a Callable is passed in, directly returned;
            otherwise instantiate the coefficient function based on our default settings.
            - norm: the norm of the diffusion coefficient.
        Returns:
            - Callable: diffusion_coeff_fn, w(t)
        """
        if type(coeff) == Callable:
            return coeff
        choices = {
            DiffusionCoeffType.CONSTANT:  lambda t: norm,
            DiffusionCoeffType.LINEAR_KL: lambda t: norm * (1 / (1 - t) * t**2 + t),
            DiffusionCoeffType.LINEAR:    lambda t: norm * t,
            DiffusionCoeffType.COS:       lambda t: 0.25 * (norm * jnp.cos(jnp.pi * t) + 1) ** 2,
            DiffusionCoeffType.CONCAVE:   lambda t: 0.25 * (norm * jnp.sin(jnp.pi * t) + 1) ** 2,
        }
        try:
            fn = choices[coeff]
        except KeyError:
            raise ValueError(f"Diffusion coefficient function {coeff} not supported. Consider using custom functions.")
        
        return fn 
    def drift(
        self, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, **net_kwargs
    ):
        """:meta private:"""
        tangent = net.pred(x, t_curr, **net_kwargs)
        score = net.score(x, t_curr, **net_kwargs)
        return tangent - 0.5 * self.bcast_right(
            self.diffusion_coeff_fn(t_curr), score
        ) * score
[docs]
    def forward(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """Euler-Maruyama step in integration.
        .. math::
            x_{t_{i+1}} = x_{t_i} + (t_{i+1} - t_i) * f(x_{t_i}, t_i) + \sqrt{2 * w(t_i)} * \epsilon
        """
        t_curr = self.expand_right(t_curr, x)
        
        net_out = self.drift(net, x, t_curr, **net_kwargs)
        if g_net is None:
            g_net = net
        
        # make uncond generation
        g_net_kwargs = {
            k: (v if k != 'y' else jnp.ones_like(v, dtype=jnp.int32) * 1000)
            for k, v in net_kwargs.items()
        }
        
        def guided_fn(x, t):
            # TODO: consider using different set of args for g_net
            g_net_out = self.drift(g_net, x, t, **g_net_kwargs)
            return g_net_out + guidance_scale * (net_out - g_net_out)
        def unguided_fn(x, t):
            return net_out
        
        d_curr = nnx.cond(
            guidance_scale == 1., unguided_fn, guided_fn, x, t_curr
        )
        dt = t_next - t_curr
        x_mean = x + d_curr * self.bcast_right(dt, d_curr)
        wiener = jax.random.normal(rng(), x_mean.shape) * self.bcast_right(
            jnp.sqrt(jnp.abs(dt)), x_mean
        )
        x = x_mean + self.bcast_right(
            jnp.sqrt(self.diffusion_coeff_fn(t_curr)), x_mean
        ) * wiener
        return x 
    def last_step(
        self, rng, net: nn.Module, x: jnp.ndarray, t_curr: jnp.ndarray, t_next: jnp.ndarray,
        g_net: nn.Module | None = None, guidance_scale: float = 1.0,
        **net_kwargs
    ) -> jnp.ndarray:
        """:meta private:"""
        del rng
        t_curr = self.expand_right(t_curr, x)
        net_out = self.drift(net, x, t_curr, **net_kwargs)
        if g_net is None:
            g_net = net
        
        def guided_fn(x, t):
            # TODO: consider using different set of args for g_net
            g_net_out = self.drift(g_net, x, t, **net_kwargs)
            return g_net_out + guidance_scale * (net_out - g_net_out)
        def unguided_fn(x, t):
            return net_out
        
        d_curr = nnx.cond(
            guidance_scale == 1., unguided_fn, guided_fn, x, t_curr
        )
        dt = t_next - t_curr
        return x + d_curr * self.bcast_right(dt, d_curr) 
class EDMSampler(Samplers):
    r"""EDM Stochastic Sampler.
    
    Second Order Stochastic Sampler proposed in https://arxiv.org/abs/2206.00364
    :meta private:
    """
    pass