"""File containing the functions for model checkpointing."""
# built-in libs
# external libs
from absl import logging
from etils import epath
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import orbax.checkpoint as ocp
# deps
# TODO: update the following functions to support sharding.
[docs]
def build_checkpoint_manager(
ckpt_dir: str,
*,
save_interval_steps: int,
max_to_keep: int,
keep_period: int,
step_prefix: str = 'checkpoint',
enable_async_checkpointing: bool = True,
) -> ocp.CheckpointManager:
"""Create a checkpoint manager for saving and restoring checkpoints during training."""
options = ocp.CheckpointManagerOptions(
save_interval_steps=save_interval_steps, # this handles the control flow of how many steps to save
max_to_keep=max_to_keep, # this handles the control flow of how many checkpoints to keep
step_prefix=step_prefix,
keep_period=keep_period, # this keeps step % keep_period == 0; can be used as backup
enable_async_checkpointing=enable_async_checkpointing
)
return ocp.CheckpointManager(ckpt_dir, options=options)
[docs]
def save_checkpoints(
ckpt_dir: str,
step: int,
optimizer_state: nnx.State,
rng_state: nnx.RngKey,
ema_state: nnx.State,
*,
mngr: ocp.CheckpointManager | None = None,
):
"""Save checkpoints for model and optimizer state.
Args:
- ckpt_dir: checkpoint directory.
- step: current step.
- optimizer_state: optimizer state. **Note** this is an analogy to Flax.TrainState,
which includes both opt_state & model_state
- rng_state: the current rng key.
- ema_state: ema state.
"""
if mngr is None:
# persistent manager not supplied; use async checkpointer instead.
logging.warning('Checkpoint Manager not supplied; using default Checkpointer instead.')
ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler())
ckptr.save(
epath.Path(ckpt_dir) / f'checkpoint_{step}',
args=ocp.args.Composite(
state=ocp.args.StandardSave(optimizer_state),
ema_state=ocp.args.StandardSave(ema_state),
rng_state=ocp.args.StandardSave(
jax.tree.map(jax.random.key_data, rng_state)
),
)
)
else:
# persistent manager supplied; use it to manage saving logics
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(optimizer_state),
ema_state=ocp.args.StandardSave(ema_state),
rng_state=ocp.args.StandardSave(
jax.tree.map(jax.random.key_data, rng_state)
),
)
)
[docs]
def restore_checkpoints(
ckpt_dir: str,
step: int,
abstract_optimizer_state: nnx.State,
abstract_rng_state: nnx.RngKey,
abstract_ema_state: nnx.State,
*,
ema_only: bool = False,
mngr: ocp.CheckpointManager | None = None,
) -> nnx.State:
"""Restore checkpoints for model and optimizer state.
Args:
- ckpt_dir: checkpoint directory.
- step: current step.
- optimizer_state: abstract optimizer state. **Note** this is an analogy to Flax.TrainState,
which includes both opt_state & model_state
- rng_state: abstract rng state
- ema_state: abstract ema state.
Return:
- state: restored training state.
- ema_state: restored ema state.
"""
if step is None:
return abstract_optimizer_state, abstract_rng_state, abstract_ema_state
if ema_only:
ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler())
restore_args = ocp.args.Composite(
ema_state=ocp.args.StandardRestore(abstract_ema_state['ema']['network'])
)
state_restored = ckptr.restore(
epath.Path(ckpt_dir),
args=restore_args
)
state_restored = {'ema': {'network': state_restored.ema_state.to_pure_dict()}}
nnx.State.replace_by_pure_dict(abstract_ema_state, state_restored)
return abstract_ema_state
else:
restore_args = ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_optimizer_state),
ema_state=ocp.args.StandardRestore(abstract_ema_state),
rng_state=ocp.args.StandardRestore(
jax.tree.map(jax.random.key_data, abstract_rng_state)
),
)
if mngr is None:
# persistent manager not supplied; use async checkpointer instead.
logging.warning('Checkpoint Manager not supplied; using default Checkpointer instead.')
ckptr = ocp.Checkpointer(ocp.CompositeCheckpointHandler())
state_restored = ckptr.restore(
epath.Path(ckpt_dir) / f'checkpoint_{step}',
args=restore_args
)
else:
# persistent manager supplied; use it to manage saving logics
state_restored = mngr.restore(
step,
args=restore_args
)
return state_restored.state, state_restored.rng_state, state_restored.ema_state