# built-in libs
from abc import ABC, abstractmethod
import dataclasses
from enum import Enum
import math
# external libs
import flax
from flax import nnx
import jax
import jax.numpy as jnp
class TrainingTimeDistType(Enum):
    """Class for Training Time Distribution Types.
    
    :meta private:
    """
    UNIFORM = 1
    LOGNORMAL = 2
    LOGITNORMAL = 3
    # TODO: Add more training time distribution types
[docs]
class Interfaces(nnx.Module, ABC):
    r"""
    Base class for all diffusion / flow matching interfaces.
    
    All interfaces be a wrapper around network backbone and should support:
        - Define the pre-conditionings (see EDM)
        - Calculate losses for training
            - Define transport path (\alpha_t & \sigma_t)
            - Sample t
            - Sample X_t
        - Give tangent for sampling
    Required RNG Key:
        - time: for sampling t
        - noise: for sampling n
    """
[docs]
    def __init__(self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType):
        self.network = network
        if isinstance(train_time_dist_type, str):
            self.train_time_dist_type = TrainingTimeDistType[train_time_dist_type.replace('_', '').upper()]
        else:
            self.train_time_dist_type = train_time_dist_type 
    
[docs]
    @abstractmethod
    def c_in(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""Calculate c_in for the interface.
        
        Args:
            t: current timestep.
        Returns:
            jnp.ndarray: c_in, c_in for the interface.
        """ 
    
[docs]
    @abstractmethod
    def c_out(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""Calculate c_out for the interface.
        
        Args:
            t: current timestep.
        Returns:
            jnp.ndarray: c_out, c_out for the interface.
        """ 
    
[docs]
    @abstractmethod
    def c_skip(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""Calculate c_skip for the interface.
        
        Args:
            t: current timestep.
        Returns:
            jnp.ndarray: c_skip, c_skip for the interface.
        """ 
    
[docs]
    @abstractmethod
    def c_noise(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""Calculate c_noise for the interface.
        
        Args:
            t: current timestep.
        Returns:
            jnp.ndarray: c_noise, c_noise for the interface.
        """ 
[docs]
    @abstractmethod
    def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray:
        r"""Sample t from the training time distribution.
        
        Args:
            shape: shape of timestep t.
        
        Returns:
            jnp.ndarray: t, sampled timestep t.
        """ 
    
[docs]
    @abstractmethod
    def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray:
        r"""Sample noises.
        
        Args:
            shape: shape of noise.
        Returns:
            jnp.ndarray: n, sampled noise.
        """ 
        # Exposing this function to the interface allows for more flexibility in noise sampling
[docs]
    @abstractmethod
    def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        r"""Sample X_t according to the defined interface.
        
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
        Returns:
            jnp.ndarray: x_t, sampled X_t according to transport path.
        """ 
[docs]
    @abstractmethod
    def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        r"""Get training target.
        
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
        Returns:
            jnp.ndarray: target, training target.
        """ 
[docs]
    @abstractmethod
    def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Predict ODE tangent according to the defined interface.
        
        Args:
            x_t: input noisy sample.
            t: current timestep.
            
        Returns:
            jnp.ndarray: tangent, predicted ODE tangent.
        """ 
    
[docs]
    @abstractmethod
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Transform ODE tangent to the Score Function \nabla \log p_t(x).
        
        Args:
            x_t: input noisy sample.
            t: current timestep.
            
        Returns:
            jnp.ndarray: score, score function \nabla \log p_t(x).
        """ 
    
[docs]
    @abstractmethod
    def loss(self, x: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Calculate loss for training.
        
        Args:
            x: input clean sample.
            args: additional arguments for network forward.
            kwargs: additional keyword arguments for network forward.
        Returns:
            jnp.ndarray: loss, calculated loss.
        """ 
    def __call__(self, x: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        return self.loss(x, *args, **kwargs)
    
    ########## Helper Functions ##########
[docs]
    @staticmethod
    def mean_flat(x: jnp.ndarray) -> jnp.ndarray:
        r"""Take mean w.r.t. all dimensions of x except the first.
        
        Args:
            x: input array.
            
        Returns:
            jnp.ndarray: mean, mean across all dimensions except the first.
        """
        return jnp.mean(x, axis=list(range(1, x.ndim))) 
    
[docs]
    @staticmethod
    def bcast_right(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        r"""Broadcast x to the right to match the shape of y.
        
        Args:
            x: array to broadcast.
            y: target array to match shape.
            
        Returns:
            jnp.ndarray: broadcasted, x broadcasted to match y's shape.
        """
        assert len(y.shape) >= x.ndim
        return x.reshape(x.shape + (1,) * (len(y.shape) - x.ndim)) 
    
[docs]
    @staticmethod
    def t_shift(t: jnp.ndarray, shift: float) -> jnp.ndarray:
        r"""Shift t by a constant shift value.
        
        Args:
            t: input timestep array.
            shift: shift value.
            
        Returns:
            jnp.ndarray: shifted_t, t shifted by the shift value.
        """
        return shift * t / (1 + (shift - 1) * t) 
 
[docs]
class SiTInterface(Interfaces):
    r"""Interface for SiT.
    
    Transport path:
    .. math::
        
        x_t = (1 - t) * x + t * n
    Losses:
    .. math::
        L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2
    Predictions:
    .. math::
        x = xt - t * D(x_t, t)
    """
[docs]
    def __init__(
        self, network: nnx.Module, train_time_dist_type:  str | TrainingTimeDistType,
        t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5,
        t_shift_base: int = 4096,
    ):
        super().__init__(network, train_time_dist_type)
        self.t_mu = t_mu
        self.t_sigma = t_sigma
        self.n_mu = n_mu
        self.n_sigma = n_sigma
        self.x_sigma = x_sigma
        self.t_shift_base = t_shift_base 
[docs]
    def c_in(self, t: jnp.ndarray) -> jnp.ndarray:
        """Flow matching preconditioning.
        
        .. math::
            c_{in} = 1
        """
        # return 1 / jnp.sqrt((1 - t) ** 2 * self.x_sigma ** 2 + t ** 2)
        return jnp.ones_like(t) 
    
[docs]
    def c_out(self, t: jnp.ndarray) -> jnp.ndarray:
        """Flow matching preconditioning.
        
        .. math::
            c_{out} = 1
        """
        return jnp.ones_like(t) 
    
[docs]
    def c_skip(self, t: jnp.ndarray) -> jnp.ndarray:
        """Flow matching preconditioning.
        
        .. math::
            c_{skip} = 0
        """
        return jnp.zeros_like(t) 
[docs]
    def c_noise(self, t: jnp.ndarray) -> jnp.ndarray:
        """Flow matching preconditioning.
        
        .. math::
            c_{noise} = t
        """
        return t 
    def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray:
        """:meta private:"""
        rng = self.network.rngs.time()
        if self.train_time_dist_type == TrainingTimeDistType.UNIFORM:
            return jax.random.uniform(rng, shape=shape)
        elif self.train_time_dist_type == TrainingTimeDistType.LOGITNORMAL:
            return jax.nn.sigmoid(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu)
        else:
            raise ValueError(f"Training Time Distribution Type {self.train_time_dist_type} not supported.")
    
    def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray:
        """:meta private:"""
        # rng = self.make_rng('noise')
        rng = self.network.rngs.noise()
        return jax.random.normal(rng, shape=shape) * self.n_sigma + self.n_mu
    
[docs]
    def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        """Sample x_t defined by flow matching.
        
        .. math::
            x_t = (1 - t) * x + t * n
            
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
            
        Returns:
            jnp.ndarray: x_t, sampled x_t according to flow matching.
        """
        t = self.bcast_right(t, x)
        return (1 - t) * x + t * n 
    
[docs]
    def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        """Return flow matching target
        .. math::
            v = n - x
            
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
            
        Returns:
            jnp.ndarray: v, flow matching target.
        """
        return n - x 
    
[docs]
    def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        """Predict flow matching tangent.
        
        .. math::
            v = D(x_t, t)
            
        Args:
            x_t: input noisy sample.
            t: current timestep.
            *args: additional arguments for network forward.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray: v, predicted flow matching tangent.
        """
        return self.network(
            (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, **kwargs
        )[0] 
    
[docs]
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Transform flow matching tangent to the score function.
        
        .. math::
            \nabla \log p_t(x) = -x_t - (1 - t) * D(x_t, t)
            
        Args:
            x_t: input noisy sample.
            t: current timestep.
            *args: additional arguments for network forward.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray: score, score function \nabla \log p_t(x).
        """
        tangent = self.pred(x_t, t, *args, **kwargs)
        t = self.bcast_right(t, x_t)
        return -(x_t + (1 - t) * tangent) / t 
    
[docs]
    def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray:
        r"""Calculate flow matching loss.
        
        .. math::
            L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2
            
        Args:
            x: input clean sample.
            *args: additional arguments for network forward.
            return_aux: whether to return auxiliary outputs.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray or tuple: loss, calculated loss (or tuple with aux outputs if return_aux=True).
        """
        t = self.sample_t((x.shape[0],))
        t = self.t_shift(t, math.sqrt(math.prod(x.shape[1:]) / self.t_shift_base))
        n = self.sample_n(x.shape)
        x_t = self.sample_x_t(x, n, t)
        target = self.target(x, n, t)
        net_out, features = self.network(
            (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, **kwargs
        )
        if return_aux:
            # specifically for auxiliary loss wrappers
            return self.mean_flat((net_out - target) ** 2), net_out, features
        else:
            return {
                'loss': self.mean_flat((net_out - target) ** 2)
            } 
 
[docs]
class EDMInterface(Interfaces):
    r"""Interface for EDM.
    
    Transport Path:
    .. math::
        x_t = x + t * n
    Losses:
    .. math::
        L =  \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2
    Predictions:
        
    .. math::
       x = D(x_t, t)
    """
[docs]
    def __init__(
        self, network: nnx.Module, train_time_dist_type:  str | TrainingTimeDistType,
        t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5
    ):
        super().__init__(network, train_time_dist_type)
        self.t_mu = t_mu
        self.t_sigma = t_sigma
        self.n_mu = n_mu
        self.n_sigma = n_sigma
        self.x_sigma = x_sigma 
[docs]
    def c_in(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""EDM preconditioning.
        
        .. math::
            c_{in} = 1 / \sqrt{x_sigma ^ 2 + t ^ 2}
        """
        return 1 / jnp.sqrt(self.x_sigma ** 2 + t ** 2) 
    
[docs]
    def c_out(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""EDM preconditioning.
        
        .. math::
            c_{out} = t * x_sigma / \sqrt{t ^ 2 + x_sigma ^ 2}
        """
        return t * self.x_sigma / jnp.sqrt(t ** 2 + self.x_sigma ** 2) 
    
[docs]
    def c_skip(self, t) -> jnp.ndarray:
        r"""EDM preconditioning.
        
        .. math::
            c_{skip} = x_sigma ^ 2 / (t ^ 2 + x_sigma ^ 2)
        """
        return self.x_sigma ** 2 / (t ** 2 + self.x_sigma ** 2) 
[docs]
    def c_noise(self, t: jnp.ndarray) -> jnp.ndarray:
        r"""EDM preconditioning.
        
        .. math::
            c_{noise} = \log(t) / 4
        """
        return jnp.log(t) / 4 
    def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray:
        """:meta private:"""
        rng = self.network.rngs.time()
        if self.train_time_dist_type == TrainingTimeDistType.UNIFORM:
            return jax.random.uniform(rng, shape=shape)
        elif self.train_time_dist_type == TrainingTimeDistType.LOGNORMAL:
            return jnp.exp(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu)
        elif self.train_time_dist_type == TrainingTimeDistType.LOGITNORMAL:
            return jax.nn.sigmoid(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu)
        else:
            raise ValueError(f"Training Time Distribution Type {self.train_time_dist_type} not supported.")
    
    def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray:
        """:meta private:"""
        rng = self.network.rngs.noise()
        return jax.random.normal(rng, shape=shape) * self.n_sigma + self.n_mu
    
[docs]
    def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        r"""Sample x_t defined by EDM.
        
        .. math::
            x_t = x + t * n
            
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
            
        Returns:
            jnp.ndarray: x_t, sampled x_t according to EDM.
        """
        return x + self.bcast_right(t, n) * n 
    
[docs]
    def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        r"""Return EDM target.
        
        .. math::
            target = x
            
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
            
        Returns:
            jnp.ndarray: target, EDM target.
        """
        return x 
    
[docs]
    def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Predict EDM tangent.
        
        .. math::
            v = (x_t - D(x_t, t)) / t
            
        Args:
            x_t: input noisy sample.
            t: current timestep.
            *args: additional arguments for network forward.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray: v, predicted EDM tangent.
        """
        F_x = self.network((self.bcast_right(self.c_in(t), x_t) * x_t), self.c_noise(t), *args, **kwargs)[0]
        D_x = self.bcast_right(self.c_skip(t), x_t) * x_t + self.bcast_right(self.c_out(t), F_x) * F_x
        return (x_t - D_x) / self.bcast_right(t, x_t) 
    
[docs]
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Transform EDM tangent to the score function.
        
        .. math::
            \nabla \log p_t(x) = -(x_t - v) / t ^ 2
            
        Args:
            x_t: input noisy sample.
            t: current timestep.
            *args: additional arguments for network forward.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray: score, score function \nabla \log p_t(x).
        """
        tangent = self.pred(x_t, t, *args, **kwargs)
        t = self.bcast_right(t, x_t)
        return -(x_t - tangent) / (t ** 2) 
    
[docs]
    def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray:
        r"""Calculate EDM loss.
        
        .. math::
            L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2
            
        Args:
            x: input clean sample.
            *args: additional arguments for network forward.
            return_aux: whether to return auxiliary outputs.
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            jnp.ndarray or tuple: loss, calculated loss (or tuple with aux outputs if return_aux=True).
        """
        sigma = self.sample_t((x.shape[0],))
        n = self.sample_n(x.shape)
        x_t = self.sample_x_t(x, n, sigma)
        target = self.target(x, n, sigma)
        F_x, features = self.network((self.bcast_right(self.c_in(sigma), x_t) * x_t), self.c_noise(sigma), *args, **kwargs)
        D_x = self.bcast_right(self.c_skip(sigma), x_t) * x_t + self.bcast_right(self.c_out(sigma), F_x) * F_x
        weight = (sigma ** 2 + self.x_sigma ** 2) / (sigma * self.x_sigma) ** 2
        if return_aux:
            # specifically for auxiliary loss wrappers
            return self.mean_flat(self.bcast_right(weight, D_x) * (D_x - target) ** 2), D_x, features
        else:
            return {
                'loss': self.mean_flat(self.bcast_right(weight, D_x) * (D_x - target) ** 2)
            } 
 
class sCTInterface(EDMInterface):
    r"""Interface for CM.
    
    Transport Path:
    .. math::
        x_t = x + t * n
    
    Losses:
    .. math::
        L =  \mathbb{E} \Vert f_{t - 1} - f_{t} \Vert ^ 2
    Predictions:
    .. math::
        x = f(x_t, t)
    
    :meta private:
    """
    def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
        """Get the effective training target for sCT."""
        pass
    def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        """Predict the average velocity from noise to data."""
        pass
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        raise ValueError("sCTInterface does not support score calculation.")
    
    def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray:
        """Calculate the sCT loss."""
        pass
class sCDInterface(sCTInterface):
    r"""Interface for CM.
    
    Transport Path:
    .. math::
        x_t = x + t * n
    
    Losses:
    .. math::
        L =  \mathbb{E} \Vert f_{t - 1} - f_{t} \Vert ^ 2
    Predictions:
    .. math::
        x = f(x_t, t)
    :meta private:
    """
    def __init__(
        self, network: nnx.Module, train_time_dist_type:  str | TrainingTimeDistType,
        t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5,
        teacher: nnx.Module | None = None, guidance_scale: float = 1.0
    ):
        assert teacher is not None, "Teacher model must be provided for sCDInterface."
        super().__init__(
            network,
            train_time_dist_type,
            t_mu=t_mu, t_sigma=t_sigma, n_mu=n_mu, n_sigma=n_sigma, x_sigma=x_sigma
        )
        self.teacher = teacher
        self.guidance_scale = guidance_scale
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        raise ValueError("sCDInterface does not support score calculation.")
    def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray:
        """Calculate the sCD loss."""
        pass
[docs]
class MeanFlowInterface(SiTInterface):
    r"""Interface for Mean Flow.
    
    Transport Path:
    .. math::
        x_t = (1 - t) * x + t * n
    Losses:
    .. math::
        L = \mathbb{E} \Vert u(x_t, t, r) - \text{sg}(v - (t - r) * \frac{du}{dt}) \Vert ^ 2
    Predictions:
    .. math::
        x_r = x_t - (t - r) * u(x_t, t, r)
    """
[docs]
    def __init__(
        self, network: nnx.Module, train_time_dist_type:  str | TrainingTimeDistType,
        t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5,
        guidance_scale: float = 1.0, guidance_mixture_ratio: float = 0.5, guidance_t_min: float = 0.0, guidance_t_max: float = 1.0,
        norm_eps: float = 1e-3, norm_power: float = 1.0, fm_portion: float = 0.75, cond_drop_ratio: float = 0.5,
        t_shift_base: int = 4096,
    ):
        super().__init__(
            network,
            train_time_dist_type,
            t_mu=t_mu, t_sigma=t_sigma, n_mu=n_mu, n_sigma=n_sigma, x_sigma=x_sigma
        )
        # omega in meanflow
        self.guidance_scale = guidance_scale
        # keppa in meanflow
        self.guidance_mixture_ratio = guidance_mixture_ratio
        # effectively guidance interval
        self.guidance_t_min = guidance_t_min
        self.guidance_t_max = guidance_t_max
        self.norm_eps = norm_eps
        self.norm_power = norm_power
        self.fm_portion = fm_portion
        self.cond_drop_ratio = cond_drop_ratio 
[docs]
    def sample_t_r(self, shape: tuple[int, ...]) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Sample time pairs (t, r) for Mean Flow training.
        
        Args:
            shape: shape of the time arrays.
            
        Returns:
            tuple[jnp.ndarray, jnp.ndarray]: (t, r), time pairs where t >= r.
        """
        t = self.sample_t(shape)
        r = self.sample_t(shape)
        t, r = jnp.maximum(t, r), jnp.minimum(t, r)
        fm_mask = jnp.arange(t.shape[0]) < int(t.shape[0] * self.fm_portion)
        r = jnp.where(fm_mask, t, r)
        return t, r 
[docs]
    def cond_drop(self, x: jnp.ndarray, n: jnp.ndarray, v: jnp.ndarray, y: jnp.ndarray, neg_y: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Drop the condition with a certain ratio.
        
        Note: the reason why we need to drop the condition outside of the model is that
              the effective regression target depends on the resulted from dropout insta velocity
              
        Args:
            x: input clean sample.
            n: noise.
            v: velocity.
            y: condition.
            neg_y: negative condition.
            
        Returns:
            tuple[jnp.ndarray, jnp.ndarray]: (v, y), updated velocity and condition after dropout.
        """
        unguided_v = n - x
        mask = jax.random.uniform(self.network.rngs.label_dropout(), shape=y.shape) < self.cond_drop_ratio
        num_drop = jnp.sum(mask).astype(jnp.int32)
        drop_mask = jnp.arange(y.shape[0]) < num_drop
        # TODO: consider supporting more generalized negative condition
        y = jnp.where(drop_mask, neg_y, y)
        v = jnp.where(self.bcast_right(drop_mask, v), unguided_v, v)
        return v, y 
[docs]
    def insta_velocity(
        self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray, *args,
        y: jnp.ndarray | None = None, neg_y: jnp.ndarray | None = None, **kwargs
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Instantaneous velocity of the mean flow. For exact formulation, see https://arxiv.org/pdf/2505.13447.
        
        Args:
            x: input clean sample.
            n: noise.
            t: current timestep.
            *args: additional arguments for network forward.
            y: condition (optional).
            neg_y: negative condition (optional).
            **kwargs: additional keyword arguments for network forward.
            
        Returns:
            tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: (v, y, neg_y), instantaneous velocity and conditions.
        """
        self.network.eval()
        v = n - x
        x_t = self.sample_x_t(x, n, t)
        # TODO: fix the hardcoding
        # unconditional generation
        if y is None:
            y = jnp.zeros((t.shape[0],), dtype=jnp.int32) + 1000
        # default negative condition
        if neg_y is None:
            neg_y = jnp.zeros((t.shape[0],), dtype=jnp.int32) + 1000
        # no guidance case
        if self.guidance_scale == 1.0 and self.guidance_mixture_ratio == 0.0:
            return v, y, neg_y
        v_uncond = self.network(
            (self.bcast_right(self.c_in(t), x_t) * x_t),
            t,
            *args,
            y=neg_y,
            dt=jnp.zeros_like(t),
            **kwargs
        )[0]
        if self.guidance_mixture_ratio == 0.0:
            return jnp.where(
                self.bcast_right((t >= self.guidance_t_min) & (t <= self.guidance_t_max), v),
                v_uncond + self.guidance_scale * (v - v_uncond),
                v
            ), y, neg_y
        
        v_cond = self.network(
            (self.bcast_right(self.c_in(t), x_t) * x_t),
            t,
            *args,
            y=y,
            dt=jnp.zeros_like(t),
            **kwargs
        )[0]
        self.network.train()
        return jnp.where(
            self.bcast_right((t >= self.guidance_t_min) & (t <= self.guidance_t_max), v),
            self.guidance_scale * v + (1 - self.guidance_scale - self.guidance_mixture_ratio) * v_uncond + self.guidance_mixture_ratio * v_cond,
            v
        ), y, neg_y 
[docs]
    def target(
        self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray, r: jnp.ndarray, *args,
        y: jnp.ndarray | None = None, neg_y: jnp.ndarray | None = None, **kwargs
    ) -> jnp.ndarray:
        r"""Get training target for Mean Flow.
        
        Note: network must be augmented with r, the jump size, as an additional input
        .. math::
            target = v - (t - r) * \frac{du}{dt}
        """
        v, y, neg_y = self.insta_velocity(x, n, t, *args, y=y, neg_y=neg_y, **kwargs)
        v, y = self.cond_drop(x, n, v, y, neg_y=neg_y)
        x_t = self.sample_x_t(x, n, t)
        def u_fn(x_t, t, r):
            return self.network(
                (self.bcast_right(self.c_in(t), x_t) * x_t),
                t,
                *args,
                dt=t - r,
                y=y,
                **kwargs
            )
        
        dtdt = jnp.ones_like(t)
        drdt = jnp.zeros_like(r)
        u, dudt, feat = jax.jvp(
            u_fn,
            (x_t, t, r),
            (v, dtdt, drdt),
            has_aux=True
        )
        return (u, feat), jax.lax.stop_gradient(v - jnp.clip(self.bcast_right(t - r, v), 0., 1.) * dudt) 
[docs]
    def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, r: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        r"""Predict ODE tangent according to the Mean Flow interface.
        
        .. math::
            v_{(t, r)} = u(x_t, t, r)
        """
        return self.network(
            (self.bcast_right(self.c_in(t), x_t) * x_t),
            t,
            *args,
            dt=t - r,
            **kwargs
        )[0] 
    def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray:
        """:meta private:"""
        # score is given at r = t
        tangent = self.pred(x_t, t, jnp.zeros_like(t), *args, **kwargs)
        t = self.bcast_right(t, x_t)
        return -(x_t + (1 - t) * tangent) / t ** 2
[docs]
    def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray:
        r"""Calculate the Mean Flow loss.
        
        .. math::
            L = \mathbb{E} \Vert u(x_t, t, r) - v_{(t, r)} \Vert ^ 2
        """
        
        t, r = self.sample_t_r((x.shape[0],))
        n = self.sample_n(x.shape)
        (net_out, features), target = self.target(x, n, t, r, *args, **kwargs)
        # following the implementation of meanflow we use sum loss
        loss = jnp.sum((net_out - target) ** 2, axis=list(range(1, x.ndim)))
        adp_w = 1.0 / (loss + self.norm_eps) ** self.norm_power
        loss = jax.lax.stop_gradient(adp_w) * loss
        
        if return_aux:
            return loss, net_out, features
        else:
            return {
                'loss': loss
            }