Source code for interfaces.repa

# built-in libs

# external libs
import flax
import flax.linen as nn
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np

# deps
from networks.transformers import dit_nnx

[docs] def build_mlp(hidden_size, projector_dim, feature_dim, rngs, dtype=jnp.float32): """Build a multi-layer perceptron for feature projection. Args: hidden_size: input hidden size. projector_dim: projector dimension. feature_dim: output feature dimension. rngs: random number generators. dtype: data type. Returns: nnx.Sequential: mlp, multi-layer perceptron for feature projection. """ return nnx.Sequential( nnx.Linear( hidden_size, projector_dim, dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs ), nnx.silu, nnx.Linear( projector_dim, projector_dim, dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs ), nnx.silu, nnx.Linear( projector_dim, feature_dim, dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs ), )
[docs] class DiT_REPA(nnx.Module): """DiT with REPA (Representation Alignment) wrapper. This class wraps a diffusion interface with REPA functionality for representation alignment. """
[docs] def __init__( self, interface, *, feature_dim: int, repa_loss_weight: float, repa_depth: int, proj_dim: int, dtype: jnp.dtype = jnp.float32, ): """Initialize DiT_REPA. Args: interface: diffusion interface to wrap. feature_dim: feature dimension for alignment. repa_loss_weight: weight for REPA loss. repa_depth: depth for REPA feature extraction. proj_dim: projection dimension. dtype: data type. """ self.interface = interface self.repa_depth = repa_depth self.repa_loss_weight = repa_loss_weight self.projector = build_mlp( interface.network.hidden_size, proj_dim, feature_dim, rngs=interface.network.rngs, dtype=dtype ) self.interface.network.return_intermediate_features = True
[docs] def loss(self, x: jnp.ndarray, x_feature: jnp.ndarray, *args, **kwargs) -> tuple[jnp.ndarray, jnp.ndarray]: """Calculate combined diffusion and REPA loss. Args: x: input clean sample. x_feature: target features for alignment. *args: additional arguments for interface. **kwargs: additional keyword arguments for interface. Returns: tuple[jnp.ndarray, jnp.ndarray]: (diffusion_loss, repa_loss), diffusion and REPA losses. """ diffusion_loss, _, intermediate_features = self.interface(x, *args, return_aux=True, **kwargs) repa_feature = intermediate_features[self.repa_depth - 1] N, T, D = repa_feature.shape feature_proj = self.projector(repa_feature.reshape(-1, D)).reshape(N, T, -1) # TODO: update the following x_feature_norm = x_feature / jnp.linalg.norm(x_feature, axis=-1, keepdims=True) feature_proj_norm = feature_proj / jnp.linalg.norm(feature_proj, axis=-1, keepdims=True) feature_cos_sim = jnp.sum(x_feature_norm * feature_proj_norm, axis=-1) repa_loss = self.interface.mean_flat(-feature_cos_sim) return diffusion_loss, repa_loss
[docs] def pred(self, *args, **kwargs) -> jnp.ndarray: """Predict ODE tangent. Args: *args: arguments passed to interface. **kwargs: keyword arguments passed to interface. Returns: jnp.ndarray: tangent, predicted ODE tangent from interface. """ return self.interface.pred(*args, **kwargs)
[docs] def score(self, *args, **kwargs) -> jnp.ndarray: """Calculate score function. Args: *args: arguments passed to interface. **kwargs: keyword arguments passed to interface. Returns: jnp.ndarray: score, score function from interface. """ return self.interface.score(*args, **kwargs)
[docs] def __call__(self, x: jnp.ndarray, x_feature: jnp.ndarray, *args, **kwargs) -> dict[str, jnp.ndarray]: """Forward pass with combined diffusion and REPA loss. Args: x: input clean sample. x_feature: target features for alignment. *args: additional arguments for interface. **kwargs: additional keyword arguments for interface. Returns: dict[str, jnp.ndarray]: losses, dictionary containing total, diffusion, and REPA losses. """ diffusion_loss, repa_loss = self.loss(x, x_feature, *args, **kwargs) return { 'loss': diffusion_loss + self.repa_loss_weight * repa_loss, 'diffusion_loss': diffusion_loss, 'repa_loss': repa_loss }