Utilities API#
The utils module contains various utility functions and helpers.
Checkpointing#
File containing the functions for model checkpointing.
- utils.checkpoint.build_checkpoint_manager(ckpt_dir: str, *, save_interval_steps: int, max_to_keep: int, keep_period: int, step_prefix: str = 'checkpoint', enable_async_checkpointing: bool = True) CheckpointManager[source]#
Create a checkpoint manager for saving and restoring checkpoints during training.
- utils.checkpoint.save_checkpoints(ckpt_dir: str, step: int, optimizer_state: State, rng_state: RngKey, ema_state: State, *, mngr: CheckpointManager | None = None)[source]#
Save checkpoints for model and optimizer state.
- Parameters:
ckpt_dir (-) – checkpoint directory.
step (-) – current step.
optimizer_state (-) – optimizer state. Note this is an analogy to Flax.TrainState, which includes both opt_state & model_state
rng_state (-) – the current rng key.
ema_state (-) – ema state.
- utils.checkpoint.restore_checkpoints(ckpt_dir: str, step: int, abstract_optimizer_state: State, abstract_rng_state: RngKey, abstract_ema_state: State, *, ema_only: bool = False, mngr: CheckpointManager | None = None) State[source]#
Restore checkpoints for model and optimizer state.
- Parameters:
ckpt_dir (-) – checkpoint directory.
step (-) – current step.
optimizer_state (-) – abstract optimizer state. Note this is an analogy to Flax.TrainState, which includes both opt_state & model_state
rng_state (-) – abstract rng state
ema_state (-) – abstract ema state.
- Returns:
restored training state. - ema_state: restored ema state.
- Return type:
state
EMA (Exponential Moving Average)#
File containing the Exponential Moving Average (EMA) implementation.
- utils.ema.get_network(module: Module)[source]#
Helper function that recursively traverses modules to find the first object named network.
This function is used in the case where there are multiple loss wrappers around the network, and in Eval / EMA only the network parameters are needed.
Google Cloud Utilities#
File containing helper functions for accessing Google Cloud Storage.
- utils.gcloud_utils.list_directories(bucket_name, prefix)[source]#
List all directories in the given bucket.
- utils.gcloud_utils.count_directories(bucket_name, prefix)[source]#
Count the number of directories in the given bucket. Used to obtain the numeral prefix for the checkpoint.
- utils.gcloud_utils.directory_exists(bucket_name, prefix, directory)[source]#
Check wether the given directory exists under the given bucket.
- utils.gcloud_utils.get_directory_index(bucket_name, prefix, directory)[source]#
Get the index of the given directory under the given bucket.
Sharding#
File containing the sharding utils.
- utils.sharding_utils.flatten_state(state: State, path: tuple[str, ...] = ())[source]#
Recursively traverse an NNX VariableState, yielding (path, VariableState).
- utils.sharding_utils.place_like_target(tree, target)[source]#
Place the tree following the sharding of the target.
- utils.sharding_utils.replicate()[source]#
Sharding tactic to fully replicate a parameter (no sharding on any axis).
- utils.sharding_utils.fsdp(axis: str, min_size_to_shard_mb: float = 4)[source]#
Fully Sharded Data Parallel tactic - shard largest available dimension along given mesh axis.
- utils.sharding_utils.infer_sharding(state: State, strategy: str, mesh: Mesh)[source]#
Infer a sharding specification for an NNX model state based on regex strategy. :param state: nnx.State (VariableState pytree) of the model’s parameters. :param strategy: list of (regex_pattern, tactic) pairs.
Tactic can be either a string like ‘fsdp(axis=”X”)’ or a callable.
- Parameters:
mesh – jax.sharding.Mesh defining device mesh axes.
- Returns:
A PyTree with same structure as state, but leaves are nnx.sharding.NamedSharding.
- utils.sharding_utils.create_device_mesh(config_mesh: list[tuple[str, int]], *, allow_split_physical_axes: bool = False)[source]#
Returns a JAX device mesh.
- Parameters:
config_mesh – A list of tuples of (axis_name, axis_size). It is advised to
intensity. (sort the axis in increasing order of network communication)
allow_split_physical_axes – Whether to allow splitting physical axes.
- utils.sharding_utils.extract_subtree_sharding(full_sharding, subtree, prefix_to_remove: str = 'model')[source]#
Extracts the sharding of a subtree from a fully-sharded tree.
Note: this function assumes substree to be a strict subset of full_sharding.
Args: - full_sharding: The fully-sharded tree. - subtree: The subtree whose sharding is to be extracted. - prefix_to_remove: The prefix to remove from the subtree’s name to match with the full sharding.
Returns: - The sharding of the subtree.
- utils.sharding_utils.make_fsarray_from_local_slice(local_slice: Array, global_devices: list)[source]#
Create a fully-sharded global device array from local host arrays.
- Parameters:
local_slice – Something convertible to a numpy array (eg also TF tensors)
array. (that is this host's slice of the global)
global_devices – The list of global devices. Needed for consistent ordering.
- Returns:
The global on-device array which consists of all local slices stacked together in the order consistent with the devices.
- utils.sharding_utils.get_local_slice_from_fsarray(global_array: Array)[source]#
Return numpy array for the host-local slice of fully-sharded array.
- Parameters:
global_array – JAX array, globally sharded on devices across hosts (potentially undressable).
- Returns:
NumPy array that holds the part of global_array that is held by the devices on the host that calls this function.
- utils.sharding_utils.update_model_sharding(graphdef: GraphDef, loaded_state: State, loaded_rng_state: RngKey, ema: EMA, loaded_ema_state: State, mesh: Mesh, sharding_strategy: list[tuple[str, str]])[source]#
Updates the model sharding for optimizer and EMA state.
- Parameters:
graphdef – The graph definition of the optimizer.
loaded_state – The loaded state of the optimizer.
loaded_rng_state – The loaded rng state of the optimizer.
ema – The EMA object.
loaded_ema_state – The loaded state of the EMA.
mesh – The mesh.
sharding_strategy – The sharding strategy.
- Returns:
The graph definition of the optimizer. state: The resharded state of the optimizer. ema_graphdef: The graph definition of the EMA. ema_state: The resharded state of the EMA. state_sharding: The sharding of the optimizer. ema_state_sharding: The sharding of the EMA.
- Return type:
graphdef
Visualization#
File containing utilities for generating visualization.
- utils.visualize.visualize(config: ConfigDict, net: Module, ema_net: Module, encoder: Module, sampler: Samplers, step: int, g_net: Module | None = None, guidance_scale: float | None = None, mesh: Mesh | None = None)[source]#
Generate and log samples from the model.
- Parameters:
config (-) – configuration for the training.
net (-) – nnx.Module, the network for training.
ema_net (-) – nnx.Module, the ema network.
encoder (-) – nnx.Module, the encoder for training.
n (-) – jnp.ndarray, the initial noise.
c (-) – jnp.ndarray, the initial condition.
guidance_scale (-) – float, the guidance weight for the guidance network.
sampler (-) – samplers.Samplers, the sampler for the network.
- utils.visualize.visualize_reconstruction(config: ConfigDict, encoder: Module, x: Array, mesh: Mesh | None = None)[source]#
Reconstruct and log samples from the encoder.
- Parameters:
config (-) – the configuration for the training.
encoder (-) – the encoder for the network.
x (-) – the original samples.
mesh (-) – the mesh for the distributed sampling.