Source code for interfaces.continuous

# built-in libs
from abc import ABC, abstractmethod
import dataclasses
from enum import Enum
import math

# external libs
import flax
from flax import nnx

import jax
import jax.numpy as jnp


class TrainingTimeDistType(Enum):
    """Class for Training Time Distribution Types.
    
    :meta private:
    """
    UNIFORM = 1
    LOGNORMAL = 2
    LOGITNORMAL = 3

    # TODO: Add more training time distribution types


[docs] class Interfaces(nnx.Module, ABC): r""" Base class for all diffusion / flow matching interfaces. All interfaces be a wrapper around network backbone and should support: - Define the pre-conditionings (see EDM) - Calculate losses for training - Define transport path (\alpha_t & \sigma_t) - Sample t - Sample X_t - Give tangent for sampling Required RNG Key: - time: for sampling t - noise: for sampling n """
[docs] def __init__(self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType): self.network = network if isinstance(train_time_dist_type, str): self.train_time_dist_type = TrainingTimeDistType[train_time_dist_type.replace('_', '').upper()] else: self.train_time_dist_type = train_time_dist_type
[docs] @abstractmethod def c_in(self, t: jnp.ndarray) -> jnp.ndarray: r"""Calculate c_in for the interface. Args: t: current timestep. Returns: jnp.ndarray: c_in, c_in for the interface. """
[docs] @abstractmethod def c_out(self, t: jnp.ndarray) -> jnp.ndarray: r"""Calculate c_out for the interface. Args: t: current timestep. Returns: jnp.ndarray: c_out, c_out for the interface. """
[docs] @abstractmethod def c_skip(self, t: jnp.ndarray) -> jnp.ndarray: r"""Calculate c_skip for the interface. Args: t: current timestep. Returns: jnp.ndarray: c_skip, c_skip for the interface. """
[docs] @abstractmethod def c_noise(self, t: jnp.ndarray) -> jnp.ndarray: r"""Calculate c_noise for the interface. Args: t: current timestep. Returns: jnp.ndarray: c_noise, c_noise for the interface. """
[docs] @abstractmethod def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray: r"""Sample t from the training time distribution. Args: shape: shape of timestep t. Returns: jnp.ndarray: t, sampled timestep t. """
[docs] @abstractmethod def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray: r"""Sample noises. Args: shape: shape of noise. Returns: jnp.ndarray: n, sampled noise. """
# Exposing this function to the interface allows for more flexibility in noise sampling
[docs] @abstractmethod def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: r"""Sample X_t according to the defined interface. Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: x_t, sampled X_t according to transport path. """
[docs] @abstractmethod def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: r"""Get training target. Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: target, training target. """
[docs] @abstractmethod def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Predict ODE tangent according to the defined interface. Args: x_t: input noisy sample. t: current timestep. Returns: jnp.ndarray: tangent, predicted ODE tangent. """
[docs] @abstractmethod def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Transform ODE tangent to the Score Function \nabla \log p_t(x). Args: x_t: input noisy sample. t: current timestep. Returns: jnp.ndarray: score, score function \nabla \log p_t(x). """
[docs] @abstractmethod def loss(self, x: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Calculate loss for training. Args: x: input clean sample. args: additional arguments for network forward. kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray: loss, calculated loss. """
def __call__(self, x: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: return self.loss(x, *args, **kwargs) ########## Helper Functions ##########
[docs] @staticmethod def mean_flat(x: jnp.ndarray) -> jnp.ndarray: r"""Take mean w.r.t. all dimensions of x except the first. Args: x: input array. Returns: jnp.ndarray: mean, mean across all dimensions except the first. """ return jnp.mean(x, axis=list(range(1, x.ndim)))
[docs] @staticmethod def bcast_right(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: r"""Broadcast x to the right to match the shape of y. Args: x: array to broadcast. y: target array to match shape. Returns: jnp.ndarray: broadcasted, x broadcasted to match y's shape. """ assert len(y.shape) >= x.ndim return x.reshape(x.shape + (1,) * (len(y.shape) - x.ndim))
[docs] @staticmethod def t_shift(t: jnp.ndarray, shift: float) -> jnp.ndarray: r"""Shift t by a constant shift value. Args: t: input timestep array. shift: shift value. Returns: jnp.ndarray: shifted_t, t shifted by the shift value. """ return shift * t / (1 + (shift - 1) * t)
[docs] class SiTInterface(Interfaces): r"""Interface for SiT. Transport path: .. math:: x_t = (1 - t) * x + t * n Losses: .. math:: L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2 Predictions: .. math:: x = xt - t * D(x_t, t) """
[docs] def __init__( self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5, t_shift_base: int = 4096, ): super().__init__(network, train_time_dist_type) self.t_mu = t_mu self.t_sigma = t_sigma self.n_mu = n_mu self.n_sigma = n_sigma self.x_sigma = x_sigma self.t_shift_base = t_shift_base
[docs] def c_in(self, t: jnp.ndarray) -> jnp.ndarray: """Flow matching preconditioning. .. math:: c_{in} = 1 """ # return 1 / jnp.sqrt((1 - t) ** 2 * self.x_sigma ** 2 + t ** 2) return jnp.ones_like(t)
[docs] def c_out(self, t: jnp.ndarray) -> jnp.ndarray: """Flow matching preconditioning. .. math:: c_{out} = 1 """ return jnp.ones_like(t)
[docs] def c_skip(self, t: jnp.ndarray) -> jnp.ndarray: """Flow matching preconditioning. .. math:: c_{skip} = 0 """ return jnp.zeros_like(t)
[docs] def c_noise(self, t: jnp.ndarray) -> jnp.ndarray: """Flow matching preconditioning. .. math:: c_{noise} = t """ return t
def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray: """:meta private:""" rng = self.network.rngs.time() if self.train_time_dist_type == TrainingTimeDistType.UNIFORM: return jax.random.uniform(rng, shape=shape) elif self.train_time_dist_type == TrainingTimeDistType.LOGITNORMAL: return jax.nn.sigmoid(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu) else: raise ValueError(f"Training Time Distribution Type {self.train_time_dist_type} not supported.") def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray: """:meta private:""" # rng = self.make_rng('noise') rng = self.network.rngs.noise() return jax.random.normal(rng, shape=shape) * self.n_sigma + self.n_mu
[docs] def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: """Sample x_t defined by flow matching. .. math:: x_t = (1 - t) * x + t * n Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: x_t, sampled x_t according to flow matching. """ t = self.bcast_right(t, x) return (1 - t) * x + t * n
[docs] def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: """Return flow matching target .. math:: v = n - x Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: v, flow matching target. """ return n - x
[docs] def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: """Predict flow matching tangent. .. math:: v = D(x_t, t) Args: x_t: input noisy sample. t: current timestep. *args: additional arguments for network forward. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray: v, predicted flow matching tangent. """ return self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, **kwargs )[0]
[docs] def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Transform flow matching tangent to the score function. .. math:: \nabla \log p_t(x) = -x_t - (1 - t) * D(x_t, t) Args: x_t: input noisy sample. t: current timestep. *args: additional arguments for network forward. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray: score, score function \nabla \log p_t(x). """ tangent = self.pred(x_t, t, *args, **kwargs) t = self.bcast_right(t, x_t) return -(x_t + (1 - t) * tangent) / t
[docs] def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray: r"""Calculate flow matching loss. .. math:: L = \mathbb{E} \Vert D(x_t, t) - (n - x) \Vert ^ 2 Args: x: input clean sample. *args: additional arguments for network forward. return_aux: whether to return auxiliary outputs. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray or tuple: loss, calculated loss (or tuple with aux outputs if return_aux=True). """ t = self.sample_t((x.shape[0],)) t = self.t_shift(t, math.sqrt(math.prod(x.shape[1:]) / self.t_shift_base)) n = self.sample_n(x.shape) x_t = self.sample_x_t(x, n, t) target = self.target(x, n, t) net_out, features = self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, **kwargs ) if return_aux: # specifically for auxiliary loss wrappers return self.mean_flat((net_out - target) ** 2), net_out, features else: return { 'loss': self.mean_flat((net_out - target) ** 2) }
[docs] class EDMInterface(Interfaces): r"""Interface for EDM. Transport Path: .. math:: x_t = x + t * n Losses: .. math:: L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2 Predictions: .. math:: x = D(x_t, t) """
[docs] def __init__( self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5 ): super().__init__(network, train_time_dist_type) self.t_mu = t_mu self.t_sigma = t_sigma self.n_mu = n_mu self.n_sigma = n_sigma self.x_sigma = x_sigma
[docs] def c_in(self, t: jnp.ndarray) -> jnp.ndarray: r"""EDM preconditioning. .. math:: c_{in} = 1 / \sqrt{x_sigma ^ 2 + t ^ 2} """ return 1 / jnp.sqrt(self.x_sigma ** 2 + t ** 2)
[docs] def c_out(self, t: jnp.ndarray) -> jnp.ndarray: r"""EDM preconditioning. .. math:: c_{out} = t * x_sigma / \sqrt{t ^ 2 + x_sigma ^ 2} """ return t * self.x_sigma / jnp.sqrt(t ** 2 + self.x_sigma ** 2)
[docs] def c_skip(self, t) -> jnp.ndarray: r"""EDM preconditioning. .. math:: c_{skip} = x_sigma ^ 2 / (t ^ 2 + x_sigma ^ 2) """ return self.x_sigma ** 2 / (t ** 2 + self.x_sigma ** 2)
[docs] def c_noise(self, t: jnp.ndarray) -> jnp.ndarray: r"""EDM preconditioning. .. math:: c_{noise} = \log(t) / 4 """ return jnp.log(t) / 4
def sample_t(self, shape: tuple[int, ...]) -> jnp.ndarray: """:meta private:""" rng = self.network.rngs.time() if self.train_time_dist_type == TrainingTimeDistType.UNIFORM: return jax.random.uniform(rng, shape=shape) elif self.train_time_dist_type == TrainingTimeDistType.LOGNORMAL: return jnp.exp(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu) elif self.train_time_dist_type == TrainingTimeDistType.LOGITNORMAL: return jax.nn.sigmoid(jax.random.normal(rng, shape=shape) * self.t_sigma + self.t_mu) else: raise ValueError(f"Training Time Distribution Type {self.train_time_dist_type} not supported.") def sample_n(self, shape: tuple[int, ...]) -> jnp.ndarray: """:meta private:""" rng = self.network.rngs.noise() return jax.random.normal(rng, shape=shape) * self.n_sigma + self.n_mu
[docs] def sample_x_t(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: r"""Sample x_t defined by EDM. .. math:: x_t = x + t * n Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: x_t, sampled x_t according to EDM. """ return x + self.bcast_right(t, n) * n
[docs] def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: r"""Return EDM target. .. math:: target = x Args: x: input clean sample. n: noise. t: current timestep. Returns: jnp.ndarray: target, EDM target. """ return x
[docs] def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Predict EDM tangent. .. math:: v = (x_t - D(x_t, t)) / t Args: x_t: input noisy sample. t: current timestep. *args: additional arguments for network forward. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray: v, predicted EDM tangent. """ F_x = self.network((self.bcast_right(self.c_in(t), x_t) * x_t), self.c_noise(t), *args, **kwargs)[0] D_x = self.bcast_right(self.c_skip(t), x_t) * x_t + self.bcast_right(self.c_out(t), F_x) * F_x return (x_t - D_x) / self.bcast_right(t, x_t)
[docs] def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Transform EDM tangent to the score function. .. math:: \nabla \log p_t(x) = -(x_t - v) / t ^ 2 Args: x_t: input noisy sample. t: current timestep. *args: additional arguments for network forward. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray: score, score function \nabla \log p_t(x). """ tangent = self.pred(x_t, t, *args, **kwargs) t = self.bcast_right(t, x_t) return -(x_t - tangent) / (t ** 2)
[docs] def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray: r"""Calculate EDM loss. .. math:: L = \mathbb{E} \Vert D(x_t, t) - x \Vert ^ 2 Args: x: input clean sample. *args: additional arguments for network forward. return_aux: whether to return auxiliary outputs. **kwargs: additional keyword arguments for network forward. Returns: jnp.ndarray or tuple: loss, calculated loss (or tuple with aux outputs if return_aux=True). """ sigma = self.sample_t((x.shape[0],)) n = self.sample_n(x.shape) x_t = self.sample_x_t(x, n, sigma) target = self.target(x, n, sigma) F_x, features = self.network((self.bcast_right(self.c_in(sigma), x_t) * x_t), self.c_noise(sigma), *args, **kwargs) D_x = self.bcast_right(self.c_skip(sigma), x_t) * x_t + self.bcast_right(self.c_out(sigma), F_x) * F_x weight = (sigma ** 2 + self.x_sigma ** 2) / (sigma * self.x_sigma) ** 2 if return_aux: # specifically for auxiliary loss wrappers return self.mean_flat(self.bcast_right(weight, D_x) * (D_x - target) ** 2), D_x, features else: return { 'loss': self.mean_flat(self.bcast_right(weight, D_x) * (D_x - target) ** 2) }
class sCTInterface(EDMInterface): r"""Interface for CM. Transport Path: .. math:: x_t = x + t * n Losses: .. math:: L = \mathbb{E} \Vert f_{t - 1} - f_{t} \Vert ^ 2 Predictions: .. math:: x = f(x_t, t) :meta private: """ def target(self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: """Get the effective training target for sCT.""" pass def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: """Predict the average velocity from noise to data.""" pass def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: raise ValueError("sCTInterface does not support score calculation.") def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray: """Calculate the sCT loss.""" pass class sCDInterface(sCTInterface): r"""Interface for CM. Transport Path: .. math:: x_t = x + t * n Losses: .. math:: L = \mathbb{E} \Vert f_{t - 1} - f_{t} \Vert ^ 2 Predictions: .. math:: x = f(x_t, t) :meta private: """ def __init__( self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5, teacher: nnx.Module | None = None, guidance_scale: float = 1.0 ): assert teacher is not None, "Teacher model must be provided for sCDInterface." super().__init__( network, train_time_dist_type, t_mu=t_mu, t_sigma=t_sigma, n_mu=n_mu, n_sigma=n_sigma, x_sigma=x_sigma ) self.teacher = teacher self.guidance_scale = guidance_scale def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: raise ValueError("sCDInterface does not support score calculation.") def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray: """Calculate the sCD loss.""" pass
[docs] class MeanFlowInterface(SiTInterface): r"""Interface for Mean Flow. Transport Path: .. math:: x_t = (1 - t) * x + t * n Losses: .. math:: L = \mathbb{E} \Vert u(x_t, t, r) - \text{sg}(v - (t - r) * \frac{du}{dt}) \Vert ^ 2 Predictions: .. math:: x_r = x_t - (t - r) * u(x_t, t, r) """
[docs] def __init__( self, network: nnx.Module, train_time_dist_type: str | TrainingTimeDistType, t_mu: float = 0., t_sigma: float = 1.0, n_mu: float = 0., n_sigma: float = 1.0, x_sigma: float = 0.5, guidance_scale: float = 1.0, guidance_mixture_ratio: float = 0.5, guidance_t_min: float = 0.0, guidance_t_max: float = 1.0, norm_eps: float = 1e-3, norm_power: float = 1.0, fm_portion: float = 0.75, cond_drop_ratio: float = 0.5, t_shift_base: int = 4096, ): super().__init__( network, train_time_dist_type, t_mu=t_mu, t_sigma=t_sigma, n_mu=n_mu, n_sigma=n_sigma, x_sigma=x_sigma ) # omega in meanflow self.guidance_scale = guidance_scale # keppa in meanflow self.guidance_mixture_ratio = guidance_mixture_ratio # effectively guidance interval self.guidance_t_min = guidance_t_min self.guidance_t_max = guidance_t_max self.norm_eps = norm_eps self.norm_power = norm_power self.fm_portion = fm_portion self.cond_drop_ratio = cond_drop_ratio
[docs] def sample_t_r(self, shape: tuple[int, ...]) -> tuple[jnp.ndarray, jnp.ndarray]: """Sample time pairs (t, r) for Mean Flow training. Args: shape: shape of the time arrays. Returns: tuple[jnp.ndarray, jnp.ndarray]: (t, r), time pairs where t >= r. """ t = self.sample_t(shape) r = self.sample_t(shape) t, r = jnp.maximum(t, r), jnp.minimum(t, r) fm_mask = jnp.arange(t.shape[0]) < int(t.shape[0] * self.fm_portion) r = jnp.where(fm_mask, t, r) return t, r
[docs] def cond_drop(self, x: jnp.ndarray, n: jnp.ndarray, v: jnp.ndarray, y: jnp.ndarray, neg_y: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Drop the condition with a certain ratio. Note: the reason why we need to drop the condition outside of the model is that the effective regression target depends on the resulted from dropout insta velocity Args: x: input clean sample. n: noise. v: velocity. y: condition. neg_y: negative condition. Returns: tuple[jnp.ndarray, jnp.ndarray]: (v, y), updated velocity and condition after dropout. """ unguided_v = n - x mask = jax.random.uniform(self.network.rngs.label_dropout(), shape=y.shape) < self.cond_drop_ratio num_drop = jnp.sum(mask).astype(jnp.int32) drop_mask = jnp.arange(y.shape[0]) < num_drop # TODO: consider supporting more generalized negative condition y = jnp.where(drop_mask, neg_y, y) v = jnp.where(self.bcast_right(drop_mask, v), unguided_v, v) return v, y
[docs] def insta_velocity( self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray, *args, y: jnp.ndarray | None = None, neg_y: jnp.ndarray | None = None, **kwargs ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Instantaneous velocity of the mean flow. For exact formulation, see https://arxiv.org/pdf/2505.13447. Args: x: input clean sample. n: noise. t: current timestep. *args: additional arguments for network forward. y: condition (optional). neg_y: negative condition (optional). **kwargs: additional keyword arguments for network forward. Returns: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: (v, y, neg_y), instantaneous velocity and conditions. """ self.network.eval() v = n - x x_t = self.sample_x_t(x, n, t) # TODO: fix the hardcoding # unconditional generation if y is None: y = jnp.zeros((t.shape[0],), dtype=jnp.int32) + 1000 # default negative condition if neg_y is None: neg_y = jnp.zeros((t.shape[0],), dtype=jnp.int32) + 1000 # no guidance case if self.guidance_scale == 1.0 and self.guidance_mixture_ratio == 0.0: return v, y, neg_y v_uncond = self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, y=neg_y, dt=jnp.zeros_like(t), **kwargs )[0] if self.guidance_mixture_ratio == 0.0: return jnp.where( self.bcast_right((t >= self.guidance_t_min) & (t <= self.guidance_t_max), v), v_uncond + self.guidance_scale * (v - v_uncond), v ), y, neg_y v_cond = self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, y=y, dt=jnp.zeros_like(t), **kwargs )[0] self.network.train() return jnp.where( self.bcast_right((t >= self.guidance_t_min) & (t <= self.guidance_t_max), v), self.guidance_scale * v + (1 - self.guidance_scale - self.guidance_mixture_ratio) * v_uncond + self.guidance_mixture_ratio * v_cond, v ), y, neg_y
[docs] def target( self, x: jnp.ndarray, n: jnp.ndarray, t: jnp.ndarray, r: jnp.ndarray, *args, y: jnp.ndarray | None = None, neg_y: jnp.ndarray | None = None, **kwargs ) -> jnp.ndarray: r"""Get training target for Mean Flow. Note: network must be augmented with r, the jump size, as an additional input .. math:: target = v - (t - r) * \frac{du}{dt} """ v, y, neg_y = self.insta_velocity(x, n, t, *args, y=y, neg_y=neg_y, **kwargs) v, y = self.cond_drop(x, n, v, y, neg_y=neg_y) x_t = self.sample_x_t(x, n, t) def u_fn(x_t, t, r): return self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, dt=t - r, y=y, **kwargs ) dtdt = jnp.ones_like(t) drdt = jnp.zeros_like(r) u, dudt, feat = jax.jvp( u_fn, (x_t, t, r), (v, dtdt, drdt), has_aux=True ) return (u, feat), jax.lax.stop_gradient(v - jnp.clip(self.bcast_right(t - r, v), 0., 1.) * dudt)
[docs] def pred(self, x_t: jnp.ndarray, t: jnp.ndarray, r: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: r"""Predict ODE tangent according to the Mean Flow interface. .. math:: v_{(t, r)} = u(x_t, t, r) """ return self.network( (self.bcast_right(self.c_in(t), x_t) * x_t), t, *args, dt=t - r, **kwargs )[0]
def score(self, x_t: jnp.ndarray, t: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: """:meta private:""" # score is given at r = t tangent = self.pred(x_t, t, jnp.zeros_like(t), *args, **kwargs) t = self.bcast_right(t, x_t) return -(x_t + (1 - t) * tangent) / t ** 2
[docs] def loss(self, x: jnp.ndarray, *args, return_aux=False, **kwargs) -> jnp.ndarray: r"""Calculate the Mean Flow loss. .. math:: L = \mathbb{E} \Vert u(x_t, t, r) - v_{(t, r)} \Vert ^ 2 """ t, r = self.sample_t_r((x.shape[0],)) n = self.sample_n(x.shape) (net_out, features), target = self.target(x, n, t, r, *args, **kwargs) # following the implementation of meanflow we use sum loss loss = jnp.sum((net_out - target) ** 2, axis=list(range(1, x.ndim))) adp_w = 1.0 / (loss + self.norm_eps) ** self.norm_power loss = jax.lax.stop_gradient(adp_w) * loss if return_aux: return loss, net_out, features else: return { 'loss': loss }