Source code for utils.ema

"""File containing the Exponential Moving Average (EMA) implementation."""

# built-in libs
import copy

# external libs
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np

# deps


[docs] def get_network( module: nnx.Module, ): """Helper function that recursively traverses modules to find the first object named `network`. This function is used in the case where there are multiple loss wrappers around the network, and in Eval / EMA only the network parameters are needed. """ if hasattr(module, 'network'): return module.network for attr in module.__dict__.values(): if isinstance(attr, nnx.Module): result = get_network(attr) if result is not None: return result return None
[docs] class EMA(nnx.Module):
[docs] def __init__(self, net: nnx.Module, decay: float): """Initialize the EMA object.""" self.ema = copy.deepcopy(net) ema_state = jax.tree.map(lambda x: jnp.zeros_like(x), nnx.state(net, nnx.Param)) nnx.update(self.ema, ema_state) self.ema.eval() self.decay = decay
[docs] def update(self, net: nnx.Module): """Update the EMA model state.""" # target_net = get_network(net) # target_ema = get_network(self.ema) state, ema_state = nnx.state(net, nnx.Param), nnx.state(self.ema, nnx.Param) ema_state = jax.tree.map( lambda p_net, p_ema: p_ema * self.decay + p_net * (1 - self.decay), state, ema_state ) nnx.update(self.ema, ema_state)
[docs] def get(self): """Return the pure EMA model state.""" return jax.device_get(nnx.split(self.ema, nnx.RngKey, ...)[-1])
[docs] def load(self, state: nnx.State): """Load the saved / pretrained EMA model state.""" graphdef, rng_state, _ = nnx.split(self.ema, nnx.RngKey, ...) self.ema = nnx.merge(graphdef, rng_state, state)
#---------------------------------------------------------------------------- # Below are PowerEMA from EDM2 https://github.com/NVlabs/edm2 def exp_to_std(exp): """:meta private:""" exp = np.float64(exp) std = np.sqrt((exp + 1) / (exp + 2) ** 2 / (exp + 3)) return std def std_to_exp(std): """:meta private:""" std = np.float64(std) tmp = std.flatten() ** -2 exp = [np.roots([1, 7, 16 - t, 12 - t]).real.max() for t in tmp] exp = np.float64(exp).reshape(std.shape) return exp def power_function_response(ofs, std, len, axis=0): """:meta private:""" ofs, std = np.broadcast_arrays(ofs, std) ofs = np.stack([np.float64(ofs)], axis=axis) exp = np.stack([std_to_exp(std)], axis=axis) s = [1] * exp.ndim s[axis] = -1 t = np.arange(len).reshape(s) resp = np.where(t <= ofs, (t / ofs) ** exp, 0) / ofs * (exp + 1) resp = resp / np.sum(resp, axis=axis, keepdims=True) return resp def power_function_correlation(a_ofs, a_std, b_ofs, b_std): """:meta private:""" a_exp = std_to_exp(a_std) b_exp = std_to_exp(b_std) t_ratio = a_ofs / b_ofs t_exp = np.where(a_ofs < b_ofs, b_exp, -a_exp) t_max = np.maximum(a_ofs, b_ofs) num = (a_exp + 1) * (b_exp + 1) * t_ratio ** t_exp den = (a_exp + b_exp + 1) * t_max return num / den def power_function_beta(exp, step): """:meta private:""" beta = (1 - 1 / step) ** (exp + 1) return beta def solve_posthoc_coefficients(in_ofs, in_std, out_ofs, out_std): # => [in, out] """:meta private:""" in_ofs, in_std = np.broadcast_arrays(in_ofs, in_std) out_ofs, out_std = np.broadcast_arrays(out_ofs, out_std) rv = lambda x: np.float64(x).reshape(-1, 1) cv = lambda x: np.float64(x).reshape(1, -1) A = power_function_correlation(rv(in_ofs), rv(in_std), cv(in_ofs), cv(in_std)) B = power_function_correlation(rv(in_ofs), rv(in_std), cv(out_ofs), cv(out_std)) X = np.linalg.solve(A, B) X = X / np.sum(X, axis=0) return X class PowerEMA: """TODO: to be updated. :meta private:""" def __init__(self, net: nnx.Module, stds: float): self.net = net self.stds = stds self.exps = [ std_to_exp(np.array(std, dtype=np.float64)) for std in self.stds ] self.emas = [copy.deepcopy(net) for _ in self.stds] for ema in self.emas: ema.eval() def update(self, net: nnx.Module, step: int): for exp, ema in zip(self.exps, self.emas): state, ema_state = nnx.state(net, nnx.Param), nnx.state(ema, nnx.Param) beta = power_function_beta(exp=exp, step=step) ema_state = jax.tree.map( lambda p_net, p_ema: p_ema * beta + p_net * (1 - beta), state, ema_state ) nnx.update(ema, ema_state) def get(self): return [(nnx.state(ema), f'-{std:.3f}') for std, ema in zip(self.stds, self.emas)] def load(self, state: list[nnx.State]): for ema, state in zip(self.emas, state): nnx.update(ema, state)