FSDP In Jax NNX#
If you find yourself with the daunting task of implementing production-level FSDP in JAX NNX, then this tutorial is for you. This notebook will guide you step by step through the process.
You will learn how to implement a fully working FSDP on TPU — with all critical operations JIT compiled — that evenly shards all weights across the devices together with DDP. Additionally, you will see how to use distributed checkpointing to save to/restore from disk or GCP bucket via Orbax, set up reproducible nnx.Rngs for noise generation and dropout, and maintain an EMA model that is also sharded.
Let’s begin.
First, let’s set some env variables. They will determine what packages we install.
[ ]:
COLAB=True # Set this to False if you are running this notebook outside of Google Colab
Install Python dependencies based on the env variables above.
[2]:
import sys
import subprocess
packages = ["jax[tpu]==0.5.1", "optax==0.2.4", "orbax-checkpoint==0.11.16", "flax==0.10.4"]
if not COLAB:
packages += ["numpy==1.26.4", "torch==2.7.0", "matplotlib==3.10.3", "pillow==11.3.0", "gcsfs==2025.9.0"]
print(f"Installing {packages} ...")
subprocess.check_call([sys.executable, "-m", "pip", "install", *packages])
Installing ['jax[tpu]==0.5.1', 'optax==0.2.4', 'orbax-checkpoint==0.11.16', 'flax==0.10.4', 'numpy==1.26.4', 'torch==2.7.0', 'matplotlib==3.10.3', 'pillow==11.3.0', 'gcsfs==2025.9.0'] ...
Requirement already satisfied: jax==0.5.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax[tpu]==0.5.1) (0.5.1)
Requirement already satisfied: optax==0.2.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.2.4)
Requirement already satisfied: orbax-checkpoint==0.11.16 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.11.16)
Requirement already satisfied: flax==0.10.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.10.4)
Requirement already satisfied: numpy==1.26.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (1.26.4)
Requirement already satisfied: torch==2.7.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (2.7.0)
Requirement already satisfied: matplotlib==3.10.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (3.10.3)
Requirement already satisfied: pillow==11.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (11.3.0)
Requirement already satisfied: gcsfs==2025.9.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (2025.9.0)
Requirement already satisfied: jaxlib<=0.5.1,>=0.5.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (0.5.1)
Requirement already satisfied: ml_dtypes>=0.4.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (0.5.3)
Requirement already satisfied: opt_einsum in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (1.16.2)
Requirement already satisfied: absl-py>=0.7.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (2.3.1)
Requirement already satisfied: chex>=0.1.87 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (0.1.90)
Requirement already satisfied: etils[epy] in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (1.13.0)
Requirement already satisfied: typing_extensions in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (4.15.0)
Requirement already satisfied: msgpack in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (1.1.1)
Requirement already satisfied: pyyaml in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (6.0.3)
Requirement already satisfied: tensorstore>=0.1.71 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (0.1.77)
Requirement already satisfied: nest_asyncio in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (1.6.0)
Requirement already satisfied: protobuf in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (6.32.1)
Requirement already satisfied: humanize in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (4.13.0)
Requirement already satisfied: simplejson>=3.16.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (3.20.2)
Requirement already satisfied: rich>=11.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from flax==0.10.4) (14.1.0)
Requirement already satisfied: treescope>=0.1.7 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from flax==0.10.4) (0.1.10)
Requirement already satisfied: filelock in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.19.1)
Requirement already satisfied: setuptools in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (78.1.1)
Requirement already satisfied: sympy>=1.13.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (1.14.0)
Requirement already satisfied: networkx in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.5)
Requirement already satisfied: jinja2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.1.6)
Requirement already satisfied: fsspec in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (2025.9.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (9.5.1.17)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (0.6.3)
Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (2.26.2)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (1.11.1.6)
Requirement already satisfied: triton==3.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.3.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (4.60.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (25.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (3.2.5)
Requirement already satisfied: python-dateutil>=2.7 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (2.9.0.post0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (3.12.15)
Requirement already satisfied: decorator>4.1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (5.2.1)
Requirement already satisfied: google-auth>=1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (2.41.1)
Requirement already satisfied: google-auth-oauthlib in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (1.2.2)
Requirement already satisfied: google-cloud-storage in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (3.4.0)
Requirement already satisfied: requests in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (2.32.5)
Requirement already satisfied: libtpu==0.0.10 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax[tpu]==0.5.1) (0.0.10)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.4.0)
Requirement already satisfied: attrs>=17.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.7.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (6.6.4)
Requirement already satisfied: propcache>=0.2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.21.0)
Requirement already satisfied: idna>=2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (3.10)
Requirement already satisfied: toolz>=0.9.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from chex>=0.1.87->optax==0.2.4) (1.0.0)
Requirement already satisfied: cachetools<7.0,>=2.0.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (6.2.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (4.9.1)
Requirement already satisfied: pyasn1>=0.1.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rsa<5,>=3.1.4->google-auth>=1.2->gcsfs==2025.9.0) (0.6.1)
Requirement already satisfied: six>=1.5 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib==3.10.3) (1.17.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.4) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.4) (2.19.2)
Requirement already satisfied: mdurl~=0.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax==0.10.4) (0.1.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from sympy>=1.13.3->torch==2.7.0) (1.3.0)
Requirement already satisfied: importlib_resources in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint==0.11.16) (6.5.2)
Requirement already satisfied: zipp in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint==0.11.16) (3.23.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth-oauthlib->gcsfs==2025.9.0) (2.0.0)
Requirement already satisfied: oauthlib>=3.0.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib->gcsfs==2025.9.0) (3.3.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (3.4.3)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (2025.10.5)
Requirement already satisfied: google-api-core<3.0.0,>=2.15.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.25.2)
Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.4.3)
Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.7.2)
Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (1.7.1)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-api-core<3.0.0,>=2.15.0->google-cloud-storage->gcsfs==2025.9.0) (1.70.0)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-api-core<3.0.0,>=2.15.0->google-cloud-storage->gcsfs==2025.9.0) (1.26.1)
Requirement already satisfied: MarkupSafe>=2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jinja2->torch==2.7.0) (3.0.3)
[2]:
0
[3]:
import argparse
import functools
import logging
import os
from typing import Any, Generator, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
import torch
from flax import nnx
from jax import random
from jax.experimental import mesh_utils
from matplotlib.figure import Figure
from torch.utils.data import DataLoader, Dataset
Here, we define our hyperparameters and other variables we most likely would want to configure/adjust.
[ ]:
args = argparse.Namespace(
experiment_name="fsdp",
gpu=False,
steps=5_000,
test_interval=1000,
batch_size=256,
log_interval=100,
save_interval=2500,
checkpoint_dir=os.path.abspath("checkpoints/"),
output_dir=os.path.abspath("outputs/"),
lr=1e-4,
add_noise=False
)
Enabling INFO level logging.
[5]:
log_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[logging.StreamHandler()],
force=True,
)
At the very beginning of our program, we need to initialize the JAX distributed framework.
[6]:
jax.distributed.initialize()
INFO:2025-10-05 20:32:00,707:jax._src.distributed:130: Starting JAX distributed service on [::]:8476
2025-10-05 20:32:00,707 - INFO - Starting JAX distributed service on [::]:8476
INFO:2025-10-05 20:32:00,709:jax._src.distributed:147: Connecting to JAX distributed service on 10.202.0.129:8476
2025-10-05 20:32:00,709 - INFO - Connecting to JAX distributed service on 10.202.0.129:8476
Here you should see the devices available to you. In my case, it’s 4 TPU chips.
[7]:
jax.devices()
[7]:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
Now let’s define our JAX mesh and sharding axis. We only have one axis: data. We will use it for both the model and the data. This tells JAX to shard the model parameters across all devices in the mesh and also split the batch across all devices. Concretely, if you have N devices in your mesh, a single device stores 1/N of the model weights and processes 1/N of the global batch.
[8]:
data_axis = "data"
device_mesh = mesh_utils.create_device_mesh(
(jax.device_count(),), devices=jax.devices()
)
mesh = jax.sharding.Mesh(device_mesh, (data_axis,))
Here, we define two types of sharding. One that does sharding (along the data axis) and the other that doesn’t, i.e., replication.
[9]:
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(data_axis)
)
repl_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
In this tutorial, we will use a simple MLP as our model. It will learn a real number to a real number function mapping. It will also have a dropout layer to show how our implementation works with RNGs, which a real-world model would require.
[10]:
IN_FEATURES = 1
OUT_FEATURES = 1
HIDDEN_DIM = 1024
[11]:
class MLP(nnx.Module):
"""A Multi-Layer Perceptron (MLP) neural network using Flax NNX.
This is a simple feedforward neural network with two hidden layers,
ReLU activations, and dropout regularization.
Args:
din: Number of input features.
dmid: Number of hidden units in each hidden layer.
dout: Number of output features.
rngs: Random number generators for parameter initialization and dropout.
"""
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs) -> None:
"""Initialize the MLP with specified dimensions.
Args:
din: Number of input features.
dmid: Number of hidden units in each hidden layer.
dout: Number of output features.
rngs: Random number generators for parameter initialization and dropout.
"""
self.fc1 = nnx.Linear(din, dmid, rngs=rngs)
self.fc2 = nnx.Linear(dmid, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.fc3 = nnx.Linear(dmid, dout, rngs=rngs)
self.rngs = rngs
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass through the MLP.
Args:
x: Input tensor of shape (batch_size, din).
Returns:
Output tensor of shape (batch_size, dout).
"""
x = self.fc1(x)
x = nnx.relu(x)
x = self.fc2(x)
x = nnx.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
Many modern codebases have an EMA (Exponential Moving Average) model. We will have one as well to show how it fits into our FSDP JAX implementation. To initialize our EMA, we will clone our model state, replacing the weights with all zeros.
[12]:
def init_ema(model: nnx.Module) -> nnx.State:
"""Initialize exponential moving average (EMA) state for a model.
Creates a zero-initialized state tree with the same structure as the model's state.
Args:
model: The neural network model to create EMA state for.
Returns:
EMA state with the same structure as the model state, but zero-initialized.
"""
ema_state = jax.tree.map(lambda x: jnp.zeros_like(x), nnx.state(model))
return ema_state
This is the core initialization function where we initialize everything that we want to FSDP: the model, optimizer, and EMA.
[13]:
def init(learning_rate: float) -> Tuple[nnx.GraphDef, nnx.State, nnx.State]:
"""Initialize the model, optimizer, and EMA state.
Creates a new MLP model, wraps it in an AdamW optimizer, and initializes
the exponential moving average state.
Args:
learning_rate: Learning rate for the AdamW optimizer.
Returns:
Tuple of (optimizer_graph, optimizer_state, ema_state).
"""
model = MLP(
IN_FEATURES,
HIDDEN_DIM,
OUT_FEATURES,
rngs=nnx.Rngs(0, dropout=random.key(1), noise=random.key(2)),
)
opt = nnx.Optimizer(
model,
optax.adamw(learning_rate=learning_rate),
)
opt_graph, opt_state = nnx.split(opt)
ema_state = init_ema(model)
return opt_graph, opt_state, ema_state
It’s convenient to have it as a function with no arguments because we will use jax.eval_shape() on it to figure out the shapes of the states we want to shard later.
[14]:
init_fn = functools.partial(init, args.lr)
_, opt_state_shape, ema_state_shape = jax.eval_shape(init_fn)
logging.info(f"Opt state shape: {opt_state_shape}")
logging.info(f"EMA state shape: {ema_state_shape}")
2025-10-05 20:32:08,664 - INFO - Opt state shape: State({
'model': {
'dropout': {
'rngs': {
'default': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='default'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='default'
)
},
'dropout': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='dropout'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='dropout'
)
},
'noise': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='noise'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='noise'
)
}
}
},
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1, 1024), dtype=float32)
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32)
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=Param,
value=ShapeDtypeStruct(shape=(1024, 1024), dtype=float32)
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=Param,
value=ShapeDtypeStruct(shape=(1,), dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024, 1), dtype=float32)
)
}
},
'opt_state': {
0: {
'count': VariableState( # 1 (4 B)
type=OptArray,
value=ShapeDtypeStruct(shape=(), dtype=int32)
),
'mu': {
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1, 1024), dtype=float32),
source_type=Param
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024, 1024), dtype=float32),
source_type=Param
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024, 1), dtype=float32),
source_type=Param
)
}
},
'nu': {
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1, 1024), dtype=float32),
source_type=Param
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024, 1024), dtype=float32),
source_type=Param
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1,), dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=ShapeDtypeStruct(shape=(1024, 1), dtype=float32),
source_type=Param
)
}
}
}
},
'step': VariableState( # 1 (4 B)
type=OptState,
value=ShapeDtypeStruct(shape=(), dtype=uint32)
)
})
2025-10-05 20:32:08,665 - INFO - EMA state shape: State({
'dropout': {
'rngs': {
'default': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='default'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='default'
)
},
'dropout': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='dropout'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='dropout'
)
},
'noise': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=ShapeDtypeStruct(shape=(), dtype=uint32),
tag='noise'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=ShapeDtypeStruct(shape=(), dtype=key<fry>),
tag='noise'
)
}
}
},
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1, 1024), dtype=float32)
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024,), dtype=float32)
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=Param,
value=ShapeDtypeStruct(shape=(1024, 1024), dtype=float32)
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=Param,
value=ShapeDtypeStruct(shape=(1,), dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=ShapeDtypeStruct(shape=(1024, 1), dtype=float32)
)
}
})
[ ]:
As you can see, our EMA state has all the same parameters as the model that is stored inside the optimizer. The EMA state contains the rngs objects as well (more about it later).
The next three functions annotate the state shapes with sharding info based on your mesh. What that means is that it determines whether a weight should be sharded (only weights with the number of parameters >= 1024*1024 are sharded). If it should be sharded, it virtually splits it into shards and assigns them to devices. We will see what it looks like below.
[15]:
def fsdp(
axis: str,
cur_spec: Tuple[Any, ...],
mesh: jax.sharding.Mesh,
var_state: nnx.VariableState,
min_size_to_shard: int,
) -> Tuple[Any, ...]:
"""Implement Fully Sharded Data Parallel (FSDP) sharding strategy.
Determines how to shard a parameter tensor across devices. Shards the largest
dimension that is divisible by the number of devices and meets the minimum size requirement.
Args:
axis: Name of the mesh axis to shard along.
cur_spec: Current partition specification.
mesh: JAX device mesh.
var_state: Variable state containing the parameter tensor.
min_size_to_shard: Minimum tensor size to consider for sharding.
Returns:
Updated partition specification with sharding applied if appropriate.
"""
arr = var_state.value
if arr is None:
return cur_spec
shape = tuple(arr.shape)
axis_size = mesh.shape[axis]
if arr.size < min_size_to_shard:
return cur_spec
dim_indices = sorted(range(len(shape)), key=lambda i: shape[i], reverse=True)
for i in dim_indices:
if cur_spec[i] is None and shape[i] % axis_size == 0:
new_spec = list(cur_spec)
new_spec[i] = axis
return tuple(new_spec)
return cur_spec
def flatten_state(
state: nnx.State, path: Tuple[str, ...] = ()
) -> Generator[Tuple[str, nnx.VariableState], None, None]:
"""Recursively flatten a nested state tree into (name, variable_state) pairs.
Traverses the state tree and yields each variable with its hierarchical path name.
Args:
state: The state tree to flatten (can be nested).
path: Current path in the hierarchy (used for recursion).
Yields:
Tuples of (path_name, variable_state) for each leaf variable.
"""
if isinstance(state, nnx.VariableState):
name = "/".join(str(p) for p in path)
yield name, state
elif hasattr(state, "items"):
for key, subtree in state.items():
yield from flatten_state(subtree, path + (key,))
elif isinstance(state, (list, tuple)):
for idx, subtree in enumerate(state):
yield from flatten_state(subtree, path + (str(idx),))
def infer_sharding(
state: nnx.State,
mesh: jax.sharding.Mesh,
axis: str,
min_size_to_shard: int = 2**20,
) -> nnx.State:
"""Infer optimal sharding strategy for a model state using FSDP.
Analyzes each parameter in the state and determines the best sharding strategy
based on tensor size and dimensions. Creates a sharding tree that matches
the structure of the input state.
Args:
state: Model state to create sharding for.
mesh: JAX device mesh for distributed computation.
axis: Name of the mesh axis for sharding.
min_size_to_shard: Minimum tensor size to consider for sharding.
Returns:
Sharding tree with the same structure as the input state.
"""
flat_params = list(flatten_state(state))
vars_states = [vs for _, vs in flat_params]
specs = [
(None,) * vs.value.ndim if vs.value is not None else () for vs in vars_states
]
for i, _ in enumerate(flat_params):
specs[i] = fsdp(axis, specs[i], mesh, vars_states[i], min_size_to_shard)
shardings = [
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
for spec in specs
]
sharding_tree = jax.tree_util.tree_unflatten(
jax.tree_util.tree_structure(
state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
),
shardings,
)
return sharding_tree
Here, we call the top-level function infer_sharding() to get the state sharding objects.
[16]:
opt_state_sharding = infer_sharding(opt_state_shape, mesh, data_axis)
ema_state_sharding = infer_sharding(ema_state_shape, mesh, data_axis)
logging.info(f"Opt state sharding: {opt_state_sharding}")
logging.info(f"EMA state sharding: {ema_state_sharding}")
2025-10-05 20:32:08,680 - INFO - Opt state sharding: State({
'model': {
'dropout': {
'rngs': {
'default': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
},
'dropout': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
},
'noise': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
}
}
},
'fc1': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
},
'fc2': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)
},
'fc3': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
}
},
'opt_state': {
0: {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'mu': {
'fc1': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
},
'fc2': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)
},
'fc3': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
}
},
'nu': {
'fc1': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
},
'fc2': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)
},
'fc3': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
}
}
}
},
'step': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
})
2025-10-05 20:32:08,681 - INFO - EMA state sharding: State({
'dropout': {
'rngs': {
'default': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
},
'dropout': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
},
'noise': {
'count': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),
'key': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)
}
}
},
'fc1': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
},
'fc2': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)
},
'fc3': {
'bias': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),
'kernel': NamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)
}
})
As you can see, all weights except for the kernel in fc1 have PartitionSpec(None, None) assigned to them. It means they won’t be sharded. This is what we want because their sizes are < 1024*1024. Only the kernel in fc1, which maps from 1024 to 1024, passed the threshold and has PartitionSpec('data', None) assigned to it. It means that its weights array of shape (1024, 1024) will be sharded on the first dimension along the data sharding axis. The second
dimension won’t be sharded because it’s enough to just shard across one dimension to ensure all parameters are evenly distributed across the devices.
Now it’s time to materialize our NNX modules on all our devices. The jax.jit() wrapper is crucial here because it lets JAX first trace our objects according to their state shardings and then, when the function is actually executed, the objects materialize on the corresponding devices. This ensures your entire state never materializes on a single device. Without the wrapper, JAX would first materialize the objects on a single device and then shard them, and if your model doesn’t fit into a
single device, the program will OOM at this point.
[17]:
opt_graph, opt_state, ema_state = jax.jit(
init_fn,
out_shardings=(repl_sharding, opt_state_sharding, ema_state_sharding),
)()
Let’s define a helper debug logging function to see what our sharded states look like in more detail.
[18]:
def log_shard_map(tag: str, state: nnx.State) -> None:
"""Log the sharding mapping of arrays to devices for debugging.
Prints a detailed breakdown of how each parameter is sharded across devices,
showing which array indices are stored on which devices.
Args:
tag: Descriptive tag for the logging output.
state: Model state to analyze for sharding information.
"""
logging.info(f"── Shard ↦ device map: {tag} ──")
for name, var in flatten_state(state):
arr = var.value if isinstance(var, nnx.VariableState) else var
for d, idx in arr.sharding.devices_indices_map(arr.shape).items():
logging.info(f" {name} {idx} → {d}")
if jax.process_index() == 0:
log_shard_map("Opt state sharding", opt_state)
log_shard_map("EMA state sharding", ema_state)
2025-10-05 20:32:09,510 - INFO - ── Shard ↦ device map: Opt state sharding ──
2025-10-05 20:32:09,510 - INFO - model/dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,517 - INFO - model/dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,518 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,519 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,521 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,522 - INFO - opt_state/0/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,523 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,523 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,534 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,534 - INFO - step () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,534 - INFO - step () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,534 - INFO - step () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,534 - INFO - step () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,535 - INFO - ── Shard ↦ device map: EMA state sharding ──
2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,537 - INFO - dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,539 - INFO - dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,541 - INFO - dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,545 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:32:09,549 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
As you can see, our fc2/kernel weights get a partial slice on each device, whereas all other weights have the full slice on all devices.
Through some NNX model surgery, we define two graphs for our model: train and eval, which we will later use for training and testing, respectively.
train() and eval() only change your graph (static information of your model). It does change the states.
It’s also worth noting that we store our models (NNX modules) as a separate graph and state at the global program level instead of a single module and only merge them right before we need to do a JAX operation on them. I find it’s easier to think about and manage your models as two separate objects: graph (static) and state (dynamic) at all other times.
[19]:
opt = nnx.merge(opt_graph, opt_state)
opt.model.train()
opt_graph, opt_state = nnx.split(opt)
opt.model.eval()
model_graph_eval, _ = nnx.split(opt.model)
Here we are initializing our distributed Orbax checkpointer.
[20]:
ckpt_mngr = ocp.CheckpointManager(
args.checkpoint_dir,
options=ocp.CheckpointManagerOptions(
save_interval_steps=args.save_interval,
max_to_keep=2,
step_prefix=args.experiment_name,
enable_async_checkpointing=False,
),
)
2025-10-05 20:32:09,563 - INFO - [thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
2025-10-05 20:32:09,564 - INFO - [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None
2025-10-05 20:32:09,564 - INFO - Initialized registry DefaultCheckpointHandlerRegistry({('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7f16887108c0>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7f16887108c0>}).
2025-10-05 20:32:09,565 - INFO - orbax-checkpoint version: 0.11.16
2025-10-05 20:32:09,565 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 1
2025-10-05 20:32:09,819 - INFO - Created directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints
2025-10-05 20:32:09,871 - INFO - [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2500, max_to_keep=2, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix='fsdp', step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=False, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False), root_directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7f29be097950>
For this tutorial, we will use a simple dataset that represents X and Y of the sinusoidal function. This is the function our model will learn to map.
[21]:
class SinDataset(Dataset):
"""A PyTorch dataset that generates sine function data points.
This dataset generates random x values from [-π, π] and computes y = sin(x).
The dataset uses a seeded random number generator for reproducible results.
Args:
seed: Random seed for reproducible data generation.
"""
def __init__(self, seed: int) -> None:
"""Initialize the dataset with a random seed.
Args:
seed: Random seed for data generation.
"""
self.seed = seed
self.reset_seed()
def reset_seed(self) -> None:
"""Reset the random number generator to the initial seed.
This is useful for ensuring reproducible evaluation data.
"""
self.rng = torch.Generator()
self.rng.manual_seed(self.seed)
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
A very large number representing the dataset size.
"""
return 2**31 - 1
def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""Generate a single data point.
Args:
idx: Index (unused, but required for Dataset interface).
Returns:
Tuple of (x, y) where x is a random value in [-π, π] and y = sin(x).
"""
x = torch.rand(1, generator=self.rng) * 2 * torch.pi - torch.pi
y = torch.sin(x)
return x.numpy(), y.numpy()
Converting our global batch size (across all JAX processes and devices) to a local batch size: the number of samples processed by the current JAX/Python process. This should not be confused with the per device/chip batch size because a single JAX process usually has more than one device assigned.
[22]:
local_batch_size = args.batch_size // jax.process_count()
We are getting close to the training code. Below, we define our train_step() function. It merges the graph and state of our optimizer into an NNX module, passes a batch through it, computes the loss, and updates the parameters with the gradients.
It’s important to understand that x and y in this function are global arrays. Their first dimension, batch dimension, is equal to the global batch size. All aggregate operations we run on them, like .mean(), run across the entire global array.
The function also optionally adds noise to the data. This doesn’t make much sense in our example problem, but working with noise is essential in some real-world models like diffusion models. It requires drawing from an RNG generator, and I added it here to show how our code supports it. The random values we sample are also global and don’t repeat across JAX processes/devices. To see how our model handles the noise, set args.add_noise=True.
[23]:
def train_step(
opt_graph: nnx.GraphDef,
opt_state: nnx.State,
x: jax.Array,
y: jax.Array,
add_noise: bool = False,
) -> Tuple[nnx.State, jax.Array]:
"""Perform a single training step with gradient computation and parameter update.
Computes the forward pass, loss, gradients, and updates model parameters.
Optionally adds noise to the target values for data augmentation.
Args:
opt_graph: Optimizer graph definition (static structure).
opt_state: Optimizer state (parameters and optimizer state).
x: Input batch of shape (batch_size, input_dim).
y: Target batch of shape (batch_size, output_dim).
add_noise: Whether to add noise to targets for data augmentation.
Returns:
Tuple of (updated_optimizer_state, loss_value).
"""
optimizer = nnx.merge(opt_graph, opt_state)
model = optimizer.model
def loss_fn(model: MLP) -> jax.Array:
y_hat = model(x)
if add_noise:
noise_key = model.rngs["noise"]()
noise = jax.random.normal(noise_key, y.shape)
y_noisy = y + noise
loss = jnp.mean((y_hat - y_noisy) ** 2)
else:
loss = jnp.mean((y_hat - y) ** 2)
return loss
grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = grad_fn(model)
optimizer.update(grads)
_, opt_state = nnx.split(optimizer)
return opt_state, loss
Now we can wrap our train_step() function in the jax.jit() decorator for JAX to compile into an optimized XLA code that will be executed on the cluster of our devices. It’s important to mark all non JAX array args as static. Additionally, to save memory, we tell JAX with donate_argnums that the argument with index 1 is also returned, and it can reuse the underlying memory buffer. Lastly, we mark that we expect our first output to be sharded because it’s our updated optimizer
state, and the second to be replicated because it’s the loss as a scalar.
[24]:
train_step_fn = jax.jit(
train_step,
donate_argnums=(1,),
static_argnums=(4,),
out_shardings=(opt_state_sharding, repl_sharding),
)
Below, we define our test function, which works very similarly to the train_step() function, except it doesn’t update the parameters.
[25]:
def test_step(
model_graph: nnx.GraphDef,
model_state: nnx.State,
x: jax.Array,
y: jax.Array,
) -> Tuple[jax.Array, jax.Array]:
"""Perform a single evaluation step without parameter updates.
Computes the forward pass and loss for evaluation purposes.
Args:
model_graph: Model graph definition (static structure).
model_state: Model state (parameters only, no optimizer state).
x: Input batch of shape (batch_size, input_dim).
y: Target batch of shape (batch_size, output_dim).
Returns:
Tuple of (loss_value, predictions).
"""
model = nnx.merge(model_graph, model_state)
y_hat = model(x)
loss = jnp.mean((y_hat - y) ** 2)
return loss, y_hat
[26]:
test_step_fn = jax.jit(
test_step,
out_shardings=(repl_sharding, data_sharding),
)
Here is our function to manage the EMA model. It updates all parameters in the EMA version using the ema_decay constant.
[27]:
def update_ema(
model_state: nnx.State,
ema_state: nnx.State,
ema_decay: float,
) -> nnx.State:
"""Update exponential moving average (EMA) of model parameters.
Computes the exponential moving average using the formula:
ema_new = ema_decay * ema_old + (1 - ema_decay) * model_param
Args:
model_state: Current model state with updated parameters.
ema_state: Current EMA state to be updated.
ema_decay: Decay factor for EMA (typically close to 1.0, e.g., 0.9999).
Returns:
Updated EMA state.
"""
def update_param(p_model: jax.Array, p_ema: jax.Array) -> jax.Array:
return p_ema * ema_decay + p_model * (1 - ema_decay)
ema_state_no_rng = jax.tree.map(
update_param,
nnx.filter_state(model_state, nnx.Param),
nnx.filter_state(ema_state, nnx.Param),
)
ema_state = nnx.merge_state(ema_state, ema_state_no_rng)
return ema_state
[28]:
update_ema_fn = jax.jit(
update_ema,
out_shardings=ema_state_sharding,
donate_argnums=(1,),
)
This function is an essential DDP function. Our JAX cluster consists of several JAX/Python processes. Each of them loads local_batch_size samples. However, as we saw before, our train_step() and test_step() functions expect global arrays with a global batch size. That’s what this function does. It takes local arrays/batches and builds them into a global, sharded array where the data from local batches is put into shards on the local devices.
[29]:
def make_fsarray_from_local_slice(
local_slice: jnp.ndarray,
global_devices: list[jax.Device],
axis: str,
) -> jax.Array:
"""Create a globally sharded array from a local data slice.
Takes a local data slice and creates a globally sharded JAX array
by distributing the data across multiple devices and processes.
This function is adapted from:
https://github.com/google-research/big_vision/blob/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/utils.py#L1388-L1409
Args:
local_slice: Local portion of the data on this process.
global_devices: List of all devices across all processes.
axis: Name of the axis for sharding.
Returns:
Globally sharded JAX array with proper device placement.
"""
mesh = jax.sharding.Mesh(global_devices, (axis,))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(axis))
local_ds = mesh.local_devices
x = np.asarray(local_slice)
xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)
global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)
Finally training. We define our training loop as a function because we will want to reuse it later.
It takes in a start step number and states, initializes dataloaders, trains, and evaluates our model according to the hyperparameters in args.
As you can see, we call make_fsarray_from_local_slice() on the local batch that the dataloader returns before supplying it to train_step() and test_step().
We update our EMA model after every train step and periodically checkpoint our states to disk and run eval. The jax.experimental.multihost_utils.process_allgather() we call on our global array before visualizing is needed to tell JAX to bring these arrays to the host memory so we can do numpy operations on it.
At least by the time of writing this, Orbax didn’t have a straightforward way of saving the NNX.Rngs object in the checkpoint. That’s why we need to filter out the rngs from the state and save them separately as a standard JAX array. We do this for EMA as well, because we will later need its rng object to restore the full EMA, but it’s worth noting that the rng object inside the EMA never gets drawn, so it stays in its initial, zero state.
[30]:
def train_loop(start_step: int, opt_state: nnx.State, ema_state: nnx.State):
train_dataloader = DataLoader(
SinDataset(seed=start_step), batch_size=local_batch_size, shuffle=False
)
test_dataset = SinDataset(seed=-1)
test_dataloader = DataLoader(
test_dataset, batch_size=local_batch_size, shuffle=False
)
train_iter = iter(train_dataloader)
ema_decay = 0.999
for step in range(start_step, start_step + args.steps):
x_batch, y_batch = next(train_iter)
x_batch = make_fsarray_from_local_slice(
x_batch, mesh.devices.flatten(), data_axis
)
y_batch = make_fsarray_from_local_slice(
y_batch, mesh.devices.flatten(), data_axis
)
opt_state, train_loss = train_step_fn(
opt_graph, opt_state, x_batch, y_batch, args.add_noise
)
ema_state = update_ema_fn(opt_state["model"], ema_state, ema_decay)
if jax.process_index() == 0 and (step + 1) % args.log_interval == 0:
logging.info(f"Step {step+1}, Train Loss: {train_loss:.6f}")
if (step + 1) % args.test_interval == 0:
test_dataset.reset_seed()
test_iter = iter(test_dataloader)
x_test, y_test = next(test_iter)
x_test = make_fsarray_from_local_slice(
x_test, mesh.devices.flatten(), data_axis
)
y_test = make_fsarray_from_local_slice(
y_test, mesh.devices.flatten(), data_axis
)
test_loss, y_pred_model = test_step_fn(
model_graph_eval, opt_state["model"], x_test, y_test
)
test_loss_ema, y_pred_ema = test_step_fn(
model_graph_eval, ema_state, x_test, y_test
)
y_pred_model = jax.experimental.multihost_utils.process_allgather(
y_pred_model, tiled=True
)
y_pred_ema = jax.experimental.multihost_utils.process_allgather(
y_pred_ema, tiled=True
)
x_test = jax.experimental.multihost_utils.process_allgather(
x_test, tiled=True
)
y_test = jax.experimental.multihost_utils.process_allgather(
y_test, tiled=True
)
if jax.process_index() == 0:
x_plot = np.array(x_test).flatten()
y_true_plot = np.array(y_test).flatten()
y_pred_ema_plot = np.array(y_pred_ema).flatten()
y_pred_model_plot = np.array(y_pred_model).flatten()
sort_idx = np.argsort(x_plot)
x_plot = x_plot[sort_idx]
y_true_plot = y_true_plot[sort_idx]
y_pred_ema_plot = y_pred_ema_plot[sort_idx]
y_pred_model_plot = y_pred_model_plot[sort_idx]
experiment_output_dir = os.path.join(
args.output_dir, args.experiment_name
)
os.makedirs(experiment_output_dir, exist_ok=True)
fig = Figure(figsize=(10, 6))
ax = fig.add_subplot(111)
ax.scatter(x_plot, y_true_plot, alpha=0.7, label="Ground Truth", s=20)
ax.scatter(
x_plot,
y_pred_model_plot,
alpha=0.7,
label="Model Prediction",
s=20,
)
ax.scatter(
x_plot,
y_pred_ema_plot,
alpha=0.7,
label="EMA Prediction",
s=20,
)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Sin Function: Ground Truth vs Model vs EMA Prediction")
ax.legend()
ax.grid(True, alpha=0.3)
plot_path = os.path.join(experiment_output_dir, f"eval_{step+1}.png")
fig.savefig(plot_path, dpi=300, bbox_inches="tight")
logging.info(f"Plot saved to {plot_path}")
if jax.process_index() == 0:
logging.info(
f"Step {step+1}, Test Loss: {test_loss:.6f}, "
f"EMA Test Loss: {test_loss_ema:.6f}"
)
if (step + 1) % args.save_interval == 0:
if jax.process_index() == 0:
logging.info(f"Saving checkpoint at step {step + 1}")
opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)
opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)
ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)
ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)
ckpt_mngr.save(
step + 1,
args=ocp.args.Composite(
opt_state=ocp.args.StandardSave(opt_state_no_rngs),
opt_rngs=ocp.args.StandardSave(opt_rng_keys),
ema_state=ocp.args.StandardSave(ema_state_no_rngs),
ema_rngs=ocp.args.StandardSave(ema_rng_keys),
),
)
if jax.process_index() == 0:
logging.info(f"Checkpoint saved successfully")
Let’s train our model!
[31]:
start_step = 0
train_loop(start_step, opt_state, ema_state)
2025-10-05 20:32:12,456 - INFO - Step 100, Train Loss: 0.232647
2025-10-05 20:32:13,548 - INFO - Step 200, Train Loss: 0.178514
2025-10-05 20:32:14,642 - INFO - Step 300, Train Loss: 0.137890
2025-10-05 20:32:15,734 - INFO - Step 400, Train Loss: 0.103556
2025-10-05 20:32:16,828 - INFO - Step 500, Train Loss: 0.071781
2025-10-05 20:32:17,921 - INFO - Step 600, Train Loss: 0.064227
2025-10-05 20:32:19,014 - INFO - Step 700, Train Loss: 0.038317
2025-10-05 20:32:20,155 - INFO - Step 800, Train Loss: 0.031034
2025-10-05 20:32:21,256 - INFO - Step 900, Train Loss: 0.018365
2025-10-05 20:32:22,357 - INFO - Step 1000, Train Loss: 0.011609
2025-10-05 20:32:22,900 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_1000.png
2025-10-05 20:32:22,901 - INFO - Step 1000, Test Loss: 0.010148, EMA Test Loss: 0.306501
2025-10-05 20:32:24,108 - INFO - Step 1100, Train Loss: 0.007852
2025-10-05 20:32:25,202 - INFO - Step 1200, Train Loss: 0.007552
2025-10-05 20:32:26,299 - INFO - Step 1300, Train Loss: 0.003781
2025-10-05 20:32:27,393 - INFO - Step 1400, Train Loss: 0.003243
2025-10-05 20:32:28,493 - INFO - Step 1500, Train Loss: 0.002446
2025-10-05 20:32:29,586 - INFO - Step 1600, Train Loss: 0.002355
2025-10-05 20:32:30,682 - INFO - Step 1700, Train Loss: 0.002892
2025-10-05 20:32:31,777 - INFO - Step 1800, Train Loss: 0.001434
2025-10-05 20:32:32,872 - INFO - Step 1900, Train Loss: 0.002903
2025-10-05 20:32:33,970 - INFO - Step 2000, Train Loss: 0.001691
2025-10-05 20:32:34,360 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_2000.png
2025-10-05 20:32:34,361 - INFO - Step 2000, Test Loss: 0.001314, EMA Test Loss: 0.070212
2025-10-05 20:32:35,459 - INFO - Step 2100, Train Loss: 0.001170
2025-10-05 20:32:36,554 - INFO - Step 2200, Train Loss: 0.001386
2025-10-05 20:32:37,773 - INFO - Step 2300, Train Loss: 0.001572
2025-10-05 20:32:38,883 - INFO - Step 2400, Train Loss: 0.000817
2025-10-05 20:32:39,985 - INFO - Step 2500, Train Loss: 0.001038
2025-10-05 20:32:39,985 - INFO - Saving checkpoint at step 2500
2025-10-05 20:32:39,986 - INFO - Using JaxDistributedSignalingClient
2025-10-05 20:32:39,987 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
2025-10-05 20:32:39,987 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 2
2025-10-05 20:32:40,052 - INFO - [process=0] Saving checkpoint at step 2500
2025-10-05 20:32:40,052 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500.
2025-10-05 20:32:40,169 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500
2025-10-05 20:32:40,326 - INFO - Wrote Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}}, json={"item_handlers": null, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1759696360220900523, "commit_timestamp_nsecs": null, "custom_metadata": {}} to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA
2025-10-05 20:32:40,380 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs
2025-10-05 20:32:40,387 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs
2025-10-05 20:32:40,388 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state
2025-10-05 20:32:40,390 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state
2025-10-05 20:32:40,390 - INFO - No entry found in handler registry for item: ema_rngs and args with type: <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>. Falling back to global handler registry.
2025-10-05 20:32:40,391 - INFO - Deferred registration for item: "ema_rngs". Adding handler `<orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x7f1618151d60>` for item "ema_rngs" and save args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>` to `_handler_registry`.
2025-10-05 20:32:40,391 - INFO - No entry found in handler registry for item: ema_state and args with type: <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>. Falling back to global handler registry.
2025-10-05 20:32:40,391 - INFO - Deferred registration for item: "ema_state". Adding handler `<orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x7f24cf89f4d0>` for item "ema_state" and save args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>` to `_handler_registry`.
2025-10-05 20:32:40,392 - INFO - No entry found in handler registry for item: opt_rngs and args with type: <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>. Falling back to global handler registry.
2025-10-05 20:32:40,392 - INFO - Deferred registration for item: "opt_rngs". Adding handler `<orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x7f24cf2214f0>` for item "opt_rngs" and save args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>` to `_handler_registry`.
2025-10-05 20:32:40,392 - INFO - No entry found in handler registry for item: opt_state and args with type: <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>. Falling back to global handler registry.
2025-10-05 20:32:40,392 - INFO - Deferred registration for item: "opt_state". Adding handler `<orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x7f24cf8ee540>` for item "opt_state" and save args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>` to `_handler_registry`.
2025-10-05 20:32:40,400 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:32:40,401 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.001122s
2025-10-05 20:32:40,401 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
I1005 20:32:40.411459 39356 google_auth_provider.cc:181] Running on GCE, using service account 373177222751-compute@developer.gserviceaccount.com
2025-10-05 20:32:40,498 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs/array_metadatas/process_0
2025-10-05 20:32:40,710 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.308578s
2025-10-05 20:32:40,713 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:32:40,715 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.002202s
2025-10-05 20:32:40,718 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:32:40,726 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.008384s
2025-10-05 20:32:40,731 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 70 Bytes/s (total bytes: 24 Bytes) (time elapsed: 338 milliseconds) (per-host)
2025-10-05 20:32:40,732 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339135s (batch_requests_ready=0.000657s, total_serialization_initiated=0.334970s, others=0.003508s)
2025-10-05 20:32:40,734 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 11.8 MiB/s (total bytes: 4.0 MiB) (time elapsed: 339 milliseconds) (per-host)
2025-10-05 20:32:40,734 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339792s (batch_requests_ready=0.000642s, total_serialization_initiated=0.337574s, others=0.001577s)
2025-10-05 20:32:40,735 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 70 Bytes/s (total bytes: 24 Bytes) (time elapsed: 339 milliseconds) (per-host)
2025-10-05 20:32:40,736 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339844s (batch_requests_ready=0.000344s, total_serialization_initiated=0.338248s, others=0.001252s)
2025-10-05 20:32:40,737 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 35.4 MiB/s (total bytes: 12.0 MiB) (time elapsed: 340 milliseconds) (per-host)
2025-10-05 20:32:40,739 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.341751s (batch_requests_ready=0.001705s, total_serialization_initiated=0.338225s, others=0.001821s)
2025-10-05 20:32:40,739 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.412713s (all_items=0.002202s, per_item={'ema_rngs': '0.00070572', 'ema_state': '0.00052786', 'opt_rngs': '0.00048780', 'opt_state': '0.00048065'}, temp_paths=0.410511)
2025-10-05 20:32:40,811 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state/array_metadatas/process_0
2025-10-05 20:32:40,815 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs/array_metadatas/process_0
2025-10-05 20:32:40,835 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state/array_metadatas/process_0
2025-10-05 20:32:41,187 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.454033s (commit=0.270430s, array_metadata_write=0.183603s)
2025-10-05 20:32:41,188 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 30 Bytes/s (total bytes: 24 Bytes) (time elapsed: 794 milliseconds) (per-host)
2025-10-05 20:32:41,510 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.773855s (commit=0.577525s, array_metadata_write=0.196330s)
2025-10-05 20:32:41,539 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.803838s (commit=0.621967s, array_metadata_write=0.181870s)
2025-10-05 20:32:41,539 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 3.5 MiB/s (total bytes: 4.0 MiB) (time elapsed: a second) (per-host)
2025-10-05 20:32:41,540 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 20 Bytes/s (total bytes: 24 Bytes) (time elapsed: a second) (per-host)
2025-10-05 20:32:41,667 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.929092s (commit=0.732886s, array_metadata_write=0.196206s)
2025-10-05 20:32:41,667 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 9.5 MiB/s (total bytes: 12.0 MiB) (time elapsed: a second) (per-host)
2025-10-05 20:32:41,829 - INFO - Read Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA
2025-10-05 20:32:42,064 - INFO - Updated Metadata={'item_handlers': {'ema_rngs': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'ema_state': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'opt_rngs': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'opt_state': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler'}, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA
2025-10-05 20:32:42,137 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:32:42,302 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.237342s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs
2025-10-05 20:32:42,303 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs
2025-10-05 20:32:42,477 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:32:42,632 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.230508s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state
2025-10-05 20:32:42,632 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state
2025-10-05 20:32:42,802 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:32:42,967 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.241697s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs
2025-10-05 20:32:42,967 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs
2025-10-05 20:32:43,138 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:32:43,315 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.250376s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state
2025-10-05 20:32:43,315 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state
2025-10-05 20:32:43,401 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500
2025-10-05 20:32:43,718 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500`.
2025-10-05 20:32:43,718 - INFO - Finished synchronous save in 3.67 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500
2025-10-05 20:32:43,719 - INFO - [process=0][thread=MainThread][step=2500] CheckpointManager Save Finalize is syncing with other hosts...
2025-10-05 20:32:43,720 - INFO - [process=0][thread=MainThread][step=2500] CheckpointManager Save Finalize is done on all hosts.
2025-10-05 20:32:43,720 - INFO - [process=0][thread=MainThread][step=2500] Finished synchronous save.
2025-10-05 20:32:43,720 - INFO - {'step': 2500, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696359.9876494, 'wait_for_prev_duration_secs': 0.0002562999725341797, 'checkpointer_blocking_start_time': 1759696360.0524569, 'checkpointer_blocking_duration_secs': 3.6667628288269043, 'get_old_steps_start_time': 1759696363.7192426, 'get_old_steps_duration_secs': 0.00011086463928222656, 'checkpoint_manager_blocking_start_time': 1759696359.98681, 'checkpoint_manager_blocking_duration_secs': 3.7337262630462646}
2025-10-05 20:32:43,720 - INFO - Checkpoint saved successfully
2025-10-05 20:32:44,824 - INFO - Step 2600, Train Loss: 0.001182
2025-10-05 20:32:45,925 - INFO - Step 2700, Train Loss: 0.000822
2025-10-05 20:32:47,020 - INFO - Step 2800, Train Loss: 0.001010
2025-10-05 20:32:48,132 - INFO - Step 2900, Train Loss: 0.001088
2025-10-05 20:32:49,232 - INFO - Step 3000, Train Loss: 0.001274
2025-10-05 20:32:49,624 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_3000.png
2025-10-05 20:32:49,625 - INFO - Step 3000, Test Loss: 0.000465, EMA Test Loss: 0.008318
2025-10-05 20:32:50,726 - INFO - Step 3100, Train Loss: 0.000773
2025-10-05 20:32:51,822 - INFO - Step 3200, Train Loss: 0.001725
2025-10-05 20:32:52,920 - INFO - Step 3300, Train Loss: 0.000757
2025-10-05 20:32:54,167 - INFO - Step 3400, Train Loss: 0.001064
2025-10-05 20:32:55,268 - INFO - Step 3500, Train Loss: 0.001229
2025-10-05 20:32:56,372 - INFO - Step 3600, Train Loss: 0.001523
2025-10-05 20:32:57,472 - INFO - Step 3700, Train Loss: 0.000991
2025-10-05 20:32:58,573 - INFO - Step 3800, Train Loss: 0.004077
2025-10-05 20:32:59,674 - INFO - Step 3900, Train Loss: 0.001708
2025-10-05 20:33:00,776 - INFO - Step 4000, Train Loss: 0.001069
2025-10-05 20:33:01,158 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_4000.png
2025-10-05 20:33:01,159 - INFO - Step 4000, Test Loss: 0.000324, EMA Test Loss: 0.000607
2025-10-05 20:33:02,271 - INFO - Step 4100, Train Loss: 0.000972
2025-10-05 20:33:03,371 - INFO - Step 4200, Train Loss: 0.000642
2025-10-05 20:33:04,469 - INFO - Step 4300, Train Loss: 0.000937
2025-10-05 20:33:05,571 - INFO - Step 4400, Train Loss: 0.000771
2025-10-05 20:33:06,673 - INFO - Step 4500, Train Loss: 0.001577
2025-10-05 20:33:07,894 - INFO - Step 4600, Train Loss: 0.000988
2025-10-05 20:33:09,016 - INFO - Step 4700, Train Loss: 0.001011
2025-10-05 20:33:10,124 - INFO - Step 4800, Train Loss: 0.000930
2025-10-05 20:33:11,225 - INFO - Step 4900, Train Loss: 0.001215
2025-10-05 20:33:12,324 - INFO - Step 5000, Train Loss: 0.000650
2025-10-05 20:33:12,697 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_5000.png
2025-10-05 20:33:12,698 - INFO - Step 5000, Test Loss: 0.000054, EMA Test Loss: 0.000241
2025-10-05 20:33:12,698 - INFO - Saving checkpoint at step 5000
2025-10-05 20:33:12,701 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
2025-10-05 20:33:12,701 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 4
2025-10-05 20:33:12,756 - INFO - [process=0] Saving checkpoint at step 5000
2025-10-05 20:33:12,757 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.
2025-10-05 20:33:12,863 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000
2025-10-05 20:33:13,058 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs
2025-10-05 20:33:13,062 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs
2025-10-05 20:33:13,065 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state
2025-10-05 20:33:13,071 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state
2025-10-05 20:33:13,078 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:13,079 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.001208s
2025-10-05 20:33:13,079 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:13,085 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.005938s
2025-10-05 20:33:13,088 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:13,090 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.002231s
2025-10-05 20:33:13,092 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:13,101 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.009317s
2025-10-05 20:33:13,107 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 697 Bytes/s (total bytes: 24 Bytes) (time elapsed: 34 milliseconds) (per-host)
2025-10-05 20:33:13,107 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.034985s (batch_requests_ready=0.000446s, total_serialization_initiated=0.032820s, others=0.001719s)
2025-10-05 20:33:13,109 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 113.1 MiB/s (total bytes: 4.0 MiB) (time elapsed: 35 milliseconds) (per-host)
2025-10-05 20:33:13,109 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036130s (batch_requests_ready=0.000589s, total_serialization_initiated=0.034237s, others=0.001304s)
2025-10-05 20:33:13,110 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 669 Bytes/s (total bytes: 24 Bytes) (time elapsed: 35 milliseconds) (per-host)
2025-10-05 20:33:13,111 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036366s (batch_requests_ready=0.000304s, total_serialization_initiated=0.035025s, others=0.001038s)
2025-10-05 20:33:13,112 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 327.1 MiB/s (total bytes: 12.0 MiB) (time elapsed: 36 milliseconds) (per-host)
2025-10-05 20:33:13,113 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.038096s (batch_requests_ready=0.001270s, total_serialization_initiated=0.034801s, others=0.002025s)
2025-10-05 20:33:13,114 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.109971s (all_items=0.000032s, per_item={'ema_rngs': '0.00002027', 'ema_state': '0.00000477', 'opt_rngs': '0.00000286', 'opt_state': '0.00000453'}, temp_paths=0.109938)
2025-10-05 20:33:13,182 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state/array_metadatas/process_0
2025-10-05 20:33:13,185 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs/array_metadatas/process_0
2025-10-05 20:33:13,192 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs/array_metadatas/process_0
2025-10-05 20:33:13,208 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state/array_metadatas/process_0
2025-10-05 20:33:13,695 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.587158s (commit=0.389604s, array_metadata_write=0.197554s)
2025-10-05 20:33:13,696 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 38 Bytes/s (total bytes: 24 Bytes) (time elapsed: 623 milliseconds) (per-host)
2025-10-05 20:33:13,866 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.754855s (commit=0.563322s, array_metadata_write=0.191533s)
2025-10-05 20:33:13,904 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.794878s (commit=0.609138s, array_metadata_write=0.185740s)
2025-10-05 20:33:13,905 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.8 MiB/s (total bytes: 4.0 MiB) (time elapsed: 831 milliseconds) (per-host)
2025-10-05 20:33:13,905 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 28 Bytes/s (total bytes: 24 Bytes) (time elapsed: 830 milliseconds) (per-host)
2025-10-05 20:33:13,982 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.869188s (commit=0.674423s, array_metadata_write=0.194765s)
2025-10-05 20:33:13,982 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.3 MiB/s (total bytes: 12.0 MiB) (time elapsed: 907 milliseconds) (per-host)
2025-10-05 20:33:14,454 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:14,612 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.230279s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs
2025-10-05 20:33:14,612 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs
2025-10-05 20:33:14,774 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:14,946 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.240756s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state
2025-10-05 20:33:14,947 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state
2025-10-05 20:33:15,121 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:15,295 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.246438s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs
2025-10-05 20:33:15,296 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs
2025-10-05 20:33:15,458 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:15,611 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.228544s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state
2025-10-05 20:33:15,612 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state
2025-10-05 20:33:15,712 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000
2025-10-05 20:33:16,003 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000`.
2025-10-05 20:33:16,004 - INFO - Finished synchronous save in 3.25 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000
2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] CheckpointManager Save Finalize is syncing with other hosts...
2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] CheckpointManager Save Finalize is done on all hosts.
2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] Finished synchronous save.
2025-10-05 20:33:16,006 - INFO - {'step': 5000, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696392.7011592, 'wait_for_prev_duration_secs': 0.00031065940856933594, 'checkpointer_blocking_start_time': 1759696392.7570903, 'checkpointer_blocking_duration_secs': 3.2475996017456055, 'get_old_steps_start_time': 1759696396.0047104, 'get_old_steps_duration_secs': 9.846687316894531e-05, 'checkpoint_manager_blocking_start_time': 1759696392.7007046, 'checkpoint_manager_blocking_duration_secs': 3.3052756786346436}
2025-10-05 20:33:16,006 - INFO - Checkpoint saved successfully
Our model has been trained. You can see in the logs above that its test loss is pretty low. Also, check the eval images. The predicted graph should look just like the sinusoidal function.
[32]:
import os
import glob
from PIL import Image
exp_dir = os.path.join(args.output_dir, args.experiment_name)
png_files = glob.glob(os.path.join(exp_dir, "*.png"))
latest_png = max(png_files, key=os.path.getmtime)
img=Image.open(latest_png)
resized_img = img.resize((img.width // 4, img.height // 4))
resized_img
[32]:
Let’s now simulate the situation when you want to load a trained model. This is usually done to either evaluate or resume training.
In our case, we will load the previously trained model from disk and run our training loop on it again.
Since JAX is built on functional programming and we never returned the trained states from our first train_loop() call, our global opt_state and ema_state variables contain the initial, randomly initialized networks. It’s as if we started our program anew, with the difference that we first load an existing model from disk using Orbax.
As we did when we were saving the checkpoints, we filter out the rngs so that the type and shape of what we ask to restore matches what was saved. The cool thing is that since our states have been marked with sharding, Orbax ensures the state will never get materialized on a single device, but get sharded accordingly. Additionally, Orbax will disregard the sharding information of the states saved on disk and load it according to our current sharding. This allows us to first train a model on one
cluster configuration and then load it in another cluster. Note that for this behavior to work, you need to use ocp.args.StandardRestore() and not ocp.args.PyTreeRestore().
After loading the states and rngs separately, we merge them. You can examine the values inside the rngs dropout object. It says Array(5000, dtype=uint32), which should match the number of training steps we have done in our model, so we are resuming from the same rng sequence. The dropout rng value in the EMA is Array(0, dtype=uint32) because it was never drawn.
[33]:
latest_step = args.steps
opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)
opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)
ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)
ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)
state_restored = ckpt_mngr.restore(
latest_step,
args=ocp.args.Composite(
opt_state=ocp.args.StandardRestore(opt_state_no_rngs),
ema_state=ocp.args.StandardRestore(ema_state_no_rngs),
opt_rngs=ocp.args.StandardRestore(opt_rng_keys),
ema_rngs=ocp.args.StandardRestore(ema_rng_keys),
),
)
opt_state_no_rngs, ema_state_no_rngs, opt_rngs_keys, ema_rngs_keys = (
state_restored.opt_state,
state_restored.ema_state,
state_restored.opt_rngs,
state_restored.ema_rngs,
)
opt_rngs = jax.tree_map(jax.random.wrap_key_data, opt_rngs_keys)
ema_rngs = jax.tree_map(jax.random.wrap_key_data, ema_rngs_keys)
opt_state = nnx.merge_state(opt_state_no_rngs, opt_rngs)
ema_state = nnx.merge_state(ema_state_no_rngs, ema_rngs)
if jax.process_index() == 0:
logging.info("Checkpoint restored successfully")
log_shard_map("Opt state sharding after restore", opt_state)
log_shard_map("EMA state sharding after restore", ema_state)
logging.info(f"Opt state after restore: {opt_state}")
logging.info(f"EMA state after restore: {ema_state}")
2025-10-05 20:33:16,340 - INFO - Restoring checkpoint from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.
2025-10-05 20:33:16,575 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 619 Bytes/s (total bytes: 96 Bytes) (time elapsed: 154 milliseconds) (per-host)
2025-10-05 20:33:16,817 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 24.0 MiB/s (total bytes: 4.1 MiB) (time elapsed: 169 milliseconds) (per-host)
2025-10-05 20:33:17,026 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 666 Bytes/s (total bytes: 96 Bytes) (time elapsed: 144 milliseconds) (per-host)
2025-10-05 20:33:17,335 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 51.8 MiB/s (total bytes: 12.2 MiB) (time elapsed: 235 milliseconds) (per-host)
2025-10-05 20:33:17,335 - INFO - Finished restoring checkpoint in 1.12 seconds from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.
2025-10-05 20:33:17,335 - INFO - {'step': 5000, 'event_type': 'restore', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'checkpointer_start_time': 1759696396.2143493, 'checkpointer_duration_secs': 1.1216151714324951, 'checkpoint_manager_start_time': 1759696396.1660664, 'checkpoint_manager_duration_secs': 1.1698994636535645}
/tmp/ipykernel_37871/985055546.py:24: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
opt_rngs = jax.tree_map(jax.random.wrap_key_data, opt_rngs_keys)
2025-10-05 20:33:17,337 - INFO - Checkpoint restored successfully
2025-10-05 20:33:17,337 - INFO - ── Shard ↦ device map: Opt state sharding after restore ──
2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,344 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,345 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,346 - INFO - model/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,347 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,347 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,348 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,348 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,348 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,348 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,349 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,349 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,350 - INFO - opt_state/0/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,355 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,355 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,361 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,361 - INFO - step () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,361 - INFO - step () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,361 - INFO - step () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,361 - INFO - step () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,362 - INFO - ── Shard ↦ device map: EMA state sharding after restore ──
2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,364 - INFO - dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,366 - INFO - dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,368 - INFO - dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,369 - INFO - fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,370 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,372 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))
2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))
2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))
2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))
2025-10-05 20:33:17,389 - INFO - Opt state after restore: State({
'model': {
'dropout': {
'rngs': {
'default': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(6, dtype=uint32),
tag='default'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 0],
tag='default'
)
},
'dropout': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(5000, dtype=uint32),
tag='dropout'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 1],
tag='dropout'
)
},
'noise': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(0, dtype=uint32),
tag='noise'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 2],
tag='noise'
)
}
}
},
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([-0.04030257, -0.05037591, 0.07226934, ..., -0.05214077,
0.05906752, -0.05086785], dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([[ 1.0710918 , -0.9678744 , -0.8126575 , ..., 0.17166048,
0.36044577, -0.6832711 ]], dtype=float32)
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([-0.00861618, -0.04522035, 0. , ..., 0.04507154,
0.0422631 , 0. ], dtype=float32)
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=Param,
value=Array([[ 0.03979657, 0.04424966, -0.0138542 , ..., -0.05651323,
-0.01515646, -0.02998248],
[-0.01223963, 0.00112585, -0.02565034, ..., -0.04619034,
-0.0092434 , 0.00962243],
[-0.01330583, -0.03468521, 0.01838579, ..., 0.02772252,
-0.02609745, -0.05185288],
...,
[-0.01194615, -0.05189555, 0.00999935, ..., -0.0643564 ,
0.01144396, -0.02076687],
[-0.01237209, -0.00924738, -0.03677849, ..., 0.00366568,
-0.01869369, 0.05434604],
[ 0.02957141, 0.0040535 , 0.02555469, ..., -0.00918872,
0.0008155 , 0.0077408 ]], dtype=float32)
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=Param,
value=Array([0.0070804], dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([[-0.03213442],
[ 0.02066209],
[-0.04244323],
...,
[-0.0067385 ],
[-0.06366555],
[-0.04282695]], dtype=float32)
)
}
},
'opt_state': {
0: {
'count': VariableState( # 1 (4 B)
type=OptArray,
value=Array(5000, dtype=int32)
),
'mu': {
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([-2.3567886e-06, -3.8756975e-06, 5.0598974e-06, ...,
2.5492732e-06, -1.8280221e-06, 2.3327759e-05], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([[-1.2811696e-05, 1.8371884e-05, -2.8107081e-05, ...,
-7.7499708e-06, 1.2058406e-05, -2.7912660e-05]], dtype=float32),
source_type=Param
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([-2.4946366e-05, -5.1071429e-06, 0.0000000e+00, ...,
-6.0458387e-07, -1.6253693e-05, 0.0000000e+00], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=OptVariable,
value=Array([[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
1.10227113e-07, -9.35575190e-06, 0.00000000e+00],
[-1.04507053e-05, -2.12880859e-05, 0.00000000e+00, ...,
3.34136939e-06, -5.51595440e-06, 0.00000000e+00],
[-1.16010815e-05, -1.84936525e-05, 0.00000000e+00, ...,
2.57636566e-06, -6.29322312e-06, 0.00000000e+00],
...,
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
0.00000000e+00, -7.95417463e-07, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
1.59682756e-07, -3.07845085e-06, 0.00000000e+00],
[-7.00975761e-06, -1.49751568e-05, 0.00000000e+00, ...,
2.39362407e-06, -3.57614044e-06, 0.00000000e+00]], dtype=float32),
source_type=Param
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=OptVariable,
value=Array([0.0003306], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([[ 2.5945314e-05],
[-6.1757506e-05],
[ 0.0000000e+00],
...,
[-2.2441673e-05],
[-3.4819095e-06],
[ 0.0000000e+00]], dtype=float32),
source_type=Param
)
}
},
'nu': {
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([6.5946330e-07, 4.2301590e-08, 1.9096869e-07, ..., 8.1455914e-07,
2.5145167e-07, 3.8475892e-07], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([[2.9782741e-06, 2.5221141e-07, 1.2813606e-06, ..., 4.2202855e-06,
1.4535699e-06, 1.6272666e-06]], dtype=float32),
source_type=Param
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([6.7727728e-07, 1.5464612e-07, 0.0000000e+00, ..., 2.4988884e-08,
2.2302729e-06, 0.0000000e+00], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=OptVariable,
value=Array([[7.1847148e-07, 0.0000000e+00, 0.0000000e+00, ..., 1.9970895e-11,
1.9745292e-08, 0.0000000e+00],
[2.3998948e-06, 7.2880732e-07, 0.0000000e+00, ..., 9.9980468e-08,
8.5483980e-06, 0.0000000e+00],
[1.8887340e-06, 5.6486834e-07, 0.0000000e+00, ..., 7.8932914e-08,
6.7499118e-06, 0.0000000e+00],
...,
[1.7248375e-08, 0.0000000e+00, 0.0000000e+00, ..., 3.9913174e-13,
2.2283823e-10, 0.0000000e+00],
[8.6669672e-08, 4.9999395e-17, 0.0000000e+00, ..., 2.8976873e-12,
3.6252299e-09, 0.0000000e+00],
[1.1716367e-06, 3.5670226e-07, 0.0000000e+00, ..., 4.8790799e-08,
4.1711064e-06, 0.0000000e+00]], dtype=float32),
source_type=Param
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=OptVariable,
value=Array([0.00091323], dtype=float32),
source_type=Param
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=OptVariable,
value=Array([[5.6587835e-04],
[2.0583873e-04],
[0.0000000e+00],
...,
[3.7648741e-04],
[1.6187840e-05],
[0.0000000e+00]], dtype=float32),
source_type=Param
)
}
}
}
},
'step': VariableState( # 1 (4 B)
type=OptState,
value=Array(5000, dtype=uint32)
)
})
2025-10-05 20:33:17,394 - INFO - EMA state after restore: State({
'dropout': {
'rngs': {
'default': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(0, dtype=uint32),
tag='default'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 0],
tag='default'
)
},
'dropout': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(0, dtype=uint32),
tag='dropout'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 0],
tag='dropout'
)
},
'noise': {
'count': VariableState( # 1 (4 B)
type=RngCount,
value=Array(0, dtype=uint32),
tag='noise'
),
'key': VariableState( # 1 (8 B)
type=RngKey,
value=Array((), dtype=key<fry>) overlaying:
[0 0],
tag='noise'
)
}
}
},
'fc1': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([-0.03922425, -0.04768839, 0.06973939, ..., -0.05001258,
0.05696585, -0.04945442], dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([[ 1.063954 , -0.9617063 , -0.80732816, ..., 0.170594 ,
0.35816 , -0.6788137 ]], dtype=float32)
)
},
'fc2': {
'bias': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([-0.00952129, -0.04323541, 0. , ..., 0.04311211,
0.04140445, 0. ], dtype=float32)
),
'kernel': VariableState( # 1,048,576 (4.2 MB)
type=Param,
value=Array([[ 0.03955881, 0.04395215, -0.01376141, ..., -0.05614801,
-0.01507475, -0.02978184],
[-0.01205855, 0.00108851, -0.02547839, ..., -0.04580969,
-0.00915899, 0.00955787],
[-0.01322271, -0.03437575, 0.01826279, ..., 0.02750572,
-0.02593077, -0.05150444],
...,
[-0.01188646, -0.05154517, 0.00993251, ..., -0.0631559 ,
0.01273686, -0.0206273 ],
[-0.01222327, -0.00915908, -0.03653297, ..., 0.00275638,
-0.01871123, 0.05397995],
[ 0.02949264, 0.00397549, 0.02538221, ..., -0.00903669,
0.00084104, 0.00768876]], dtype=float32)
)
},
'fc3': {
'bias': VariableState( # 1 (4 B)
type=Param,
value=Array([0.00718139], dtype=float32)
),
'kernel': VariableState( # 1,024 (4.1 KB)
type=Param,
value=Array([[-0.03202476],
[ 0.02044498],
[-0.0421575 ],
...,
[-0.0065837 ],
[-0.06370619],
[-0.04253863]], dtype=float32)
)
}
})
Now we run train_loop() once again using the states we just loaded and starting from the last step.
[34]:
start_step = latest_step
train_loop(start_step, opt_state, ema_state)
2025-10-05 20:33:18,505 - INFO - Step 5100, Train Loss: 0.000881
2025-10-05 20:33:19,607 - INFO - Step 5200, Train Loss: 0.001370
2025-10-05 20:33:20,711 - INFO - Step 5300, Train Loss: 0.000773
2025-10-05 20:33:21,810 - INFO - Step 5400, Train Loss: 0.004323
2025-10-05 20:33:22,912 - INFO - Step 5500, Train Loss: 0.001996
2025-10-05 20:33:24,012 - INFO - Step 5600, Train Loss: 0.002029
2025-10-05 20:33:25,284 - INFO - Step 5700, Train Loss: 0.003011
2025-10-05 20:33:26,391 - INFO - Step 5800, Train Loss: 0.000485
2025-10-05 20:33:27,494 - INFO - Step 5900, Train Loss: 0.000653
2025-10-05 20:33:28,596 - INFO - Step 6000, Train Loss: 0.000697
2025-10-05 20:33:28,982 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_6000.png
2025-10-05 20:33:28,983 - INFO - Step 6000, Test Loss: 0.000865, EMA Test Loss: 0.000275
2025-10-05 20:33:30,091 - INFO - Step 6100, Train Loss: 0.001780
2025-10-05 20:33:31,201 - INFO - Step 6200, Train Loss: 0.000800
2025-10-05 20:33:32,310 - INFO - Step 6300, Train Loss: 0.000672
2025-10-05 20:33:33,419 - INFO - Step 6400, Train Loss: 0.000786
2025-10-05 20:33:34,525 - INFO - Step 6500, Train Loss: 0.000826
2025-10-05 20:33:35,634 - INFO - Step 6600, Train Loss: 0.000914
2025-10-05 20:33:36,745 - INFO - Step 6700, Train Loss: 0.000973
2025-10-05 20:33:37,858 - INFO - Step 6800, Train Loss: 0.002495
2025-10-05 20:33:38,979 - INFO - Step 6900, Train Loss: 0.000930
2025-10-05 20:33:40,244 - INFO - Step 7000, Train Loss: 0.000603
2025-10-05 20:33:40,627 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_7000.png
2025-10-05 20:33:40,628 - INFO - Step 7000, Test Loss: 0.000240, EMA Test Loss: 0.000210
2025-10-05 20:33:41,729 - INFO - Step 7100, Train Loss: 0.000742
2025-10-05 20:33:42,832 - INFO - Step 7200, Train Loss: 0.000778
2025-10-05 20:33:43,945 - INFO - Step 7300, Train Loss: 0.001267
2025-10-05 20:33:45,061 - INFO - Step 7400, Train Loss: 0.000588
2025-10-05 20:33:46,161 - INFO - Step 7500, Train Loss: 0.001284
2025-10-05 20:33:46,161 - INFO - Saving checkpoint at step 7500
2025-10-05 20:33:46,163 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
2025-10-05 20:33:46,163 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 6
2025-10-05 20:33:46,223 - INFO - [process=0] Saving checkpoint at step 7500
2025-10-05 20:33:46,223 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500.
2025-10-05 20:33:46,332 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500
2025-10-05 20:33:46,553 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs
2025-10-05 20:33:46,554 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs
2025-10-05 20:33:46,555 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state
2025-10-05 20:33:46,558 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state
2025-10-05 20:33:46,565 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:46,566 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.001162s
2025-10-05 20:33:46,566 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:46,572 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.006226s
2025-10-05 20:33:46,575 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:46,578 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.002278s
2025-10-05 20:33:46,578 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:33:46,589 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.010375s
2025-10-05 20:33:46,593 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 700 Bytes/s (total bytes: 24 Bytes) (time elapsed: 34 milliseconds) (per-host)
2025-10-05 20:33:46,594 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.035200s (batch_requests_ready=0.000461s, total_serialization_initiated=0.031193s, others=0.003546s)
2025-10-05 20:33:46,596 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 111.4 MiB/s (total bytes: 4.0 MiB) (time elapsed: 36 milliseconds) (per-host)
2025-10-05 20:33:46,597 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036603s (batch_requests_ready=0.000614s, total_serialization_initiated=0.034220s, others=0.001769s)
2025-10-05 20:33:46,599 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 643 Bytes/s (total bytes: 24 Bytes) (time elapsed: 37 milliseconds) (per-host)
2025-10-05 20:33:46,599 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.038188s (batch_requests_ready=0.000318s, total_serialization_initiated=0.036442s, others=0.001429s)
2025-10-05 20:33:46,601 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 313.0 MiB/s (total bytes: 12.0 MiB) (time elapsed: 38 milliseconds) (per-host)
2025-10-05 20:33:46,602 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.039559s (batch_requests_ready=0.001333s, total_serialization_initiated=0.036562s, others=0.001664s)
2025-10-05 20:33:46,602 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.105153s (all_items=0.000026s, per_item={'ema_rngs': '0.00001717', 'ema_state': '0.00000381', 'opt_rngs': '0.00000238', 'opt_state': '0.00000262'}, temp_paths=0.105127)
2025-10-05 20:33:46,673 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs/array_metadatas/process_0
2025-10-05 20:33:46,677 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs/array_metadatas/process_0
2025-10-05 20:33:46,684 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state/array_metadatas/process_0
2025-10-05 20:33:46,692 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state/array_metadatas/process_0
2025-10-05 20:33:47,207 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.611453s (commit=0.420513s, array_metadata_write=0.190940s)
2025-10-05 20:33:47,208 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 37 Bytes/s (total bytes: 24 Bytes) (time elapsed: 648 milliseconds) (per-host)
2025-10-05 20:33:47,340 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.740908s (commit=0.551089s, array_metadata_write=0.189819s)
2025-10-05 20:33:47,445 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.847461s (commit=0.641053s, array_metadata_write=0.206408s)
2025-10-05 20:33:47,445 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.5 MiB/s (total bytes: 4.0 MiB) (time elapsed: 885 milliseconds) (per-host)
2025-10-05 20:33:47,446 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 27 Bytes/s (total bytes: 24 Bytes) (time elapsed: 884 milliseconds) (per-host)
2025-10-05 20:33:47,469 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.867791s (commit=0.645247s, array_metadata_write=0.222543s)
2025-10-05 20:33:47,470 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.3 MiB/s (total bytes: 12.0 MiB) (time elapsed: 907 milliseconds) (per-host)
2025-10-05 20:33:47,936 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:48,081 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.222106s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs
2025-10-05 20:33:48,082 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs
2025-10-05 20:33:48,248 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:48,414 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.239519s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state
2025-10-05 20:33:48,414 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state
2025-10-05 20:33:48,602 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:48,732 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.224504s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs
2025-10-05 20:33:48,733 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs
2025-10-05 20:33:48,900 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:33:49,052 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.223419s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state
2025-10-05 20:33:49,052 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state
2025-10-05 20:33:49,141 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500
2025-10-05 20:33:49,469 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500`.
2025-10-05 20:33:49,469 - INFO - Finished synchronous save in 3.25 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500
2025-10-05 20:33:49,630 - INFO - Deleted step 2500.
2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] CheckpointManager Save Finalize is syncing with other hosts...
2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] CheckpointManager Save Finalize is done on all hosts.
2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] Finished synchronous save.
2025-10-05 20:33:49,631 - INFO - {'step': 7500, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696426.1633513, 'wait_for_prev_duration_secs': 0.00028967857360839844, 'checkpointer_blocking_start_time': 1759696426.223418, 'checkpointer_blocking_duration_secs': 3.2464306354522705, 'get_old_steps_start_time': 1759696429.4698665, 'get_old_steps_duration_secs': 9.059906005859375e-05, 'checkpoint_manager_blocking_start_time': 1759696426.162882, 'checkpoint_manager_blocking_duration_secs': 3.4689133167266846}
2025-10-05 20:33:49,632 - INFO - Checkpoint saved successfully
2025-10-05 20:33:50,739 - INFO - Step 7600, Train Loss: 0.000762
2025-10-05 20:33:51,843 - INFO - Step 7700, Train Loss: 0.000554
2025-10-05 20:33:52,947 - INFO - Step 7800, Train Loss: 0.000492
2025-10-05 20:33:54,050 - INFO - Step 7900, Train Loss: 0.000939
2025-10-05 20:33:55,155 - INFO - Step 8000, Train Loss: 0.000522
2025-10-05 20:33:55,540 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_8000.png
2025-10-05 20:33:55,541 - INFO - Step 8000, Test Loss: 0.000316, EMA Test Loss: 0.000152
2025-10-05 20:33:56,804 - INFO - Step 8100, Train Loss: 0.001092
2025-10-05 20:33:57,912 - INFO - Step 8200, Train Loss: 0.000585
2025-10-05 20:33:59,020 - INFO - Step 8300, Train Loss: 0.001435
2025-10-05 20:34:00,132 - INFO - Step 8400, Train Loss: 0.001997
2025-10-05 20:34:01,242 - INFO - Step 8500, Train Loss: 0.001370
2025-10-05 20:34:02,355 - INFO - Step 8600, Train Loss: 0.000390
2025-10-05 20:34:03,460 - INFO - Step 8700, Train Loss: 0.000709
2025-10-05 20:34:04,571 - INFO - Step 8800, Train Loss: 0.001060
2025-10-05 20:34:05,679 - INFO - Step 8900, Train Loss: 0.000630
2025-10-05 20:34:06,789 - INFO - Step 9000, Train Loss: 0.000743
2025-10-05 20:34:07,168 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_9000.png
2025-10-05 20:34:07,169 - INFO - Step 9000, Test Loss: 0.000382, EMA Test Loss: 0.000085
2025-10-05 20:34:08,279 - INFO - Step 9100, Train Loss: 0.000958
2025-10-05 20:34:09,400 - INFO - Step 9200, Train Loss: 0.000821
2025-10-05 20:34:10,507 - INFO - Step 9300, Train Loss: 0.000445
2025-10-05 20:34:11,772 - INFO - Step 9400, Train Loss: 0.001565
2025-10-05 20:34:12,879 - INFO - Step 9500, Train Loss: 0.000859
2025-10-05 20:34:13,987 - INFO - Step 9600, Train Loss: 0.001134
2025-10-05 20:34:15,096 - INFO - Step 9700, Train Loss: 0.000996
2025-10-05 20:34:16,199 - INFO - Step 9800, Train Loss: 0.001920
2025-10-05 20:34:17,310 - INFO - Step 9900, Train Loss: 0.000885
2025-10-05 20:34:18,420 - INFO - Step 10000, Train Loss: 0.000495
2025-10-05 20:34:18,799 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_10000.png
2025-10-05 20:34:18,799 - INFO - Step 10000, Test Loss: 0.000113, EMA Test Loss: 0.000071
2025-10-05 20:34:18,800 - INFO - Saving checkpoint at step 10000
2025-10-05 20:34:18,804 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
2025-10-05 20:34:18,804 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 8
2025-10-05 20:34:18,858 - INFO - [process=0] Saving checkpoint at step 10000
2025-10-05 20:34:18,859 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000.
2025-10-05 20:34:18,967 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000
2025-10-05 20:34:19,157 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs
2025-10-05 20:34:19,163 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state
2025-10-05 20:34:19,164 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs
2025-10-05 20:34:19,169 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state
2025-10-05 20:34:19,176 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:34:19,177 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.001225s
2025-10-05 20:34:19,179 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:34:19,184 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.005068s
2025-10-05 20:34:19,185 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:34:19,188 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.003032s
2025-10-05 20:34:19,191 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False
2025-10-05 20:34:19,205 - INFO - [process=0][thread=MainThread] Initiated "orbax.checkpoint._src.serialization.type_handlers.ArrayHandler".serialize. Time taken: 0.014571s
2025-10-05 20:34:19,211 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 587 Bytes/s (total bytes: 24 Bytes) (time elapsed: 40 milliseconds) (per-host)
2025-10-05 20:34:19,212 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.041266s (batch_requests_ready=0.000434s, total_serialization_initiated=0.039421s, others=0.001411s)
2025-10-05 20:34:19,213 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 95.3 MiB/s (total bytes: 4.0 MiB) (time elapsed: 42 milliseconds) (per-host)
2025-10-05 20:34:19,214 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.042618s (batch_requests_ready=0.000598s, total_serialization_initiated=0.040283s, others=0.001736s)
2025-10-05 20:34:19,215 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 558 Bytes/s (total bytes: 24 Bytes) (time elapsed: 42 milliseconds) (per-host)
2025-10-05 20:34:19,216 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.043751s (batch_requests_ready=0.000308s, total_serialization_initiated=0.041589s, others=0.001853s)
2025-10-05 20:34:19,218 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 272.1 MiB/s (total bytes: 12.0 MiB) (time elapsed: 44 milliseconds) (per-host)
2025-10-05 20:34:19,219 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.045207s (batch_requests_ready=0.001276s, total_serialization_initiated=0.041893s, others=0.002038s)
2025-10-05 20:34:19,219 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.112145s (all_items=0.000021s, per_item={'ema_rngs': '0.00001311', 'ema_state': '0.00000334', 'opt_rngs': '0.00000262', 'opt_state': '0.00000215'}, temp_paths=0.112123)
2025-10-05 20:34:19,278 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state/array_metadatas/process_0
2025-10-05 20:34:19,280 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs/array_metadatas/process_0
2025-10-05 20:34:19,282 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs/array_metadatas/process_0
2025-10-05 20:34:19,324 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state/array_metadatas/process_0
2025-10-05 20:34:19,821 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.608847s (commit=0.419599s, array_metadata_write=0.189248s)
2025-10-05 20:34:19,822 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 36 Bytes/s (total bytes: 24 Bytes) (time elapsed: 651 milliseconds) (per-host)
2025-10-05 20:34:19,980 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.763334s (commit=0.564349s, array_metadata_write=0.198985s)
2025-10-05 20:34:19,993 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.778045s (commit=0.586967s, array_metadata_write=0.191078s)
2025-10-05 20:34:19,993 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.9 MiB/s (total bytes: 4.0 MiB) (time elapsed: 821 milliseconds) (per-host)
2025-10-05 20:34:19,994 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 29 Bytes/s (total bytes: 24 Bytes) (time elapsed: 821 milliseconds) (per-host)
2025-10-05 20:34:20,073 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.854778s (commit=0.669826s, array_metadata_write=0.184951s)
2025-10-05 20:34:20,074 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.4 MiB/s (total bytes: 12.0 MiB) (time elapsed: 900 milliseconds) (per-host)
2025-10-05 20:34:20,528 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:34:20,706 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.244012s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs
2025-10-05 20:34:20,706 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs
2025-10-05 20:34:20,866 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:34:21,034 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.241950s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state
2025-10-05 20:34:21,034 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state
2025-10-05 20:34:21,211 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:34:21,360 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.232543s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs
2025-10-05 20:34:21,361 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs
2025-10-05 20:34:21,538 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
2025-10-05 20:34:21,685 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.231616s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state
2025-10-05 20:34:21,686 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state
2025-10-05 20:34:21,771 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000
2025-10-05 20:34:22,091 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000`.
2025-10-05 20:34:22,092 - INFO - Finished synchronous save in 3.23 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000
2025-10-05 20:34:22,198 - INFO - Deleted step 5000.
2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] CheckpointManager Save Finalize is syncing with other hosts...
2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] CheckpointManager Save Finalize is done on all hosts.
2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] Finished synchronous save.
2025-10-05 20:34:22,200 - INFO - {'step': 10000, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696458.8044248, 'wait_for_prev_duration_secs': 0.00032973289489746094, 'checkpointer_blocking_start_time': 1759696458.858993, 'checkpointer_blocking_duration_secs': 3.2333545684814453, 'get_old_steps_start_time': 1759696462.0923612, 'get_old_steps_duration_secs': 9.632110595703125e-05, 'checkpoint_manager_blocking_start_time': 1759696458.803941, 'checkpoint_manager_blocking_duration_secs': 3.396230697631836}
2025-10-05 20:34:22,200 - INFO - Checkpoint saved successfully
That’s it. Our model has trained for another 5000 steps.
As a final note, I wanted to share some JAX findings I discovered through my work with it that might be helpful to whoever is reading this.
When you call JITed functions, like we do with train_step_fn() inside our train_loop() function, JAX launches it asynchronously, so your host Python code continues right away. This makes JAX run as much in parallel as possible. JAX will synchronize everything later when your next JAX function requires the output of the previous function. This means that if you want to measure the execution time of your train_step_fn() and do time.time() before it and after, you will see an
unrealistically small number. To make your measurement legit, you need to tell JAX to block for the output with jax.block_until_ready(opt_state) before you make your second time observation. However, this should only be done for temporary debugging, and your production code should not have any block_until_ready() in our training loop because it breaks the internal JAX parallelization and makes your training loop run slower overall.
You cannot do logging.into() or print() inside your JITed functions because they don’t run in Python. To print some debugging info, you need to use jax.debug.print() instead. However, there are two things to note. First, if your function fails somewhere after the print and never finishes, you won’t see that print. The JITed function needs to finish for you to see all of its prints. Second, having any jax.debug.* function calls in your JITed function slows down its execution
significantly. Make sure they never make it into your production code.
This concludes our JAX FSDP tutorial. I hope it was useful. Happy JAXing!