Source code for utils.checkpoint

"""File containing the functions for model checkpointing."""

# built-in libs

# external libs
from absl import logging
from etils import epath
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import orbax.checkpoint as ocp

# deps


# TODO: update the following functions to support sharding.
[docs] def build_checkpoint_manager( ckpt_dir: str, *, save_interval_steps: int, max_to_keep: int, keep_period: int, step_prefix: str = 'checkpoint', enable_async_checkpointing: bool = True, ) -> ocp.CheckpointManager: """Create a checkpoint manager for saving and restoring checkpoints during training.""" options = ocp.CheckpointManagerOptions( save_interval_steps=save_interval_steps, # this handles the control flow of how many steps to save max_to_keep=max_to_keep, # this handles the control flow of how many checkpoints to keep step_prefix=step_prefix, keep_period=keep_period, # this keeps step % keep_period == 0; can be used as backup enable_async_checkpointing=enable_async_checkpointing ) return ocp.CheckpointManager(ckpt_dir, options=options)
[docs] def save_checkpoints( ckpt_dir: str, step: int, optimizer_state: nnx.State, rng_state: nnx.RngKey, ema_state: nnx.State, *, mngr: ocp.CheckpointManager | None = None, ): """Save checkpoints for model and optimizer state. Args: - ckpt_dir: checkpoint directory. - step: current step. - optimizer_state: optimizer state. **Note** this is an analogy to Flax.TrainState, which includes both opt_state & model_state - rng_state: the current rng key. - ema_state: ema state. """ if mngr is None: # persistent manager not supplied; use async checkpointer instead. logging.warning('Checkpoint Manager not supplied; using default Checkpointer instead.') ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler()) ckptr.save( epath.Path(ckpt_dir) / f'checkpoint_{step}', args=ocp.args.Composite( state=ocp.args.StandardSave(optimizer_state), ema_state=ocp.args.StandardSave(ema_state), rng_state=ocp.args.StandardSave( jax.tree.map(jax.random.key_data, rng_state) ), ) ) else: # persistent manager supplied; use it to manage saving logics mngr.save( step, args=ocp.args.Composite( state=ocp.args.StandardSave(optimizer_state), ema_state=ocp.args.StandardSave(ema_state), rng_state=ocp.args.StandardSave( jax.tree.map(jax.random.key_data, rng_state) ), ) )
[docs] def restore_checkpoints( ckpt_dir: str, step: int, abstract_optimizer_state: nnx.State, abstract_rng_state: nnx.RngKey, abstract_ema_state: nnx.State, *, ema_only: bool = False, mngr: ocp.CheckpointManager | None = None, ) -> nnx.State: """Restore checkpoints for model and optimizer state. Args: - ckpt_dir: checkpoint directory. - step: current step. - optimizer_state: abstract optimizer state. **Note** this is an analogy to Flax.TrainState, which includes both opt_state & model_state - rng_state: abstract rng state - ema_state: abstract ema state. Return: - state: restored training state. - ema_state: restored ema state. """ if step is None: return abstract_optimizer_state, abstract_rng_state, abstract_ema_state if ema_only: ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler()) restore_args = ocp.args.Composite( ema_state=ocp.args.StandardRestore(abstract_ema_state['ema']['network']) ) state_restored = ckptr.restore( epath.Path(ckpt_dir), args=restore_args ) state_restored = {'ema': {'network': state_restored.ema_state.to_pure_dict()}} nnx.State.replace_by_pure_dict(abstract_ema_state, state_restored) return abstract_ema_state else: restore_args = ocp.args.Composite( state=ocp.args.StandardRestore(abstract_optimizer_state), ema_state=ocp.args.StandardRestore(abstract_ema_state), rng_state=ocp.args.StandardRestore( jax.tree.map(jax.random.key_data, abstract_rng_state) ), ) if mngr is None: # persistent manager not supplied; use async checkpointer instead. logging.warning('Checkpoint Manager not supplied; using default Checkpointer instead.') ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler()) state_restored = ckptr.restore( epath.Path(ckpt_dir) / f'checkpoint_{step}', args=restore_args ) else: # persistent manager supplied; use it to manage saving logics state_restored = mngr.restore( step, args=restore_args ) return state_restored.state, state_restored.rng_state, state_restored.ema_state