"""File containing the Exponential Moving Average (EMA) implementation."""
# built-in libs
import copy
# external libs
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
# deps
[docs]
def get_network(
    module: nnx.Module,
):
    """Helper function that recursively traverses modules to find the first object named `network`.
    
    This function is used in the case where there are multiple loss wrappers around the network, and in
    Eval / EMA only the network parameters are needed.
    """
    if hasattr(module, 'network'):
        return module.network
    for attr in module.__dict__.values():
        if isinstance(attr, nnx.Module):
            result = get_network(attr)
            if result is not None:
                return result
    return None 
[docs]
class EMA(nnx.Module):
[docs]
    def __init__(self, net: nnx.Module, decay: float):
        """Initialize the EMA object."""
        self.ema = copy.deepcopy(net)
        ema_state = jax.tree.map(lambda x: jnp.zeros_like(x), nnx.state(net, nnx.Param))
        nnx.update(self.ema, ema_state)
        self.ema.eval()
        self.decay = decay 
[docs]
    def update(self, net: nnx.Module):
        """Update the EMA model state."""
        # target_net = get_network(net)
        # target_ema = get_network(self.ema)
        state, ema_state = nnx.state(net, nnx.Param), nnx.state(self.ema, nnx.Param)
        ema_state = jax.tree.map(
            lambda p_net, p_ema: p_ema * self.decay + p_net * (1 - self.decay),
            state, ema_state
        )
        nnx.update(self.ema, ema_state) 
    
[docs]
    def get(self):
        """Return the pure EMA model state."""
        return jax.device_get(nnx.split(self.ema, nnx.RngKey, ...)[-1]) 
    
[docs]
    def load(self, state: nnx.State):
        """Load the saved / pretrained EMA model state."""
        graphdef, rng_state, _ = nnx.split(self.ema, nnx.RngKey, ...)
        self.ema = nnx.merge(graphdef, rng_state, state) 
 
#----------------------------------------------------------------------------
# Below are PowerEMA from EDM2 https://github.com/NVlabs/edm2
def exp_to_std(exp):
    """:meta private:"""
    exp = np.float64(exp)
    std = np.sqrt((exp + 1) / (exp + 2) ** 2 / (exp + 3))
    return std
def std_to_exp(std):
    """:meta private:"""
    std = np.float64(std)
    tmp = std.flatten() ** -2
    exp = [np.roots([1, 7, 16 - t, 12 - t]).real.max() for t in tmp]
    exp = np.float64(exp).reshape(std.shape)
    return exp
def power_function_response(ofs, std, len, axis=0):
    """:meta private:"""
    ofs, std = np.broadcast_arrays(ofs, std)
    ofs = np.stack([np.float64(ofs)], axis=axis)
    exp = np.stack([std_to_exp(std)], axis=axis)
    s = [1] * exp.ndim
    s[axis] = -1
    t = np.arange(len).reshape(s)
    resp = np.where(t <= ofs, (t / ofs) ** exp, 0) / ofs * (exp + 1)
    resp = resp / np.sum(resp, axis=axis, keepdims=True)
    return resp
def power_function_correlation(a_ofs, a_std, b_ofs, b_std):
    """:meta private:"""
    a_exp = std_to_exp(a_std)
    b_exp = std_to_exp(b_std)
    t_ratio = a_ofs / b_ofs
    t_exp = np.where(a_ofs < b_ofs, b_exp, -a_exp)
    t_max = np.maximum(a_ofs, b_ofs)
    num = (a_exp + 1) * (b_exp + 1) * t_ratio ** t_exp
    den = (a_exp + b_exp + 1) * t_max
    return num / den
def power_function_beta(exp, step):
    """:meta private:"""
    beta = (1 - 1 / step) ** (exp + 1)
    return beta
def solve_posthoc_coefficients(in_ofs, in_std, out_ofs, out_std): # => [in, out]
    """:meta private:"""
    in_ofs, in_std = np.broadcast_arrays(in_ofs, in_std)
    out_ofs, out_std = np.broadcast_arrays(out_ofs, out_std)
    rv = lambda x: np.float64(x).reshape(-1, 1)
    cv = lambda x: np.float64(x).reshape(1, -1)
    A = power_function_correlation(rv(in_ofs), rv(in_std), cv(in_ofs), cv(in_std))
    B = power_function_correlation(rv(in_ofs), rv(in_std), cv(out_ofs), cv(out_std))
    X = np.linalg.solve(A, B)
    X = X / np.sum(X, axis=0)
    return X
class PowerEMA:
    """TODO: to be updated.
    
    :meta private:"""
    def __init__(self, net: nnx.Module, stds: float):
        self.net = net
        self.stds = stds
        self.exps = [
            std_to_exp(np.array(std, dtype=np.float64)) for std in self.stds
        ]
        self.emas = [copy.deepcopy(net) for _ in self.stds]
        for ema in self.emas:
            ema.eval()
    def update(self, net: nnx.Module, step: int):
        for exp, ema in zip(self.exps, self.emas):
            state, ema_state = nnx.state(net, nnx.Param), nnx.state(ema, nnx.Param)
            beta = power_function_beta(exp=exp, step=step)
            ema_state = jax.tree.map(
                lambda p_net, p_ema: p_ema * beta + p_net * (1 - beta),
                state, ema_state
            )
            nnx.update(ema, ema_state)
    
    def get(self):
        return [(nnx.state(ema), f'-{std:.3f}') for std, ema in zip(self.stds, self.emas)]
    def load(self, state: list[nnx.State]):
        for ema, state in zip(self.emas, state):
            nnx.update(ema, state)