"""File containing the sharding utils."""
# built-in libs
import dataclasses
import re
from typing import Mapping
# external libs
from absl import logging
import flax
from flax import nnx
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import ml_collections
import numpy as np
# deps
from utils import ema
[docs]
def flatten_state(
state: nnx.State,
path: tuple[str, ...] = ()
):
"""Recursively traverse an NNX VariableState, yielding (path, VariableState)."""
if isinstance(state, nnx.VariableState):
# Join path components into a string name (e.g. "Encoder/Layer_0/kernel")
name = "/".join(str(p) for p in path)
yield name, state
elif hasattr(state, "items"): # state behaves like a dict of submodules/vars
for key, subtree in state.items():
yield from flatten_state(subtree, path + (key,))
elif isinstance(state, (list, tuple)):
for idx, subtree in enumerate(state):
yield from flatten_state(subtree, path + (str(idx),))
[docs]
def place_like_target(tree, target):
"""Place the tree following the sharding of the target."""
def _put(x, ref):
# Ensure array-like (helps if some leaves are numpy arrays / lists)
x = jnp.asarray(x)
if isinstance(ref, jax.Array):
# Use the *exact* Sharding carried by the reference leaf
return jax.device_put(x, ref.sharding)
else:
# If target leaf isn't a jax.Array, just return x (or replicate if you prefer)
return x
return jax.tree.map(_put, tree, target)
[docs]
def replicate():
"""Sharding tactic to fully replicate a parameter (no sharding on any axis)."""
def update_spec(cur_spec, mesh, name, var_state):
# Ensure no other sharding has been applied to this parameter
if not all(axis is None for axis in cur_spec):
raise ValueError(f"Conflict: {name} already has a sharding spec {cur_spec}, cannot replicate.")
return cur_spec # All None => fully replicated
return update_spec
[docs]
def fsdp(
axis: str,
min_size_to_shard_mb: float = 4
):
"""Fully Sharded Data Parallel tactic - shard largest available dimension along given mesh axis."""
# Allow axis to be a single name or tuple of names (for multiple mesh axes)
axis_names = axis if isinstance(axis, tuple) else (axis,)
def update_spec(cur_spec, mesh, name, var_state):
arr = var_state.value
if arr is None:
# it's possible for a parameter to be None (e.g. in an optimizer state / norm layer)
return cur_spec
shape = arr.shape
# Compute total devices for the given axis/axes in the mesh
axis_size = np.prod([mesh.shape[a] for a in axis_names])
# Skip sharding if tensor is too small
if arr.size * arr.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20):
return cur_spec # leave as is (no sharding)
# Find the largest dimension that is not yet sharded and divisible by axis_size
dim_indices = np.argsort(shape)[::-1] # dims sorted by size (largest first)
for i in dim_indices:
if cur_spec[i] is None and shape[i] % axis_size == 0:
# Shard this dimension along the given mesh axis (or tuple of axes)
new_spec = list(cur_spec)
new_spec[i] = axis if isinstance(axis, tuple) else axis_names[0]
return tuple(new_spec)
# If no suitable dimension found, leave spec unchanged (param stays replicated)
return cur_spec
return update_spec
[docs]
def infer_sharding(
state: nnx.State,
strategy: str,
mesh: jax.sharding.Mesh
):
"""
Infer a sharding specification for an NNX model state based on regex strategy.
:param state: nnx.State (VariableState pytree) of the model's parameters.
:param strategy: list of (regex_pattern, tactic) pairs.
Tactic can be either a string like 'fsdp(axis=\"X\")' or a callable.
:param mesh: jax.sharding.Mesh defining device mesh axes.
:return: A PyTree with same structure as state, but leaves are nnx.sharding.NamedSharding.
"""
# Flatten state to list of (name, VariableState)
flat_params = list(flatten_state(state))
names = [name for name, _ in flat_params]
vars_states = [vs for _, vs in flat_params]
# Initialize spec: tuple[None,...] for each param (length = param.ndim)
specs = [
(None,) * vars_states[i].value.ndim if vars_states[i].value is not None else ()
for i in range(len(vars_states))
]
matched = set() # track indices of params already matched by a rule
# Helper to get tactic callable from strategy entry
def get_tactic_fn(tactic_descr):
# If already a callable (function), use it
if callable(tactic_descr):
return tactic_descr
# If string, parse basic format e.g. "fsdp(axis=\"data\")" or "replicate"
tactic_descr = tactic_descr.strip()
if tactic_descr.startswith("fsdp"):
# Extract axis argument inside parentheses if present
# e.g. fsdp(axis="model")
axis_match = re.search(r'axis\s*=\s*\"([A-Za-z0-9_, ]+)\"', tactic_descr)
axis_names = axis_match.group(1) if axis_match else None
if axis_names is not None:
# support multiple axis names separated by comma
axis_tuple = tuple(n.strip() for n in axis_names.split(','))
# if only one axis was provided, use string instead of tuple of length 1
axis_arg = axis_tuple if len(axis_tuple) > 1 else axis_tuple[0]
else:
axis_arg = None
return fsdp(axis=axis_arg) if axis_arg else fsdp(axis='data')
elif tactic_descr.startswith("replicate"):
return replicate()
else:
raise ValueError(f"Unknown tactic: {tactic_descr}")
# Apply each pattern in order
for pattern, tactic in strategy:
prog = re.compile(pattern)
tactic_fn = get_tactic_fn(tactic)
for idx, name in enumerate(names):
if idx in matched:
continue # already handled by earlier rule
if prog.search(name): # regex match (search anywhere in name)
# Apply tactic: possibly sequential ops if tactic returns composite (not in this simple impl)
specs[idx] = tactic_fn(specs[idx], mesh, name, vars_states[idx])
matched.add(idx)
# Convert specs (tuples) to PartitionSpec and wrap in NamedSharding
sharding_tree = []
for spec in specs:
pspec = P(*spec) # convert tuple of axis names/None to PartitionSpec
sharding_tree.append(NamedSharding(mesh, pspec))
# Reconstruct the tree structure of sharding_tree to mirror `state` structure
sharding_tree = jax.tree_util.tree_unflatten(
jax.tree_util.tree_structure(state, is_leaf=lambda x: isinstance(x, nnx.VariableState)),
sharding_tree)
return sharding_tree
[docs]
def create_device_mesh(
config_mesh: list[tuple[str, int]],
*,
allow_split_physical_axes: bool = False,
):
"""Returns a JAX device mesh.
Args:
config_mesh: A list of tuples of (axis_name, axis_size). It is advised to
sort the axis in increasing order of network communication intensity.
allow_split_physical_axes: Whether to allow splitting physical axes.
"""
devices = jax.devices()
mesh_axes, mesh_size = tuple(zip(*config_mesh))
# Because jax.utils do not support `-1` shape size.
mesh_size = np.array(devices).reshape(mesh_size).shape
device_mesh = mesh_utils.create_device_mesh(
mesh_size,
devices=devices,
allow_split_physical_axes=allow_split_physical_axes
)
return jax.sharding.Mesh(device_mesh, mesh_axes)
[docs]
def make_fsarray_from_local_slice(
local_slice: jnp.ndarray,
global_devices: list,
):
"""Create a fully-sharded global device array from local host arrays.
Args:
local_slice: Something convertible to a numpy array (eg also TF tensors)
that is this host's slice of the global array.
global_devices: The list of global devices. Needed for consistent ordering.
Returns:
The global on-device array which consists of all local slices stacked
together in the order consistent with the devices.
"""
mesh = jax.sharding.Mesh(global_devices, ("devices",))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("devices"))
local_ds = mesh.local_devices
x = np.asarray(local_slice)
xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)
global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)
[docs]
def get_local_slice_from_fsarray(
global_array: jnp.ndarray
):
"""Return numpy array for the host-local slice of fully-sharded array.
Args:
global_array: JAX array, globally sharded on devices across hosts (potentially undressable).
Returns:
NumPy array that holds the part of `global_array` that is held by the
devices on the host that calls this function.
"""
# For now, for simplicity, we only implement slicing along the first axis.
for shard in global_array.addressable_shards:
assert all(idx == slice(None) for idx in shard.index[1:]), (
f"global_array is sharded along non-first dimensions:\n{shard.index}")
# Get the shards back in the same order in which the global array was created
# in the first place. This makes sure it's consistent with other things in the
# batch, for example (assuming the whole batch is consistent).
m = {s.device: s for s in global_array.addressable_shards}
local_shards = [m[d] for d in global_array.sharding.mesh.local_devices]
return np.concatenate([jax.device_get(s.data) for s in local_shards], axis=0)
[docs]
def update_model_sharding(
graphdef: nnx.GraphDef,
loaded_state: nnx.State,
loaded_rng_state: nnx.RngKey,
ema: ema.EMA,
loaded_ema_state: nnx.State,
mesh: Mesh,
sharding_strategy: list[tuple[str, str]],
):
"""Updates the model sharding for optimizer and EMA state.
Args:
graphdef: The graph definition of the optimizer.
loaded_state: The loaded state of the optimizer.
loaded_rng_state: The loaded rng state of the optimizer.
ema: The EMA object.
loaded_ema_state: The loaded state of the EMA.
mesh: The mesh.
sharding_strategy: The sharding strategy.
Returns:
graphdef: The graph definition of the optimizer.
state: The resharded state of the optimizer.
ema_graphdef: The graph definition of the EMA.
ema_state: The resharded state of the EMA.
state_sharding: The sharding of the optimizer.
ema_state_sharding: The sharding of the EMA.
"""
loaded_state = jax.device_get(loaded_state) # <-- required, otherwise orbax will load as SingleDeviceArray
loaded_rng_state = jax.device_get(loaded_rng_state)
loaded_ema_state = jax.device_get(loaded_ema_state.ema)
optimizer = nnx.merge(graphdef, loaded_rng_state, loaded_state)
ema.load(loaded_ema_state)
with mesh:
graphdef, state = nnx.split(optimizer)
ema_graphdef, ema_state = nnx.split(ema)
state_sharding = infer_sharding(state, sharding_strategy, mesh)
state = jax.lax.with_sharding_constraint(state, state_sharding)
ema_state_sharding = infer_sharding(ema_state, sharding_strategy, mesh)
ema_state = jax.lax.with_sharding_constraint(ema_state, ema_state_sharding)
return graphdef, state, ema_graphdef, ema_state, state_sharding, ema_state_sharding