# built-in libs
# external libs
import flax
import flax.linen as nn
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
# deps
from networks.transformers import dit_nnx
[docs]
def build_mlp(hidden_size, projector_dim, feature_dim, rngs, dtype=jnp.float32):
    """Build a multi-layer perceptron for feature projection.
    
    Args:
        hidden_size: input hidden size.
        projector_dim: projector dimension.
        feature_dim: output feature dimension.
        rngs: random number generators.
        dtype: data type.
        
    Returns:
        nnx.Sequential: mlp, multi-layer perceptron for feature projection.
    """
    return nnx.Sequential(
        nnx.Linear(
            hidden_size, projector_dim, 
            dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs
        ),
        nnx.silu,
        nnx.Linear(
            projector_dim, projector_dim, 
            dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs
        ),
        nnx.silu,
        nnx.Linear(
            projector_dim, feature_dim,
            dtype=dtype, precision=dit_nnx.PRECISION, rngs=rngs
        ),
    ) 
[docs]
class DiT_REPA(nnx.Module):
    """DiT with REPA (Representation Alignment) wrapper.
    
    This class wraps a diffusion interface with REPA functionality for representation alignment.
    """
    
[docs]
    def __init__(
        self,
        interface,
        *,
        feature_dim: int,
        repa_loss_weight: float,
        repa_depth: int,
        proj_dim: int,
        dtype: jnp.dtype = jnp.float32,
    ):
        """Initialize DiT_REPA.
        
        Args:
            interface: diffusion interface to wrap.
            feature_dim: feature dimension for alignment.
            repa_loss_weight: weight for REPA loss.
            repa_depth: depth for REPA feature extraction.
            proj_dim: projection dimension.
            dtype: data type.
        """
        self.interface = interface
        self.repa_depth = repa_depth
        self.repa_loss_weight = repa_loss_weight
        self.projector = build_mlp(
            interface.network.hidden_size, proj_dim, feature_dim,
            rngs=interface.network.rngs, dtype=dtype
        )
        self.interface.network.return_intermediate_features = True 
    
[docs]
    def loss(self, x: jnp.ndarray, x_feature: jnp.ndarray, *args, **kwargs) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Calculate combined diffusion and REPA loss.
        
        Args:
            x: input clean sample.
            x_feature: target features for alignment.
            *args: additional arguments for interface.
            **kwargs: additional keyword arguments for interface.
            
        Returns:
            tuple[jnp.ndarray, jnp.ndarray]: (diffusion_loss, repa_loss), diffusion and REPA losses.
        """
        diffusion_loss, _, intermediate_features = self.interface(x, *args, return_aux=True, **kwargs)
        repa_feature = intermediate_features[self.repa_depth - 1]
        N, T, D = repa_feature.shape
        feature_proj = self.projector(repa_feature.reshape(-1, D)).reshape(N, T, -1)
        # TODO: update the following
        x_feature_norm = x_feature / jnp.linalg.norm(x_feature, axis=-1, keepdims=True)
        feature_proj_norm = feature_proj / jnp.linalg.norm(feature_proj, axis=-1, keepdims=True)
        feature_cos_sim = jnp.sum(x_feature_norm * feature_proj_norm, axis=-1)
        repa_loss = self.interface.mean_flat(-feature_cos_sim)
        return diffusion_loss, repa_loss 
[docs]
    def pred(self, *args, **kwargs) -> jnp.ndarray:
        """Predict ODE tangent.
        
        Args:
            *args: arguments passed to interface.
            **kwargs: keyword arguments passed to interface.
            
        Returns:
            jnp.ndarray: tangent, predicted ODE tangent from interface.
        """
        return self.interface.pred(*args, **kwargs) 
    
[docs]
    def score(self, *args, **kwargs) -> jnp.ndarray:
        """Calculate score function.
        
        Args:
            *args: arguments passed to interface.
            **kwargs: keyword arguments passed to interface.
            
        Returns:
            jnp.ndarray: score, score function from interface.
        """
        return self.interface.score(*args, **kwargs) 
[docs]
    def __call__(self, x: jnp.ndarray, x_feature: jnp.ndarray, *args, **kwargs) -> dict[str, jnp.ndarray]:
        """Forward pass with combined diffusion and REPA loss.
        
        Args:
            x: input clean sample.
            x_feature: target features for alignment.
            *args: additional arguments for interface.
            **kwargs: additional keyword arguments for interface.
            
        Returns:
            dict[str, jnp.ndarray]: losses, dictionary containing total, diffusion, and REPA losses.
        """
        diffusion_loss, repa_loss = self.loss(x, x_feature, *args, **kwargs)
        return {
            'loss': diffusion_loss + self.repa_loss_weight * repa_loss,
            'diffusion_loss': diffusion_loss,
            'repa_loss': repa_loss
        }