{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "UEO53KYw8oBT" }, "source": [ "# FSDP In Jax NNX\n", "\n", "If you find yourself with the daunting task of implementing production-level [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) in JAX NNX, then this tutorial is for you. This notebook will guide you step by step through the process.\n", "\n", "You will learn how to implement a fully working FSDP on TPU — with all critical operations JIT compiled — that evenly shards all weights across the devices together with DDP. Additionally, you will see how to use distributed checkpointing to save to/restore from disk or GCP bucket via Orbax, set up reproducible `nnx.Rngs` for noise generation and dropout, and maintain an EMA model that is also sharded.\n", "\n", "Let's begin." ] }, { "cell_type": "markdown", "metadata": { "id": "PQqZeu8g9jZK" }, "source": [ "First, let's set some env variables. They will determine what packages we install." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QHaxHmpjSl7E" }, "outputs": [], "source": [ "COLAB=True # Set this to False if you are running this notebook outside of Google Colab" ] }, { "cell_type": "markdown", "metadata": { "id": "Ry2NvkaRFRMJ" }, "source": [ "Install Python dependencies based on the env variables above." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uyglrYBymZpJ", "outputId": "0fe5bd4f-e70c-4b4f-eeb4-ed7ed4306e94" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Installing ['jax[tpu]==0.5.1', 'optax==0.2.4', 'orbax-checkpoint==0.11.16', 'flax==0.10.4', 'numpy==1.26.4', 'torch==2.7.0', 'matplotlib==3.10.3', 'pillow==11.3.0', 'gcsfs==2025.9.0'] ...\n", "Requirement already satisfied: jax==0.5.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax[tpu]==0.5.1) (0.5.1)\n", "Requirement already satisfied: optax==0.2.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.2.4)\n", "Requirement already satisfied: orbax-checkpoint==0.11.16 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.11.16)\n", "Requirement already satisfied: flax==0.10.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (0.10.4)\n", "Requirement already satisfied: numpy==1.26.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (1.26.4)\n", "Requirement already satisfied: torch==2.7.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (2.7.0)\n", "Requirement already satisfied: matplotlib==3.10.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (3.10.3)\n", "Requirement already satisfied: pillow==11.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (11.3.0)\n", "Requirement already satisfied: gcsfs==2025.9.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (2025.9.0)\n", "Requirement already satisfied: jaxlib<=0.5.1,>=0.5.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (0.5.1)\n", "Requirement already satisfied: ml_dtypes>=0.4.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (0.5.3)\n", "Requirement already satisfied: opt_einsum in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (3.4.0)\n", "Requirement already satisfied: scipy>=1.11.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax==0.5.1->jax[tpu]==0.5.1) (1.16.2)\n", "Requirement already satisfied: absl-py>=0.7.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (2.3.1)\n", "Requirement already satisfied: chex>=0.1.87 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (0.1.90)\n", "Requirement already satisfied: etils[epy] in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from optax==0.2.4) (1.13.0)\n", "Requirement already satisfied: typing_extensions in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (4.15.0)\n", "Requirement already satisfied: msgpack in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (1.1.1)\n", "Requirement already satisfied: pyyaml in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (6.0.3)\n", "Requirement already satisfied: tensorstore>=0.1.71 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (0.1.77)\n", "Requirement already satisfied: nest_asyncio in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (1.6.0)\n", "Requirement already satisfied: protobuf in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (6.32.1)\n", "Requirement already satisfied: humanize in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (4.13.0)\n", "Requirement already satisfied: simplejson>=3.16.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from orbax-checkpoint==0.11.16) (3.20.2)\n", "Requirement already satisfied: rich>=11.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from flax==0.10.4) (14.1.0)\n", "Requirement already satisfied: treescope>=0.1.7 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from flax==0.10.4) (0.1.10)\n", "Requirement already satisfied: filelock in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.19.1)\n", "Requirement already satisfied: setuptools in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (78.1.1)\n", "Requirement already satisfied: sympy>=1.13.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (1.14.0)\n", "Requirement already satisfied: networkx in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.5)\n", "Requirement already satisfied: jinja2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.1.6)\n", "Requirement already satisfied: fsspec in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (2025.9.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.80)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (9.5.1.17)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.4.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (11.3.0.4)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (10.3.7.77)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (11.7.1.2)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.5.4.2)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (0.6.3)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (2.26.2)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.77)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (12.6.85)\n", "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (1.11.1.6)\n", "Requirement already satisfied: triton==3.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from torch==2.7.0) (3.3.0)\n", "Requirement already satisfied: contourpy>=1.0.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (1.3.3)\n", "Requirement already satisfied: cycler>=0.10 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (4.60.1)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (1.4.9)\n", "Requirement already satisfied: packaging>=20.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (25.0)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (3.2.5)\n", "Requirement already satisfied: python-dateutil>=2.7 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from matplotlib==3.10.3) (2.9.0.post0)\n", "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (3.12.15)\n", "Requirement already satisfied: decorator>4.1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (5.2.1)\n", "Requirement already satisfied: google-auth>=1.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (2.41.1)\n", "Requirement already satisfied: google-auth-oauthlib in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (1.2.2)\n", "Requirement already satisfied: google-cloud-storage in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (3.4.0)\n", "Requirement already satisfied: requests in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from gcsfs==2025.9.0) (2.32.5)\n", "Requirement already satisfied: libtpu==0.0.10 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jax[tpu]==0.5.1) (0.0.10)\n", "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (2.6.1)\n", "Requirement already satisfied: aiosignal>=1.4.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.4.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (25.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.7.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (6.6.4)\n", "Requirement already satisfied: propcache>=0.2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (0.4.0)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (1.21.0)\n", "Requirement already satisfied: idna>=2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->gcsfs==2025.9.0) (3.10)\n", "Requirement already satisfied: toolz>=0.9.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from chex>=0.1.87->optax==0.2.4) (1.0.0)\n", "Requirement already satisfied: cachetools<7.0,>=2.0.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (6.2.0)\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (0.4.2)\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth>=1.2->gcsfs==2025.9.0) (4.9.1)\n", "Requirement already satisfied: pyasn1>=0.1.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rsa<5,>=3.1.4->google-auth>=1.2->gcsfs==2025.9.0) (0.6.1)\n", "Requirement already satisfied: six>=1.5 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib==3.10.3) (1.17.0)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.4) (4.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.4) (2.19.2)\n", "Requirement already satisfied: mdurl~=0.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax==0.10.4) (0.1.2)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from sympy>=1.13.3->torch==2.7.0) (1.3.0)\n", "Requirement already satisfied: importlib_resources in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint==0.11.16) (6.5.2)\n", "Requirement already satisfied: zipp in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint==0.11.16) (3.23.0)\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-auth-oauthlib->gcsfs==2025.9.0) (2.0.0)\n", "Requirement already satisfied: oauthlib>=3.0.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib->gcsfs==2025.9.0) (3.3.1)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (3.4.3)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from requests->gcsfs==2025.9.0) (2025.10.5)\n", "Requirement already satisfied: google-api-core<3.0.0,>=2.15.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.25.2)\n", "Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.4.3)\n", "Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (2.7.2)\n", "Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-cloud-storage->gcsfs==2025.9.0) (1.7.1)\n", "Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-api-core<3.0.0,>=2.15.0->google-cloud-storage->gcsfs==2025.9.0) (1.70.0)\n", "Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from google-api-core<3.0.0,>=2.15.0->google-cloud-storage->gcsfs==2025.9.0) (1.26.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /home/georgy/miniconda3/envs/fsdp-jax-notebook/lib/python3.12/site-packages (from jinja2->torch==2.7.0) (3.0.3)\n" ] }, { "data": { "text/plain": [ "0" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "import subprocess\n", "\n", "packages = [\"jax[tpu]==0.5.1\", \"optax==0.2.4\", \"orbax-checkpoint==0.11.16\", \"flax==0.10.4\"]\n", "if not COLAB:\n", " packages += [\"numpy==1.26.4\", \"torch==2.7.0\", \"matplotlib==3.10.3\", \"pillow==11.3.0\", \"gcsfs==2025.9.0\"]\n", "print(f\"Installing {packages} ...\")\n", "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", *packages])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "xSMLegmvLLJe" }, "outputs": [], "source": [ "import argparse\n", "import functools\n", "import logging\n", "import os\n", "from typing import Any, Generator, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import optax\n", "import orbax.checkpoint as ocp\n", "import torch\n", "from flax import nnx\n", "from jax import random\n", "from jax.experimental import mesh_utils\n", "from matplotlib.figure import Figure\n", "from torch.utils.data import DataLoader, Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "t7DafZpO_lQz" }, "source": [ "Here, we define our hyperparameters and other variables we most likely would want to configure/adjust." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-b6V44MUSl7E" }, "outputs": [], "source": [ "args = argparse.Namespace(\n", " experiment_name=\"fsdp\",\n", " gpu=False,\n", " steps=5_000,\n", " test_interval=1000,\n", " batch_size=256,\n", " log_interval=100,\n", " save_interval=2500,\n", " checkpoint_dir=os.path.abspath(\"checkpoints/\"),\n", " output_dir=os.path.abspath(\"outputs/\"),\n", " lr=1e-4,\n", " add_noise=False\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "3qzuX0HBFXEi" }, "source": [ "Enabling INFO level logging." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "_U5MtEG0Sl7F" }, "outputs": [], "source": [ "\n", "log_format = \"%(asctime)s - %(levelname)s - %(message)s\"\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format=log_format,\n", " handlers=[logging.StreamHandler()],\n", " force=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "DABCSUHbAFVq" }, "source": [ "At the very beginning of our program, we need to initialize the JAX distributed framework." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "H9wGpLY_bG5X", "outputId": "1442e9f6-92de-45b7-f7f7-35d7106fd7e1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:2025-10-05 20:32:00,707:jax._src.distributed:130: Starting JAX distributed service on [::]:8476\n", "2025-10-05 20:32:00,707 - INFO - Starting JAX distributed service on [::]:8476\n", "INFO:2025-10-05 20:32:00,709:jax._src.distributed:147: Connecting to JAX distributed service on 10.202.0.129:8476\n", "2025-10-05 20:32:00,709 - INFO - Connecting to JAX distributed service on 10.202.0.129:8476\n" ] } ], "source": [ "jax.distributed.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "_hKGFEVfFZC1" }, "source": [ "Here you should see the devices available to you. In my case, it's 4 TPU chips." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uKUl0CebbKGi", "outputId": "d68e250c-e1f3-432b-cc0a-98edfb10187c" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "0AkPUd6HA5k_" }, "source": [ "Now let's define our JAX mesh and sharding axis. We only have one axis: `data`. We will use it for both the model and the data. This tells JAX to shard the model parameters across all devices in the mesh and also split the batch across all devices. Concretely, if you have `N` devices in your mesh, a single device stores `1/N` of the model weights and processes `1/N` of the global batch." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "37dcGAsVblut" }, "outputs": [], "source": [ "data_axis = \"data\"\n", "device_mesh = mesh_utils.create_device_mesh(\n", " (jax.device_count(),), devices=jax.devices()\n", ")\n", "mesh = jax.sharding.Mesh(device_mesh, (data_axis,))" ] }, { "cell_type": "markdown", "metadata": { "id": "1aOToQrtB3ZM" }, "source": [ "Here, we define two types of sharding. One that does sharding (along the `data` axis) and the other that doesn't, i.e., replication." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "gdxIDFYYb1TT" }, "outputs": [], "source": [ "data_sharding = jax.sharding.NamedSharding(\n", " mesh, jax.sharding.PartitionSpec(data_axis)\n", ")\n", "repl_sharding = jax.sharding.NamedSharding(\n", " mesh, jax.sharding.PartitionSpec()\n", ")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "H_Fwi14hCIBQ" }, "source": [ "In this tutorial, we will use a simple MLP as our model. It will learn a real number to a real number function mapping. It will also have a dropout layer to show how our implementation works with RNGs, which a real-world model would require." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "o1Lj5lHGSl7F" }, "outputs": [], "source": [ "IN_FEATURES = 1\n", "OUT_FEATURES = 1\n", "HIDDEN_DIM = 1024" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "jEQsFGXcSl7F" }, "outputs": [], "source": [ "class MLP(nnx.Module):\n", " \"\"\"A Multi-Layer Perceptron (MLP) neural network using Flax NNX.\n", "\n", " This is a simple feedforward neural network with two hidden layers,\n", " ReLU activations, and dropout regularization.\n", "\n", " Args:\n", " din: Number of input features.\n", " dmid: Number of hidden units in each hidden layer.\n", " dout: Number of output features.\n", " rngs: Random number generators for parameter initialization and dropout.\n", " \"\"\"\n", "\n", " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs) -> None:\n", " \"\"\"Initialize the MLP with specified dimensions.\n", "\n", " Args:\n", " din: Number of input features.\n", " dmid: Number of hidden units in each hidden layer.\n", " dout: Number of output features.\n", " rngs: Random number generators for parameter initialization and dropout.\n", " \"\"\"\n", " self.fc1 = nnx.Linear(din, dmid, rngs=rngs)\n", " self.fc2 = nnx.Linear(dmid, dmid, rngs=rngs)\n", " self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)\n", " self.fc3 = nnx.Linear(dmid, dout, rngs=rngs)\n", " self.rngs = rngs\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " \"\"\"Forward pass through the MLP.\n", "\n", " Args:\n", " x: Input tensor of shape (batch_size, din).\n", "\n", " Returns:\n", " Output tensor of shape (batch_size, dout).\n", " \"\"\"\n", " x = self.fc1(x)\n", " x = nnx.relu(x)\n", " x = self.fc2(x)\n", " x = nnx.relu(x)\n", " x = self.dropout(x)\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "PboRJpsGCwe1" }, "source": [ "Many modern codebases have an EMA (Exponential Moving Average) model. We will have one as well to show how it fits into our FSDP JAX implementation. To initialize our EMA, we will clone our model state, replacing the weights with all zeros." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "m0jNsPY7Sl7F" }, "outputs": [], "source": [ "def init_ema(model: nnx.Module) -> nnx.State:\n", " \"\"\"Initialize exponential moving average (EMA) state for a model.\n", "\n", " Creates a zero-initialized state tree with the same structure as the model's state.\n", "\n", " Args:\n", " model: The neural network model to create EMA state for.\n", "\n", " Returns:\n", " EMA state with the same structure as the model state, but zero-initialized.\n", " \"\"\"\n", " ema_state = jax.tree.map(lambda x: jnp.zeros_like(x), nnx.state(model))\n", " return ema_state" ] }, { "cell_type": "markdown", "metadata": { "id": "uNZ3QmI3DiZQ" }, "source": [ "This is the core initialization function where we initialize everything that we want to FSDP: the model, optimizer, and EMA." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "S3gE3Ht0Sl7F" }, "outputs": [], "source": [ "def init(learning_rate: float) -> Tuple[nnx.GraphDef, nnx.State, nnx.State]:\n", " \"\"\"Initialize the model, optimizer, and EMA state.\n", "\n", " Creates a new MLP model, wraps it in an AdamW optimizer, and initializes\n", " the exponential moving average state.\n", "\n", " Args:\n", " learning_rate: Learning rate for the AdamW optimizer.\n", "\n", " Returns:\n", " Tuple of (optimizer_graph, optimizer_state, ema_state).\n", " \"\"\"\n", " model = MLP(\n", " IN_FEATURES,\n", " HIDDEN_DIM,\n", " OUT_FEATURES,\n", " rngs=nnx.Rngs(0, dropout=random.key(1), noise=random.key(2)),\n", " )\n", " opt = nnx.Optimizer(\n", " model,\n", " optax.adamw(learning_rate=learning_rate),\n", " )\n", " opt_graph, opt_state = nnx.split(opt)\n", " ema_state = init_ema(model)\n", " return opt_graph, opt_state, ema_state" ] }, { "cell_type": "markdown", "metadata": { "id": "419932OAD9EG" }, "source": [ "It's convenient to have it as a function with no arguments because we will use `jax.eval_shape()` on it to figure out the shapes of the states we want to shard later." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mW_1idUwcZz1", "outputId": "8da35d3f-088b-491b-e959-582280b76d1b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:32:08,664 - INFO - Opt state shape: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'model'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1, 1024), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1024), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'opt_state'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m0\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptArray\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=int32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'mu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1, 1024), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1024), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'nu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1, 1024), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1024), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1,), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1), dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'step'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptState\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n", "2025-10-05 20:32:08,665 - INFO - EMA state shape: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(), dtype=key),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1, 1024), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1024), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(1024, 1), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "init_fn = functools.partial(init, args.lr)\n", "_, opt_state_shape, ema_state_shape = jax.eval_shape(init_fn)\n", "logging.info(f\"Opt state shape: {opt_state_shape}\")\n", "logging.info(f\"EMA state shape: {ema_state_shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AsW6gdkOGdmo", "outputId": "ebec9032-7f3f-47a4-fef4-da0ee86fce7e" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "J_iDlqQ5HI8F" }, "source": [ "As you can see, our EMA state has all the same parameters as the model that is stored inside the optimizer. The EMA state contains the rngs objects as well (more about it later)." ] }, { "cell_type": "markdown", "metadata": { "id": "Fxwbb0RaEimH" }, "source": [ "The next three functions annotate the state shapes with sharding info based on your mesh. What that means is that it determines whether a weight should be sharded (only weights with the number of parameters >= `1024*1024` are sharded). If it should be sharded, it virtually splits it into shards and assigns them to devices. We will see what it looks like below." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "H1y_d7sjSl7F" }, "outputs": [], "source": [ "def fsdp(\n", " axis: str,\n", " cur_spec: Tuple[Any, ...],\n", " mesh: jax.sharding.Mesh,\n", " var_state: nnx.VariableState,\n", " min_size_to_shard: int,\n", ") -> Tuple[Any, ...]:\n", " \"\"\"Implement Fully Sharded Data Parallel (FSDP) sharding strategy.\n", "\n", " Determines how to shard a parameter tensor across devices. Shards the largest\n", " dimension that is divisible by the number of devices and meets the minimum size requirement.\n", "\n", " Args:\n", " axis: Name of the mesh axis to shard along.\n", " cur_spec: Current partition specification.\n", " mesh: JAX device mesh.\n", " var_state: Variable state containing the parameter tensor.\n", " min_size_to_shard: Minimum tensor size to consider for sharding.\n", "\n", " Returns:\n", " Updated partition specification with sharding applied if appropriate.\n", " \"\"\"\n", " arr = var_state.value\n", " if arr is None:\n", " return cur_spec\n", " shape = tuple(arr.shape)\n", " axis_size = mesh.shape[axis]\n", " if arr.size < min_size_to_shard:\n", " return cur_spec\n", " dim_indices = sorted(range(len(shape)), key=lambda i: shape[i], reverse=True)\n", " for i in dim_indices:\n", " if cur_spec[i] is None and shape[i] % axis_size == 0:\n", " new_spec = list(cur_spec)\n", " new_spec[i] = axis\n", " return tuple(new_spec)\n", " return cur_spec\n", "\n", "def flatten_state(\n", " state: nnx.State, path: Tuple[str, ...] = ()\n", ") -> Generator[Tuple[str, nnx.VariableState], None, None]:\n", " \"\"\"Recursively flatten a nested state tree into (name, variable_state) pairs.\n", "\n", " Traverses the state tree and yields each variable with its hierarchical path name.\n", "\n", " Args:\n", " state: The state tree to flatten (can be nested).\n", " path: Current path in the hierarchy (used for recursion).\n", "\n", " Yields:\n", " Tuples of (path_name, variable_state) for each leaf variable.\n", " \"\"\"\n", " if isinstance(state, nnx.VariableState):\n", " name = \"/\".join(str(p) for p in path)\n", " yield name, state\n", " elif hasattr(state, \"items\"):\n", " for key, subtree in state.items():\n", " yield from flatten_state(subtree, path + (key,))\n", " elif isinstance(state, (list, tuple)):\n", " for idx, subtree in enumerate(state):\n", " yield from flatten_state(subtree, path + (str(idx),))\n", "\n", "def infer_sharding(\n", " state: nnx.State,\n", " mesh: jax.sharding.Mesh,\n", " axis: str,\n", " min_size_to_shard: int = 2**20,\n", ") -> nnx.State:\n", " \"\"\"Infer optimal sharding strategy for a model state using FSDP.\n", "\n", " Analyzes each parameter in the state and determines the best sharding strategy\n", " based on tensor size and dimensions. Creates a sharding tree that matches\n", " the structure of the input state.\n", "\n", " Args:\n", " state: Model state to create sharding for.\n", " mesh: JAX device mesh for distributed computation.\n", " axis: Name of the mesh axis for sharding.\n", " min_size_to_shard: Minimum tensor size to consider for sharding.\n", "\n", " Returns:\n", " Sharding tree with the same structure as the input state.\n", " \"\"\"\n", " flat_params = list(flatten_state(state))\n", " vars_states = [vs for _, vs in flat_params]\n", "\n", " specs = [\n", " (None,) * vs.value.ndim if vs.value is not None else () for vs in vars_states\n", " ]\n", "\n", " for i, _ in enumerate(flat_params):\n", " specs[i] = fsdp(axis, specs[i], mesh, vars_states[i], min_size_to_shard)\n", "\n", " shardings = [\n", " jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))\n", " for spec in specs\n", " ]\n", "\n", " sharding_tree = jax.tree_util.tree_unflatten(\n", " jax.tree_util.tree_structure(\n", " state, is_leaf=lambda x: isinstance(x, nnx.VariableState)\n", " ),\n", " shardings,\n", " )\n", " return sharding_tree" ] }, { "cell_type": "markdown", "metadata": { "id": "l0LKLHKWHkLL" }, "source": [ "Here, we call the top-level function `infer_sharding()` to get the state sharding objects." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "s7x4tKAycwam", "outputId": "51e11d33-da4a-4213-f8a8-f7e1e6a49402" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:32:08,680 - INFO - Opt state sharding: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'model'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'opt_state'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m0\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'mu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'nu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'step'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", "\u001b[38;2;255;213;3m})\u001b[0m\n", "2025-10-05 20:32:08,681 - INFO - EMA state sharding: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device),\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec('data', None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None,), memory_kind=device),\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0mNamedSharding(mesh=Mesh('data': 4), spec=PartitionSpec(None, None), memory_kind=device)\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "opt_state_sharding = infer_sharding(opt_state_shape, mesh, data_axis)\n", "ema_state_sharding = infer_sharding(ema_state_shape, mesh, data_axis)\n", "logging.info(f\"Opt state sharding: {opt_state_sharding}\")\n", "logging.info(f\"EMA state sharding: {ema_state_sharding}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "uY005CEDHzQZ" }, "source": [ "As you can see, all weights except for the kernel in `fc1` have `PartitionSpec(None, None)` assigned to them. It means they won't be sharded. This is what we want because their sizes are < `1024*1024`. Only the kernel in `fc1`, which maps from `1024` to `1024`, passed the threshold and has `PartitionSpec('data', None)` assigned to it. It means that its weights array of shape `(1024, 1024)` will be sharded on the first dimension along the `data` sharding axis. The second dimension won't be sharded because it's enough to just shard across one dimension to ensure all parameters are evenly distributed across the devices." ] }, { "cell_type": "markdown", "metadata": { "id": "dWHXrAh9JUgU" }, "source": [ "Now it's time to materialize our NNX modules on all our devices. The `jax.jit()` wrapper is crucial here because it lets JAX first trace our objects according to their state shardings and then, when the function is actually executed, the objects materialize on the corresponding devices. This ensures your entire state never materializes on a single device. Without the wrapper, JAX would first materialize the objects on a single device and then shard them, and if your model doesn't fit into a single device, the program will OOM at this point." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "yd6aXNN6c6m3" }, "outputs": [], "source": [ "opt_graph, opt_state, ema_state = jax.jit(\n", " init_fn,\n", " out_shardings=(repl_sharding, opt_state_sharding, ema_state_sharding),\n", ")()" ] }, { "cell_type": "markdown", "metadata": { "id": "HIiftxQIKdCw" }, "source": [ "Let's define a helper debug logging function to see what our sharded states look like in more detail." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Nm-hlk1kSl7F", "outputId": "15d1bddc-e2cb-46a5-dcec-ecdf91dfffd8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:32:09,510 - INFO - ── Shard ↦ device map: Opt state sharding ──\n", "2025-10-05 20:32:09,510 - INFO - model/dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,511 - INFO - model/dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,512 - INFO - model/dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,513 - INFO - model/dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,514 - INFO - model/dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,515 - INFO - model/dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,516 - INFO - model/dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,517 - INFO - model/dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,517 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,518 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,518 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,519 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,519 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,520 - INFO - model/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,521 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,521 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,522 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,522 - INFO - opt_state/0/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,523 - INFO - opt_state/0/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,523 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,523 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,524 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,525 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,526 - INFO - opt_state/0/mu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,527 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,528 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,529 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,530 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,531 - INFO - opt_state/0/nu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,532 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,533 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,534 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,534 - INFO - step () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,534 - INFO - step () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,534 - INFO - step () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,534 - INFO - step () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,535 - INFO - ── Shard ↦ device map: EMA state sharding ──\n", "2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,535 - INFO - dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,536 - INFO - dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,537 - INFO - dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,537 - INFO - dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,538 - INFO - dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,539 - INFO - dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,539 - INFO - dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,540 - INFO - dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,541 - INFO - dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,541 - INFO - fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,542 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,545 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,547 - INFO - fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,548 - INFO - fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,549 - INFO - fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:32:09,549 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:32:09,550 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n" ] } ], "source": [ "def log_shard_map(tag: str, state: nnx.State) -> None:\n", " \"\"\"Log the sharding mapping of arrays to devices for debugging.\n", "\n", " Prints a detailed breakdown of how each parameter is sharded across devices,\n", " showing which array indices are stored on which devices.\n", "\n", " Args:\n", " tag: Descriptive tag for the logging output.\n", " state: Model state to analyze for sharding information.\n", " \"\"\"\n", " logging.info(f\"── Shard ↦ device map: {tag} ──\")\n", "\n", " for name, var in flatten_state(state):\n", " arr = var.value if isinstance(var, nnx.VariableState) else var\n", " for d, idx in arr.sharding.devices_indices_map(arr.shape).items():\n", " logging.info(f\" {name} {idx} → {d}\")\n", "\n", "if jax.process_index() == 0:\n", " log_shard_map(\"Opt state sharding\", opt_state)\n", " log_shard_map(\"EMA state sharding\", ema_state)" ] }, { "cell_type": "markdown", "metadata": { "id": "4tH2XIPGKvIS" }, "source": [ "As you can see, our `fc2/kernel` weights get a partial slice on each device, whereas all other weights have the full slice on all devices." ] }, { "cell_type": "markdown", "metadata": { "id": "lvT9p4JNPJVd" }, "source": [ "Through some NNX model surgery, we define two graphs for our model: train and eval, which we will later use for training and testing, respectively.\n", "\n", "`train()` and `eval()` only change your graph (static information of your model). It does change the states.\n", "\n", "It's also worth noting that we store our models (NNX modules) as a separate graph and state at the global program level instead of a single module and only merge them right before we need to do a JAX operation on them. I find it's easier to think about and manage your models as two separate objects: graph (static) and state (dynamic) at all other times." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "f-DtJRWBd2j-" }, "outputs": [], "source": [ "opt = nnx.merge(opt_graph, opt_state)\n", "opt.model.train()\n", "opt_graph, opt_state = nnx.split(opt)\n", "opt.model.eval()\n", "model_graph_eval, _ = nnx.split(opt.model)" ] }, { "cell_type": "markdown", "metadata": { "id": "Gmw6qZlAPckS" }, "source": [ "Here we are initializing our distributed Orbax checkpointer." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6RAfzhI1eD55", "outputId": "1c3fe55e-d378-4329-ee42-d4da3006f4a7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:32:09,563 - INFO - [thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.\n", "2025-10-05 20:32:09,564 - INFO - [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None\n", "2025-10-05 20:32:09,564 - INFO - Initialized registry DefaultCheckpointHandlerRegistry({('metrics', ): , ('metrics', ): }).\n", "2025-10-05 20:32:09,565 - INFO - orbax-checkpoint version: 0.11.16\n", "2025-10-05 20:32:09,565 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 1\n", "2025-10-05 20:32:09,819 - INFO - Created directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints\n", "2025-10-05 20:32:09,871 - INFO - [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=2500, max_to_keep=2, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix='fsdp', step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=False, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False), root_directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints: \n" ] } ], "source": [ "ckpt_mngr = ocp.CheckpointManager(\n", " args.checkpoint_dir,\n", " options=ocp.CheckpointManagerOptions(\n", " save_interval_steps=args.save_interval,\n", " max_to_keep=2,\n", " step_prefix=args.experiment_name,\n", " enable_async_checkpointing=False,\n", " ),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "iggaEsJLPmri" }, "source": [ "For this tutorial, we will use a simple dataset that represents X and Y of the sinusoidal function. This is the function our model will learn to map." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "Jqto_JiYSl7F" }, "outputs": [], "source": [ "class SinDataset(Dataset):\n", " \"\"\"A PyTorch dataset that generates sine function data points.\n", "\n", " This dataset generates random x values from [-π, π] and computes y = sin(x).\n", " The dataset uses a seeded random number generator for reproducible results.\n", "\n", " Args:\n", " seed: Random seed for reproducible data generation.\n", " \"\"\"\n", "\n", " def __init__(self, seed: int) -> None:\n", " \"\"\"Initialize the dataset with a random seed.\n", "\n", " Args:\n", " seed: Random seed for data generation.\n", " \"\"\"\n", " self.seed = seed\n", " self.reset_seed()\n", "\n", " def reset_seed(self) -> None:\n", " \"\"\"Reset the random number generator to the initial seed.\n", "\n", " This is useful for ensuring reproducible evaluation data.\n", " \"\"\"\n", " self.rng = torch.Generator()\n", " self.rng.manual_seed(self.seed)\n", "\n", " def __len__(self) -> int:\n", " \"\"\"Return the length of the dataset.\n", "\n", " Returns:\n", " A very large number representing the dataset size.\n", " \"\"\"\n", " return 2**31 - 1\n", "\n", " def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:\n", " \"\"\"Generate a single data point.\n", "\n", " Args:\n", " idx: Index (unused, but required for Dataset interface).\n", "\n", " Returns:\n", " Tuple of (x, y) where x is a random value in [-π, π] and y = sin(x).\n", " \"\"\"\n", " x = torch.rand(1, generator=self.rng) * 2 * torch.pi - torch.pi\n", " y = torch.sin(x)\n", " return x.numpy(), y.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "9DN-pcbfQO8E" }, "source": [ "Converting our global batch size (across all JAX processes and devices) to a local batch size: the number of samples processed by the current JAX/Python process. This should not be confused with the per device/chip batch size because a single JAX process usually has more than one device assigned." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "zN6EyeOtg55K" }, "outputs": [], "source": [ "local_batch_size = args.batch_size // jax.process_count()" ] }, { "cell_type": "markdown", "metadata": { "id": "2P7BEXmdQwLd" }, "source": [ "We are getting close to the training code. Below, we define our `train_step()` function. It merges the graph and state of our optimizer into an NNX module, passes a batch through it, computes the loss, and updates the parameters with the gradients.\n", "\n", "It's important to understand that `x` and `y` in this function are global arrays. Their first dimension, batch dimension, is equal to the global batch size. All aggregate operations we run on them, like `.mean()`, run across the entire global array.\n", "\n", "The function also optionally adds noise to the data. This doesn't make much sense in our example problem, but working with noise is essential in some real-world models like diffusion models. It requires drawing from an RNG generator, and I added it here to show how our code supports it. The random values we sample are also global and don't repeat across JAX processes/devices. To see how our model handles the noise, set `args.add_noise=True`." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "OiGZWQrISl7G" }, "outputs": [], "source": [ "\n", "def train_step(\n", " opt_graph: nnx.GraphDef,\n", " opt_state: nnx.State,\n", " x: jax.Array,\n", " y: jax.Array,\n", " add_noise: bool = False,\n", ") -> Tuple[nnx.State, jax.Array]:\n", " \"\"\"Perform a single training step with gradient computation and parameter update.\n", "\n", " Computes the forward pass, loss, gradients, and updates model parameters.\n", " Optionally adds noise to the target values for data augmentation.\n", "\n", " Args:\n", " opt_graph: Optimizer graph definition (static structure).\n", " opt_state: Optimizer state (parameters and optimizer state).\n", " x: Input batch of shape (batch_size, input_dim).\n", " y: Target batch of shape (batch_size, output_dim).\n", " add_noise: Whether to add noise to targets for data augmentation.\n", "\n", " Returns:\n", " Tuple of (updated_optimizer_state, loss_value).\n", " \"\"\"\n", " optimizer = nnx.merge(opt_graph, opt_state)\n", " model = optimizer.model\n", "\n", " def loss_fn(model: MLP) -> jax.Array:\n", " y_hat = model(x)\n", " if add_noise:\n", " noise_key = model.rngs[\"noise\"]()\n", " noise = jax.random.normal(noise_key, y.shape)\n", " y_noisy = y + noise\n", " loss = jnp.mean((y_hat - y_noisy) ** 2)\n", " else:\n", " loss = jnp.mean((y_hat - y) ** 2)\n", " return loss\n", "\n", " grad_fn = nnx.value_and_grad(loss_fn)\n", " loss, grads = grad_fn(model)\n", " optimizer.update(grads)\n", "\n", " _, opt_state = nnx.split(optimizer)\n", "\n", " return opt_state, loss\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GSwpOwb4UcXf" }, "source": [ "Now we can wrap our `train_step()` function in the `jax.jit()` decorator for JAX to compile into an optimized XLA code that will be executed on the cluster of our devices. It's important to mark all non JAX array args as static. Additionally, to save memory, we tell JAX with `donate_argnums` that the argument with index `1` is also returned, and it can reuse the underlying memory buffer. Lastly, we mark that we expect our first output to be sharded because it's our updated optimizer state, and the second to be replicated because it's the loss as a scalar." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "X5YsTlo-gUQ7" }, "outputs": [], "source": [ "train_step_fn = jax.jit(\n", " train_step,\n", " donate_argnums=(1,),\n", " static_argnums=(4,),\n", " out_shardings=(opt_state_sharding, repl_sharding),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "hmwj7KgYVgec" }, "source": [ "Below, we define our test function, which works very similarly to the `train_step()` function, except it doesn't update the parameters." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "H3Lj0tHtSl7F" }, "outputs": [], "source": [ "def test_step(\n", " model_graph: nnx.GraphDef,\n", " model_state: nnx.State,\n", " x: jax.Array,\n", " y: jax.Array,\n", ") -> Tuple[jax.Array, jax.Array]:\n", " \"\"\"Perform a single evaluation step without parameter updates.\n", "\n", " Computes the forward pass and loss for evaluation purposes.\n", "\n", " Args:\n", " model_graph: Model graph definition (static structure).\n", " model_state: Model state (parameters only, no optimizer state).\n", " x: Input batch of shape (batch_size, input_dim).\n", " y: Target batch of shape (batch_size, output_dim).\n", "\n", " Returns:\n", " Tuple of (loss_value, predictions).\n", " \"\"\"\n", " model = nnx.merge(model_graph, model_state)\n", " y_hat = model(x)\n", " loss = jnp.mean((y_hat - y) ** 2)\n", " return loss, y_hat" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "XnUfm8LlgZPt" }, "outputs": [], "source": [ "test_step_fn = jax.jit(\n", " test_step,\n", " out_shardings=(repl_sharding, data_sharding),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Ml4Ek1hWVx6l" }, "source": [ "Here is our function to manage the EMA model. It updates all parameters in the EMA version using the `ema_decay` constant." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "mnngVe5uSl7F" }, "outputs": [], "source": [ "def update_ema(\n", " model_state: nnx.State,\n", " ema_state: nnx.State,\n", " ema_decay: float,\n", ") -> nnx.State:\n", " \"\"\"Update exponential moving average (EMA) of model parameters.\n", "\n", " Computes the exponential moving average using the formula:\n", " ema_new = ema_decay * ema_old + (1 - ema_decay) * model_param\n", "\n", " Args:\n", " model_state: Current model state with updated parameters.\n", " ema_state: Current EMA state to be updated.\n", " ema_decay: Decay factor for EMA (typically close to 1.0, e.g., 0.9999).\n", "\n", " Returns:\n", " Updated EMA state.\n", " \"\"\"\n", "\n", " def update_param(p_model: jax.Array, p_ema: jax.Array) -> jax.Array:\n", " return p_ema * ema_decay + p_model * (1 - ema_decay)\n", "\n", " ema_state_no_rng = jax.tree.map(\n", " update_param,\n", " nnx.filter_state(model_state, nnx.Param),\n", " nnx.filter_state(ema_state, nnx.Param),\n", " )\n", " ema_state = nnx.merge_state(ema_state, ema_state_no_rng)\n", " return ema_state" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "NmjRGK19gcRj" }, "outputs": [], "source": [ "update_ema_fn = jax.jit(\n", " update_ema,\n", " out_shardings=ema_state_sharding,\n", " donate_argnums=(1,),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "yb1MsCxZWA-w" }, "source": [ "This function is an essential DDP function. Our JAX cluster consists of several JAX/Python processes. Each of them loads `local_batch_size` samples. However, as we saw before, our `train_step()` and `test_step()` functions expect global arrays with a global batch size. That's what this function does. It takes local arrays/batches and builds them into a global, sharded array where the data from local batches is put into shards on the local devices." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "TXpo8OiXSl7F" }, "outputs": [], "source": [ "def make_fsarray_from_local_slice(\n", " local_slice: jnp.ndarray,\n", " global_devices: list[jax.Device],\n", " axis: str,\n", ") -> jax.Array:\n", " \"\"\"Create a globally sharded array from a local data slice.\n", "\n", " Takes a local data slice and creates a globally sharded JAX array\n", " by distributing the data across multiple devices and processes.\n", "\n", " This function is adapted from:\n", " https://github.com/google-research/big_vision/blob/0127fb6b337ee2a27bf4e54dea79cff176527356/big_vision/utils.py#L1388-L1409\n", "\n", " Args:\n", " local_slice: Local portion of the data on this process.\n", " global_devices: List of all devices across all processes.\n", " axis: Name of the axis for sharding.\n", "\n", " Returns:\n", " Globally sharded JAX array with proper device placement.\n", " \"\"\"\n", " mesh = jax.sharding.Mesh(global_devices, (axis,))\n", " sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(axis))\n", " local_ds = mesh.local_devices\n", "\n", " x = np.asarray(local_slice)\n", " xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)\n", "\n", " global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])\n", " return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)" ] }, { "cell_type": "markdown", "metadata": { "id": "hA4Eq5B2XbQJ" }, "source": [ "Finally training. We define our training loop as a function because we will want to reuse it later.\n", "\n", "It takes in a start step number and states, initializes dataloaders, trains, and evaluates our model according to the hyperparameters in `args`.\n", "\n", "As you can see, we call `make_fsarray_from_local_slice()` on the local batch that the dataloader returns before supplying it to `train_step()` and `test_step()`.\n", "\n", "We update our EMA model after every train step and periodically checkpoint our states to disk and run eval. The `jax.experimental.multihost_utils.process_allgather()` we call on our global array before visualizing is needed to tell JAX to bring these arrays to the host memory so we can do numpy operations on it.\n", "\n", "At least by the time of writing this, Orbax didn't have a straightforward way of saving the `NNX.Rngs` object in the checkpoint. That's why we need to filter out the rngs from the state and save them separately as a standard JAX array. We do this for EMA as well, because we will later need its rng object to restore the full EMA, but it's worth noting that the rng object inside the EMA never gets drawn, so it stays in its initial, zero state." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "fCv3qez3WB4e" }, "outputs": [], "source": [ "def train_loop(start_step: int, opt_state: nnx.State, ema_state: nnx.State):\n", " train_dataloader = DataLoader(\n", " SinDataset(seed=start_step), batch_size=local_batch_size, shuffle=False\n", " )\n", " test_dataset = SinDataset(seed=-1)\n", " test_dataloader = DataLoader(\n", " test_dataset, batch_size=local_batch_size, shuffle=False\n", " )\n", "\n", "\n", "\n", " train_iter = iter(train_dataloader)\n", " ema_decay = 0.999\n", "\n", " for step in range(start_step, start_step + args.steps):\n", " x_batch, y_batch = next(train_iter)\n", " x_batch = make_fsarray_from_local_slice(\n", " x_batch, mesh.devices.flatten(), data_axis\n", " )\n", " y_batch = make_fsarray_from_local_slice(\n", " y_batch, mesh.devices.flatten(), data_axis\n", " )\n", "\n", " opt_state, train_loss = train_step_fn(\n", " opt_graph, opt_state, x_batch, y_batch, args.add_noise\n", " )\n", "\n", " ema_state = update_ema_fn(opt_state[\"model\"], ema_state, ema_decay)\n", "\n", " if jax.process_index() == 0 and (step + 1) % args.log_interval == 0:\n", " logging.info(f\"Step {step+1}, Train Loss: {train_loss:.6f}\")\n", "\n", " if (step + 1) % args.test_interval == 0:\n", " test_dataset.reset_seed()\n", " test_iter = iter(test_dataloader)\n", " x_test, y_test = next(test_iter)\n", " x_test = make_fsarray_from_local_slice(\n", " x_test, mesh.devices.flatten(), data_axis\n", " )\n", " y_test = make_fsarray_from_local_slice(\n", " y_test, mesh.devices.flatten(), data_axis\n", " )\n", " test_loss, y_pred_model = test_step_fn(\n", " model_graph_eval, opt_state[\"model\"], x_test, y_test\n", " )\n", "\n", " test_loss_ema, y_pred_ema = test_step_fn(\n", " model_graph_eval, ema_state, x_test, y_test\n", " )\n", "\n", " y_pred_model = jax.experimental.multihost_utils.process_allgather(\n", " y_pred_model, tiled=True\n", " )\n", " y_pred_ema = jax.experimental.multihost_utils.process_allgather(\n", " y_pred_ema, tiled=True\n", " )\n", " x_test = jax.experimental.multihost_utils.process_allgather(\n", " x_test, tiled=True\n", " )\n", " y_test = jax.experimental.multihost_utils.process_allgather(\n", " y_test, tiled=True\n", " )\n", "\n", " if jax.process_index() == 0:\n", " x_plot = np.array(x_test).flatten()\n", " y_true_plot = np.array(y_test).flatten()\n", " y_pred_ema_plot = np.array(y_pred_ema).flatten()\n", " y_pred_model_plot = np.array(y_pred_model).flatten()\n", "\n", " sort_idx = np.argsort(x_plot)\n", " x_plot = x_plot[sort_idx]\n", " y_true_plot = y_true_plot[sort_idx]\n", " y_pred_ema_plot = y_pred_ema_plot[sort_idx]\n", " y_pred_model_plot = y_pred_model_plot[sort_idx]\n", "\n", " experiment_output_dir = os.path.join(\n", " args.output_dir, args.experiment_name\n", " )\n", " os.makedirs(experiment_output_dir, exist_ok=True)\n", " fig = Figure(figsize=(10, 6))\n", " ax = fig.add_subplot(111)\n", " ax.scatter(x_plot, y_true_plot, alpha=0.7, label=\"Ground Truth\", s=20)\n", " ax.scatter(\n", " x_plot,\n", " y_pred_model_plot,\n", " alpha=0.7,\n", " label=\"Model Prediction\",\n", " s=20,\n", " )\n", " ax.scatter(\n", " x_plot,\n", " y_pred_ema_plot,\n", " alpha=0.7,\n", " label=\"EMA Prediction\",\n", " s=20,\n", " )\n", " ax.set_xlabel(\"X\")\n", " ax.set_ylabel(\"Y\")\n", " ax.set_title(\"Sin Function: Ground Truth vs Model vs EMA Prediction\")\n", " ax.legend()\n", " ax.grid(True, alpha=0.3)\n", "\n", " plot_path = os.path.join(experiment_output_dir, f\"eval_{step+1}.png\")\n", " fig.savefig(plot_path, dpi=300, bbox_inches=\"tight\")\n", "\n", " logging.info(f\"Plot saved to {plot_path}\")\n", "\n", " if jax.process_index() == 0:\n", " logging.info(\n", " f\"Step {step+1}, Test Loss: {test_loss:.6f}, \"\n", " f\"EMA Test Loss: {test_loss_ema:.6f}\"\n", " )\n", "\n", " if (step + 1) % args.save_interval == 0:\n", " if jax.process_index() == 0:\n", " logging.info(f\"Saving checkpoint at step {step + 1}\")\n", " opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)\n", " opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)\n", "\n", " ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)\n", " ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)\n", " ckpt_mngr.save(\n", " step + 1,\n", " args=ocp.args.Composite(\n", " opt_state=ocp.args.StandardSave(opt_state_no_rngs),\n", " opt_rngs=ocp.args.StandardSave(opt_rng_keys),\n", " ema_state=ocp.args.StandardSave(ema_state_no_rngs),\n", " ema_rngs=ocp.args.StandardSave(ema_rng_keys),\n", " ),\n", " )\n", " if jax.process_index() == 0:\n", " logging.info(f\"Checkpoint saved successfully\")\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pi3KZYIobH1K" }, "source": [ "Let's train our model!" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7LaTMTa1ilah", "outputId": "3615cc5f-9cc5-4b3b-83e3-f1e9e23a871d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:32:12,456 - INFO - Step 100, Train Loss: 0.232647\n", "2025-10-05 20:32:13,548 - INFO - Step 200, Train Loss: 0.178514\n", "2025-10-05 20:32:14,642 - INFO - Step 300, Train Loss: 0.137890\n", "2025-10-05 20:32:15,734 - INFO - Step 400, Train Loss: 0.103556\n", "2025-10-05 20:32:16,828 - INFO - Step 500, Train Loss: 0.071781\n", "2025-10-05 20:32:17,921 - INFO - Step 600, Train Loss: 0.064227\n", "2025-10-05 20:32:19,014 - INFO - Step 700, Train Loss: 0.038317\n", "2025-10-05 20:32:20,155 - INFO - Step 800, Train Loss: 0.031034\n", "2025-10-05 20:32:21,256 - INFO - Step 900, Train Loss: 0.018365\n", "2025-10-05 20:32:22,357 - INFO - Step 1000, Train Loss: 0.011609\n", "2025-10-05 20:32:22,900 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_1000.png\n", "2025-10-05 20:32:22,901 - INFO - Step 1000, Test Loss: 0.010148, EMA Test Loss: 0.306501\n", "2025-10-05 20:32:24,108 - INFO - Step 1100, Train Loss: 0.007852\n", "2025-10-05 20:32:25,202 - INFO - Step 1200, Train Loss: 0.007552\n", "2025-10-05 20:32:26,299 - INFO - Step 1300, Train Loss: 0.003781\n", "2025-10-05 20:32:27,393 - INFO - Step 1400, Train Loss: 0.003243\n", "2025-10-05 20:32:28,493 - INFO - Step 1500, Train Loss: 0.002446\n", "2025-10-05 20:32:29,586 - INFO - Step 1600, Train Loss: 0.002355\n", "2025-10-05 20:32:30,682 - INFO - Step 1700, Train Loss: 0.002892\n", "2025-10-05 20:32:31,777 - INFO - Step 1800, Train Loss: 0.001434\n", "2025-10-05 20:32:32,872 - INFO - Step 1900, Train Loss: 0.002903\n", "2025-10-05 20:32:33,970 - INFO - Step 2000, Train Loss: 0.001691\n", "2025-10-05 20:32:34,360 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_2000.png\n", "2025-10-05 20:32:34,361 - INFO - Step 2000, Test Loss: 0.001314, EMA Test Loss: 0.070212\n", "2025-10-05 20:32:35,459 - INFO - Step 2100, Train Loss: 0.001170\n", "2025-10-05 20:32:36,554 - INFO - Step 2200, Train Loss: 0.001386\n", "2025-10-05 20:32:37,773 - INFO - Step 2300, Train Loss: 0.001572\n", "2025-10-05 20:32:38,883 - INFO - Step 2400, Train Loss: 0.000817\n", "2025-10-05 20:32:39,985 - INFO - Step 2500, Train Loss: 0.001038\n", "2025-10-05 20:32:39,985 - INFO - Saving checkpoint at step 2500\n", "2025-10-05 20:32:39,986 - INFO - Using JaxDistributedSignalingClient\n", "2025-10-05 20:32:39,987 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.\n", "2025-10-05 20:32:39,987 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 2\n", "2025-10-05 20:32:40,052 - INFO - [process=0] Saving checkpoint at step 2500\n", "2025-10-05 20:32:40,052 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500.\n", "2025-10-05 20:32:40,169 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500\n", "2025-10-05 20:32:40,326 - INFO - Wrote Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}}, json={\"item_handlers\": null, \"metrics\": {}, \"performance_metrics\": {}, \"init_timestamp_nsecs\": 1759696360220900523, \"commit_timestamp_nsecs\": null, \"custom_metadata\": {}} to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA\n", "2025-10-05 20:32:40,380 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs\n", "2025-10-05 20:32:40,387 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs\n", "2025-10-05 20:32:40,388 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state\n", "2025-10-05 20:32:40,390 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state\n", "2025-10-05 20:32:40,390 - INFO - No entry found in handler registry for item: ema_rngs and args with type: . Falling back to global handler registry.\n", "2025-10-05 20:32:40,391 - INFO - Deferred registration for item: \"ema_rngs\". Adding handler `` for item \"ema_rngs\" and save args `` and restore args `` to `_handler_registry`.\n", "2025-10-05 20:32:40,391 - INFO - No entry found in handler registry for item: ema_state and args with type: . Falling back to global handler registry.\n", "2025-10-05 20:32:40,391 - INFO - Deferred registration for item: \"ema_state\". Adding handler `` for item \"ema_state\" and save args `` and restore args `` to `_handler_registry`.\n", "2025-10-05 20:32:40,392 - INFO - No entry found in handler registry for item: opt_rngs and args with type: . Falling back to global handler registry.\n", "2025-10-05 20:32:40,392 - INFO - Deferred registration for item: \"opt_rngs\". Adding handler `` for item \"opt_rngs\" and save args `` and restore args `` to `_handler_registry`.\n", "2025-10-05 20:32:40,392 - INFO - No entry found in handler registry for item: opt_state and args with type: . Falling back to global handler registry.\n", "2025-10-05 20:32:40,392 - INFO - Deferred registration for item: \"opt_state\". Adding handler `` for item \"opt_state\" and save args `` and restore args `` to `_handler_registry`.\n", "2025-10-05 20:32:40,400 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:32:40,401 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.001122s\n", "2025-10-05 20:32:40,401 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "I1005 20:32:40.411459 39356 google_auth_provider.cc:181] Running on GCE, using service account 373177222751-compute@developer.gserviceaccount.com\n", "2025-10-05 20:32:40,498 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs/array_metadatas/process_0\n", "2025-10-05 20:32:40,710 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.308578s\n", "2025-10-05 20:32:40,713 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:32:40,715 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.002202s\n", "2025-10-05 20:32:40,718 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:32:40,726 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.008384s\n", "2025-10-05 20:32:40,731 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 70 Bytes/s (total bytes: 24 Bytes) (time elapsed: 338 milliseconds) (per-host)\n", "2025-10-05 20:32:40,732 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339135s (batch_requests_ready=0.000657s, total_serialization_initiated=0.334970s, others=0.003508s)\n", "2025-10-05 20:32:40,734 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 11.8 MiB/s (total bytes: 4.0 MiB) (time elapsed: 339 milliseconds) (per-host)\n", "2025-10-05 20:32:40,734 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339792s (batch_requests_ready=0.000642s, total_serialization_initiated=0.337574s, others=0.001577s)\n", "2025-10-05 20:32:40,735 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 70 Bytes/s (total bytes: 24 Bytes) (time elapsed: 339 milliseconds) (per-host)\n", "2025-10-05 20:32:40,736 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.339844s (batch_requests_ready=0.000344s, total_serialization_initiated=0.338248s, others=0.001252s)\n", "2025-10-05 20:32:40,737 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 35.4 MiB/s (total bytes: 12.0 MiB) (time elapsed: 340 milliseconds) (per-host)\n", "2025-10-05 20:32:40,739 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.341751s (batch_requests_ready=0.001705s, total_serialization_initiated=0.338225s, others=0.001821s)\n", "2025-10-05 20:32:40,739 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.412713s (all_items=0.002202s, per_item={'ema_rngs': '0.00070572', 'ema_state': '0.00052786', 'opt_rngs': '0.00048780', 'opt_state': '0.00048065'}, temp_paths=0.410511)\n", "2025-10-05 20:32:40,811 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state/array_metadatas/process_0\n", "2025-10-05 20:32:40,815 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs/array_metadatas/process_0\n", "2025-10-05 20:32:40,835 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state/array_metadatas/process_0\n", "2025-10-05 20:32:41,187 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.454033s (commit=0.270430s, array_metadata_write=0.183603s)\n", "2025-10-05 20:32:41,188 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 30 Bytes/s (total bytes: 24 Bytes) (time elapsed: 794 milliseconds) (per-host)\n", "2025-10-05 20:32:41,510 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.773855s (commit=0.577525s, array_metadata_write=0.196330s)\n", "2025-10-05 20:32:41,539 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.803838s (commit=0.621967s, array_metadata_write=0.181870s)\n", "2025-10-05 20:32:41,539 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 3.5 MiB/s (total bytes: 4.0 MiB) (time elapsed: a second) (per-host)\n", "2025-10-05 20:32:41,540 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 20 Bytes/s (total bytes: 24 Bytes) (time elapsed: a second) (per-host)\n", "2025-10-05 20:32:41,667 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.929092s (commit=0.732886s, array_metadata_write=0.196206s)\n", "2025-10-05 20:32:41,667 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 9.5 MiB/s (total bytes: 12.0 MiB) (time elapsed: a second) (per-host)\n", "2025-10-05 20:32:41,829 - INFO - Read Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA\n", "2025-10-05 20:32:42,064 - INFO - Updated Metadata={'item_handlers': {'ema_rngs': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'ema_state': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'opt_rngs': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler', 'opt_state': 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler'}, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1759696360220900523, 'commit_timestamp_nsecs': None, 'custom_metadata': {}} to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/_CHECKPOINT_METADATA\n", "2025-10-05 20:32:42,137 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:32:42,302 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.237342s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs\n", "2025-10-05 20:32:42,303 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_rngs\n", "2025-10-05 20:32:42,477 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:32:42,632 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.230508s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state\n", "2025-10-05 20:32:42,632 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/ema_state\n", "2025-10-05 20:32:42,802 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:32:42,967 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.241697s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs\n", "2025-10-05 20:32:42,967 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_rngs\n", "2025-10-05 20:32:43,138 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:32:43,315 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.250376s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state\n", "2025-10-05 20:32:43,315 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500/opt_state\n", "2025-10-05 20:32:43,401 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500\n", "2025-10-05 20:32:43,718 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500`.\n", "2025-10-05 20:32:43,718 - INFO - Finished synchronous save in 3.67 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_2500\n", "2025-10-05 20:32:43,719 - INFO - [process=0][thread=MainThread][step=2500] CheckpointManager Save Finalize is syncing with other hosts...\n", "2025-10-05 20:32:43,720 - INFO - [process=0][thread=MainThread][step=2500] CheckpointManager Save Finalize is done on all hosts.\n", "2025-10-05 20:32:43,720 - INFO - [process=0][thread=MainThread][step=2500] Finished synchronous save.\n", "2025-10-05 20:32:43,720 - INFO - {'step': 2500, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696359.9876494, 'wait_for_prev_duration_secs': 0.0002562999725341797, 'checkpointer_blocking_start_time': 1759696360.0524569, 'checkpointer_blocking_duration_secs': 3.6667628288269043, 'get_old_steps_start_time': 1759696363.7192426, 'get_old_steps_duration_secs': 0.00011086463928222656, 'checkpoint_manager_blocking_start_time': 1759696359.98681, 'checkpoint_manager_blocking_duration_secs': 3.7337262630462646}\n", "2025-10-05 20:32:43,720 - INFO - Checkpoint saved successfully\n", "2025-10-05 20:32:44,824 - INFO - Step 2600, Train Loss: 0.001182\n", "2025-10-05 20:32:45,925 - INFO - Step 2700, Train Loss: 0.000822\n", "2025-10-05 20:32:47,020 - INFO - Step 2800, Train Loss: 0.001010\n", "2025-10-05 20:32:48,132 - INFO - Step 2900, Train Loss: 0.001088\n", "2025-10-05 20:32:49,232 - INFO - Step 3000, Train Loss: 0.001274\n", "2025-10-05 20:32:49,624 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_3000.png\n", "2025-10-05 20:32:49,625 - INFO - Step 3000, Test Loss: 0.000465, EMA Test Loss: 0.008318\n", "2025-10-05 20:32:50,726 - INFO - Step 3100, Train Loss: 0.000773\n", "2025-10-05 20:32:51,822 - INFO - Step 3200, Train Loss: 0.001725\n", "2025-10-05 20:32:52,920 - INFO - Step 3300, Train Loss: 0.000757\n", "2025-10-05 20:32:54,167 - INFO - Step 3400, Train Loss: 0.001064\n", "2025-10-05 20:32:55,268 - INFO - Step 3500, Train Loss: 0.001229\n", "2025-10-05 20:32:56,372 - INFO - Step 3600, Train Loss: 0.001523\n", "2025-10-05 20:32:57,472 - INFO - Step 3700, Train Loss: 0.000991\n", "2025-10-05 20:32:58,573 - INFO - Step 3800, Train Loss: 0.004077\n", "2025-10-05 20:32:59,674 - INFO - Step 3900, Train Loss: 0.001708\n", "2025-10-05 20:33:00,776 - INFO - Step 4000, Train Loss: 0.001069\n", "2025-10-05 20:33:01,158 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_4000.png\n", "2025-10-05 20:33:01,159 - INFO - Step 4000, Test Loss: 0.000324, EMA Test Loss: 0.000607\n", "2025-10-05 20:33:02,271 - INFO - Step 4100, Train Loss: 0.000972\n", "2025-10-05 20:33:03,371 - INFO - Step 4200, Train Loss: 0.000642\n", "2025-10-05 20:33:04,469 - INFO - Step 4300, Train Loss: 0.000937\n", "2025-10-05 20:33:05,571 - INFO - Step 4400, Train Loss: 0.000771\n", "2025-10-05 20:33:06,673 - INFO - Step 4500, Train Loss: 0.001577\n", "2025-10-05 20:33:07,894 - INFO - Step 4600, Train Loss: 0.000988\n", "2025-10-05 20:33:09,016 - INFO - Step 4700, Train Loss: 0.001011\n", "2025-10-05 20:33:10,124 - INFO - Step 4800, Train Loss: 0.000930\n", "2025-10-05 20:33:11,225 - INFO - Step 4900, Train Loss: 0.001215\n", "2025-10-05 20:33:12,324 - INFO - Step 5000, Train Loss: 0.000650\n", "2025-10-05 20:33:12,697 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_5000.png\n", "2025-10-05 20:33:12,698 - INFO - Step 5000, Test Loss: 0.000054, EMA Test Loss: 0.000241\n", "2025-10-05 20:33:12,698 - INFO - Saving checkpoint at step 5000\n", "2025-10-05 20:33:12,701 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.\n", "2025-10-05 20:33:12,701 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 4\n", "2025-10-05 20:33:12,756 - INFO - [process=0] Saving checkpoint at step 5000\n", "2025-10-05 20:33:12,757 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.\n", "2025-10-05 20:33:12,863 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000\n", "2025-10-05 20:33:13,058 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs\n", "2025-10-05 20:33:13,062 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs\n", "2025-10-05 20:33:13,065 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state\n", "2025-10-05 20:33:13,071 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state\n", "2025-10-05 20:33:13,078 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:13,079 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.001208s\n", "2025-10-05 20:33:13,079 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:13,085 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.005938s\n", "2025-10-05 20:33:13,088 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:13,090 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.002231s\n", "2025-10-05 20:33:13,092 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:13,101 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.009317s\n", "2025-10-05 20:33:13,107 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 697 Bytes/s (total bytes: 24 Bytes) (time elapsed: 34 milliseconds) (per-host)\n", "2025-10-05 20:33:13,107 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.034985s (batch_requests_ready=0.000446s, total_serialization_initiated=0.032820s, others=0.001719s)\n", "2025-10-05 20:33:13,109 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 113.1 MiB/s (total bytes: 4.0 MiB) (time elapsed: 35 milliseconds) (per-host)\n", "2025-10-05 20:33:13,109 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036130s (batch_requests_ready=0.000589s, total_serialization_initiated=0.034237s, others=0.001304s)\n", "2025-10-05 20:33:13,110 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 669 Bytes/s (total bytes: 24 Bytes) (time elapsed: 35 milliseconds) (per-host)\n", "2025-10-05 20:33:13,111 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036366s (batch_requests_ready=0.000304s, total_serialization_initiated=0.035025s, others=0.001038s)\n", "2025-10-05 20:33:13,112 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 327.1 MiB/s (total bytes: 12.0 MiB) (time elapsed: 36 milliseconds) (per-host)\n", "2025-10-05 20:33:13,113 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.038096s (batch_requests_ready=0.001270s, total_serialization_initiated=0.034801s, others=0.002025s)\n", "2025-10-05 20:33:13,114 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.109971s (all_items=0.000032s, per_item={'ema_rngs': '0.00002027', 'ema_state': '0.00000477', 'opt_rngs': '0.00000286', 'opt_state': '0.00000453'}, temp_paths=0.109938)\n", "2025-10-05 20:33:13,182 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state/array_metadatas/process_0\n", "2025-10-05 20:33:13,185 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs/array_metadatas/process_0\n", "2025-10-05 20:33:13,192 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs/array_metadatas/process_0\n", "2025-10-05 20:33:13,208 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state/array_metadatas/process_0\n", "2025-10-05 20:33:13,695 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.587158s (commit=0.389604s, array_metadata_write=0.197554s)\n", "2025-10-05 20:33:13,696 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 38 Bytes/s (total bytes: 24 Bytes) (time elapsed: 623 milliseconds) (per-host)\n", "2025-10-05 20:33:13,866 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.754855s (commit=0.563322s, array_metadata_write=0.191533s)\n", "2025-10-05 20:33:13,904 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.794878s (commit=0.609138s, array_metadata_write=0.185740s)\n", "2025-10-05 20:33:13,905 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.8 MiB/s (total bytes: 4.0 MiB) (time elapsed: 831 milliseconds) (per-host)\n", "2025-10-05 20:33:13,905 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 28 Bytes/s (total bytes: 24 Bytes) (time elapsed: 830 milliseconds) (per-host)\n", "2025-10-05 20:33:13,982 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.869188s (commit=0.674423s, array_metadata_write=0.194765s)\n", "2025-10-05 20:33:13,982 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.3 MiB/s (total bytes: 12.0 MiB) (time elapsed: 907 milliseconds) (per-host)\n", "2025-10-05 20:33:14,454 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:14,612 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.230279s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs\n", "2025-10-05 20:33:14,612 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_rngs\n", "2025-10-05 20:33:14,774 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:14,946 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.240756s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state\n", "2025-10-05 20:33:14,947 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/ema_state\n", "2025-10-05 20:33:15,121 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:15,295 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.246438s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs\n", "2025-10-05 20:33:15,296 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_rngs\n", "2025-10-05 20:33:15,458 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:15,611 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.228544s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state\n", "2025-10-05 20:33:15,612 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000/opt_state\n", "2025-10-05 20:33:15,712 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000\n", "2025-10-05 20:33:16,003 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000`.\n", "2025-10-05 20:33:16,004 - INFO - Finished synchronous save in 3.25 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000\n", "2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] CheckpointManager Save Finalize is syncing with other hosts...\n", "2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] CheckpointManager Save Finalize is done on all hosts.\n", "2025-10-05 20:33:16,005 - INFO - [process=0][thread=MainThread][step=5000] Finished synchronous save.\n", "2025-10-05 20:33:16,006 - INFO - {'step': 5000, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696392.7011592, 'wait_for_prev_duration_secs': 0.00031065940856933594, 'checkpointer_blocking_start_time': 1759696392.7570903, 'checkpointer_blocking_duration_secs': 3.2475996017456055, 'get_old_steps_start_time': 1759696396.0047104, 'get_old_steps_duration_secs': 9.846687316894531e-05, 'checkpoint_manager_blocking_start_time': 1759696392.7007046, 'checkpoint_manager_blocking_duration_secs': 3.3052756786346436}\n", "2025-10-05 20:33:16,006 - INFO - Checkpoint saved successfully\n" ] } ], "source": [ "start_step = 0\n", "train_loop(start_step, opt_state, ema_state)" ] }, { "cell_type": "markdown", "metadata": { "id": "4gMQn2FSbTgy" }, "source": [ "Our model has been trained. You can see in the logs above that its test loss is pretty low. Also, check the eval images. The predicted graph should look just like the sinusoidal function." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAooAAAGZCAYAAAAO3hAkAADOP0lEQVR4AeydBWAbR/bGP7GZ7SQOM3PSUNOm3CZXZmama++udL3Sv1futVdmZmZukkLaBhpOw4yOmW3h/32bW1dxbMeJZUu23iSyVrs79JvV7qc3M28sAQnQoASUgBJQAkpACSgBJaAEahGw1vqsH5WAElACSkAJKAEloASUgEFAhaJeCEpACSgBJaAElIASUAJ1ElChWCcW3akElIASUAJKQAkoASWgQlGvASWgBJSAElACSkAJKIE6CahQrBOL7lQCSkAJKAEloASUgBJQoajXgBJQAkpACSgBJaAElECdBFQo1olFd7YVAps2bcK8efOwatUqVFZW1lRrxowZeP7552s+78mG3+9Hfn4+tm3bhq1bt9a83G73niTT5HOZ34MPPog1a9Y0Oa09ScDr9Ro8yXX9+vUgj0gKOTk5eOCBB1BeXr5TsegJbPv27TXtxW2ea7Yh27SpgdcY22Tz5s1NTWqP4vt8PuTm5qK4uHiXeOTAYx6PZ5djjdmxYcMG/Oc//0FVVdVuTy8sLDTY8705Q0lJyU5txzbk95HXJkNRUZFR59re33icLMrKynYpXmlpKRpTbraxeR0xz8Zw2SWzBnbk5eUZDFlHhtdeew3ff/99AzHqP/T777/jiSeeqP8EPaIEGkOAfhQ1KIG2RkBu+IEzzzwz0Ldv38D48eMDw4YNCwwdOjSwdOlSo6qPPvpo4IgjjtirasuNPNCrV69Av379jHSZNl9z5szZq/QaE6m6ujpwxhlnBGbOnFlzujzsAoMGDQr88MMPNfuae+PDDz8MjB492sh3woQJgeHDhxt8P/vss+bOutHpi4ANpKamBuQhvlMc8po8eXJNm2VnZwe6detm1IHXxrnnnrvT+bv7UFeb8NpIT08P/Pbbb7uLHtLjq1evDrRv395oFxEYNWmLgDSuc/L48ccfa/bvyQavr6ysrAC/U7sL8oPMYM/35gwXXHBBoHPnzkbbmd+/MWPGBNauXWtke/TRRwdSUlICX3/99U7FENFk7P/b3/62034R0YH999/fuB5ESO50rPYHptGuXTsj7yFDhgQGDx4c4P1EfjDVPnWvPi9ZssRguHHjRiP+ySefHPi///u/3aYlYjJALsHhlVdeCey3337Bu3RbCewxAXtjxKSeowRaG4H7778fIqrw5ZdfQh4ohjVx7ty5SE5ONqpy0UUX4bzzzquzWrQQxMTE1HmMO+VbZlgl5CaMiRMn1pyXmJhYs92YjYbyoYWIFqDgcvz88884/fTTa5KOi4vDr7/+itjY2Jp93GD5aG10uVw77Q/+0FDewecFb3/00UcQsQqylYcXWF9aZr744guIUAk+tWa7dj61P9ec2MgNEWf11ss8RmZkY7FYdkqV+954440aC6iIRohAxN13322cF8yaVqPaXHdK7H8farcJ82Q+ZlpNra+ZZ0VFhVFvm81m7trpndcKX7SQfvXVVzjxxBON47T6LliwwDjGstQODZXP5EkOdfGkdY7WZKfTWZOsWf/a7GtOqGeD16wZh/nys8mwrii06B1wwAF4+OGHaw4zvvkdpGWV1//LL7+MQw891DiHfNj+DodjF8sr7xUrVqyA3W4Hr3MRXDXp1t6gtTIzM9Ow8rH+8uMJF198MUQ0QkRZzen8DpNR8PewoevK5G1ev2ZCL774IqzWXTv/2HYsL18MbPtffvnFjGa8n3LKKTjuuON22scPzIsc6kqXx1lutgHP0aAEdr36lIkSaAMEKApHjRqFHj16GDe7pKQkTJo0CR06dDBq98knn+Bf//qXsc0uJwowPlSOPfZYQ/xRRDbUDcUHNh8WYqmpefGGffPNN4M3djNs2bLFeGjzneHGG2/EQw89hCuuuMIoz8EHH2w8yM3zKQhYrn333dd4EB511FFg3Lvuust4EFx//fVGvP/+979GlEsvvRSLFi0ytnljf/zxx424Yh3BaaedBrGw1By75JJL8Mwzz+Css84Cjx9zzDFYt26dcZx/+IB89tlnaz4Hb1B4sm7kwjzT0tIMrqw/2V111VXG6ezWO/XUU410xGILsewYXXN8gB1++OFGvgcddJAhZsz02VXLlxnY5XbSSSdh+fLlxq57770X//73v2HWnQ/jn376yTwdPP+yyy4z6j1lyhSIFanOBxyFBH8omG1mPpD5mWKHPx7I58gjjwRFJLuPzz77bMyaNasmLwpDsTwan8XKs0ubUBTw4cry8QHNHxINXUt///vfd2Eu1in885//NPL4/PPPceCBBxrXAtPidVBfSEhIMNri1VdfrTmFP2YOO+wwiLWxRiDzIK9RpjdJvhMnnHBCDWseYxc868hrkO0nlsideLIr+4YbbjBEEa8jXg8FBQWM2mDgjxq2K8WSGcRqZuTBa5xdrhRobF+KQLYDu73rCqYoNNuS72JBhCmkKdKOP/54LF682BgewTSYPwUSeVIIBYcXXngBf/nLXyCWRnCb36X6AvPmtcM8xXqM888/37ivUJBLjwWkJwOPPfaYkQ/ZMHz77bfG9c+8+Z3njyszkOfVV19t8OZ3hm3Oa4j5MPCHzFtvvWWejtmzZxvMyJ6s+MONQ2z4PSEvtinz4Davw+uuu64mLvexfIzH9ud9xBw6wuuc1/vTTz9tiOuxY8caaZvHaxLRjegjIF8IDUqgzRFgV5A8/AOXX355gN2lMo5vpzrKzTXArlMGecgZ3VHyYDS6ceVGHOjfv39AHtY7xTE/sGtKhFKAaXzzzTdG95Y8hIzD7Nq86aabzFMD7ILLyMgw3rmT3d3sxpMbf0AEntE9zi4zsXYYXVdyoza6RqdOnRpgFxTPYxfUwoULAx07dgzcd999gWnTpgVWrlwZEFEZ6NSpU+C7774z8hORY+T18ccfB+bPnx9gl9WIESMC8iAyjrOLrHfv3gF2E8tDLXDIIYcE5GFqHOMfEXsBERU1n4M3WA8RQTV5BR8L3pYHlnGePKwC8pAKyIM68Mcffxjluv322406y8MpIMI9QM4M7PIN7vZl9y0Zmd3sPCaWooCMKTXiX3PNNYE+ffrU1Ese1AH5UWCcP336dGOb7bO7LkR50Ab++te/GmUQcWRcA/vss4/Bl+VmVyu7puXBbZzDP++//77BkF26ZFy7TcTaZHRLyoO4UdfSI488Ehg4cGBAhLiRh4goY1jDSy+9ZOTP7nG2q1i7ArzG5AdOTVmCN5YtWxaQH0FGnrx22V5s9wEDBhj16d69e0AsjUaUt99+26gr33kdkB+HaLC+DGKNNLorZXxbQASOUT5287JLW0SDcc3yumFcDuWQHwrGtcZj/J6RCbvCawcyZvd4ME9+hzh8gTyvvPLKgPyIMK51XvvvvvtugNdTXUFEuDGMgOXjd5AvXm8iAI3Tx40bF5AfZEZZ+Z1hEMEeuOOOOwLyoyLA75kZeL2xXPzOyY8DY+iA/NA0D+/yLsLNuMbMA2wXEagBEX/GNSgqwkif/OTHjtHlLz9YA+RNLryGeF3xOAPvURz6wHRYB5ZdftAERDwbx4PvKWxn+YFq3GN4T+AQB6bHe8Ftt91mXDv8DnCYAffxO8O0GfiZ1zfbjtcu7wNMS35cGsfJkGUnHw7fkB+OxvdumtxvNEQ3Af5y0qAE2hwBPjB4k6T4482QwoQPND6sGGRwfkCsFsY2hSLP4Y3eDGLhqndsDwUIx6GNHDkyIL/cjYebWOuMqGKlC9xyyy1mMsaDIfjBSSHGB4MZKBb5kGK5OL6KDwjzAWKew3cKST7MTVHIfRxz17NnT0MI8DPLc+edd3LTCEyT5eTDlIFCUSZ5GNv8I93yRnyKEwYKi+DxbcbO//3hg5NCMXgcJgU4xRZf0vVmnMkHLesQ/HChQOS4xuBAwSzWFmOXWPICfJmB5aYwMYUkH+rBgpYPUI4Ro3hi21FUiqXPjB744IMPjPbeE6HIdCjo+XA0g3TtGYKUnMzAOlPY8fqiuKndJnt6LUl3oSHwzPGDfFhT/LMdZIKGUSbpLg2I1dssQp3vpoCQyUWG6BbrUoA/GNg2jMs0zfF6vO6Dx+hRQFCQsm7MkyKbIsQMYmEzrlGexx8oHPsnVssABQlfYp009rEuzD/4ejfTMN/Z5mLBND6SIUWiKVQo5HhdMI3dBbFMGvnw+2e+OCaZZWSg2JJu6QDH7XGb41X544JC7cILL9xJKPI+QXFtBrFkBsTCZ37c5Z33BrHeBqRLOyAWwkCXLl0M5hT7HM/J698UeYwsFvaAWLoN8cbvBduagk16FwL87vH7z3Kagfch/jDid4mB4y3Ne8o//vGPeu9LFNbkGRzYduZ3j2UTK2iA7WQGimiOs2TgtcfvAK9hM5CtWPPNj/oepQR0jGL0GZGjosbsgmKXH19y4wO7DEVwQB5iRhdNMAT57htdSSI4anZzTBbHNNUXGIfjo9g9Fxy43+wy4v7gbfOzPLRrorCLiWVlVxm739ilxe7y2kEeqkZ3WF1lYh7sHhKBBZncUhOVXWMiAGq63phPcN7Mi4F5M7DO9QV5+BmHxOpUc4o87CAiEGIpgTyEjP2sP7tAzS5+7mT3tjyIa+JxQ0Sr0S1o7gzmFLzN40wzuNzmuCwyETFoMBYLjZkURDwjPj7eiFezczcbzINj8Xh9mIH7GILLE7xdV5swDrk29lrieeyO5LAHdgWyS5jdjxxrxxevMb7Y/cg6ilAwzjPLWPud+Z9zzjlGFyrjc1gAy8P9Ztk5lo38zcB6cxwvuyV57bH7nJ/NIALL2MfxbOyOZ1cpy2mOb2O6LHtjuihZNg6n4LXKYRGcrWyOoRMxZHSTstufZWe3t1h8jfKbZTHf2YXMdBqa0ctz2MXKdmL3q0xAM+oX/B0iF7HeGu3O4SgM/M689957kB9dxrVs5mm+Mz1e39dee63BgKxkYptxmN8lEVs1Y6G5k1x5nYo100zC6CYXK68xe5pl4LYZuM36c3/tIJZiyOSd2ruNz+b1GNzWwSeyHPIDy+guN/dzXCXvj+Y9gENKgu8D/C7XNbbVjK/v0UFAhWJ0tHNU15I3Pz5UOP4weLxZbSjBD7r6brbBcXhjrh348Azez5ssx/eZD2meH/wAMLf5znLy4cYbN8db1Q48xxy4HnyM+5kvxRHFphmYN9MSa6m5q868aw42sMGHIQWDWOuM8U88lQ9UvsSaZQjx4OhmvbiP47ikqyv4sDGmivsZWHYyMgO3az+cgtMzt/lOAcuHHOtpijyKEMYPZm6mvbt3M22eZwoh8yHKfRzPWvs6qatNap/TUFk4HpDj3Di+jWPKKLzNQKHHsaacLMTxhpycQJbBbWqey3eykyEVxnVEUSFWphq2Zt0oQij4zMCykhnbgzyZBidsmNcgRQ73mbwpLDkOkm1fO+zOVZNYvQ3Rz/Gw0r1sjFkVi5qRTNeuXY26sy3Fqm78yOM1LV3StbMxPpv1qfOg7ORxtg3HHt5666148803dzmVZSBPCkqOzzXjcNyrWNhqRGxwRPIiG479rC8El411YL3NccXBceheh+eyzqZY5DbHcdZ1zfCHBQVfXYHpME5d8Xg+BSzblWlTADLwOqAw5LVultl853EzTW5riF4C1uituta8LRN4/fXXDSsXZz/yxihdaZBuW+MhynrzZh8sALhd+wYZfLw2q9rnm8dp0eNMaw7M5w2fDwfenM2bN+PVFhHcRyuHjB80ZuFy0gYfICw3RRjT4WQLPqA5kJ0WIc42ZprB6VEI0/rEhzWP0/cdg7gHMt5rl5n15T4zSNejMZnD/Bz8TssnJ1JQIDAPzjo1hSgn0wSL49r50GIkY6kgY6IgXYNGu3z66afGzGnmQauGdMkZDy0+oDmZg+LEFGrkFcyMcZgH8+SDU8ZhGmWjtZNs6EORwmZ3gWkEp1u73GROKx6tS+RJn5FPPvlkzXVCEVK7Tcyy7cm1RGscRZp0iRoiipZaBjKmpZbilOXg5AJeE3XVzWxL1oeWY05+oNBhuuQUXFfOiH7qqacMYUrrICde8BqVbmrDksjrkG3NtqCQ4HXEdHmNyphHwxpJaxqtgbwGOJGCkzVMlsyrvsCycTIVJ16QKy3+ZuCkGVoZKWTJgGI42IJtnsd35sU24XUS/DKthcH15cQxfvdpoTTjmmWVbmejPrw2eU3ynSKWAj24rY2I//tjsgjeZ26b7WB+5jt/BPB7Q1+IvP7ZnvTjSiFPC5+MUTYma/F7zjbn5BSeZwbmZ5aXXgeYjoxLNCy75GNOZmNbs714zyAT83o224NMeU3z+0F2ZM28OKnNvJeY59aVt7lP36OPgFoUo6/No6LGFAz33HOPMXuQN0c+CGkBYNcdA39R8xc2AwUJb9g8zwz8lV2f1Ybn0wrCLrragbNv2c19gMzapNWBD3cKIVP08GbOB6EZKMCYN2/UfIiyC5IzJWVskFFGPiBoSaK1kTOm2QVJywitTKwL45rl4ExUWhQ5W5Tlp0DgA4r15AOM9aGwMQPjmXlzH0UJZxpz9m9dgZYZChCKiOeeew6cSc4HM7u4yZrBZMl6mUHGiBkPJM7wZRwKEM7qZRcjA+tCcS2Ta4x6khnFH3kwkKPZ9c3PZh4UauTG7kcOK5gksz3ZrpztzgcujzUUgrsIzTSDrwHGZfcjLX5Mm/WU8V5Glzl5sny124SzV3ltBKfT0LXEPHguZ6JyFjXZmhZKCgq2MduR5/DhzvY3LaeMawbGYb4mMwo6M5BD8HVC4cThADLu07BCU+yx+9Xs3mfbmnWmNUzGvdWIDpaDP8Jo5aPbGfKmqGEbs/1YDuZllsMsQ/A7Z1lTfDI/ilMzUKhxNjCvEabJrmLO1K8r0JrJ2e2m6xvzHM5YZnl5rZvfM353ONPeDLyeyITXLr+r5sxk8zjfKWb54qxsGYMYfMhIt757Q+3vFCNypjd/+NEzAHny2mGghwLWkT+Mgq9fXmP8wWneM1h+MmHgUBeZAGV4RuB3jvcH/hBkGnwnD/LlcAMKcX5vzLIyHd4PZDKYIYj5o4Pikd4MGFh2/vAK/t4E522cpH+ikoBFLtpdB0JEJQqtdFsjwBuhafXgDdAUhqwnu3j5oOCDjl8Bihdumw84HqflxnzYBLPh+bRU8nzzoR58nL/K2VXIuBzLxHO5zRs/H/aMEzw+kPvM42Y6fJDzAU5REFwG1olWIMZn/iw3hUhwOWjh4TnsyjPzYbo8l59NEUMBwgcy0+fDgdsse3B+ZnmC3/lw4gOU5eZDN1i48BgFTjBLMy7FOtuDbUHBHBzIlMz4sGI3N8vKhyrbg3Vh+VhPhrrai3VhfD4YKZjInA/X4IdecH7cZjnJjeK5rjTN85k/68u2ZPosm/ng5jm128S8NhpzLZl5kDvTZfqmQOAxXoNsT/Na4PG6AuOzPeribtaN/IIFPH9UME8KIZOtmTbzpcWJQoFCoy6eZEKG/F6xTRnM9q99PZvpmu/Ml3zYxsGBAp/WMNaD1299gW3C7yjrzWC2M/mwjiwX34OvfzMtXucMPMZrktdwMHPzPFrrWD7z+2LuZ1vw3lHX96T2d8qMw3eWiRY/pkeRHJwu69HYewbTYluTP69dfl/Ma43HWCeWgz9smC45BZeVn3l/YZszrhlYJ17LPNfkyXyC71fmufoeXQRUKEZXe2ttlYASUAJKQAkoASXQaAI6RrHRqPREJaAElIASUAJKQAlEFwEVitHV3lpbJaAElIASUAJKQAk0moAKxUaj0hOVgBJQAkpACSgBJRBdBKJu1jMH59IFAgdrmwN2o6vJtbZKQAkoASWgBJSAEtgxOZATq+jVgpPI6gpRN5mFvqsOP/xww/lypAhFzhSkcK1rhl5djab7dk+AM/s4EzB4NvDuY+kZ9RHgLErOoOSsZA2hIcDZs7xB1zXjNjQ5RFcqeo2Gtr31uRRankwtEp9L9IpAH7qyFrzhrqmuWkedRZHikMuJ0dFoJAU+NFQohq5F6Ooh2BVI6FKO3pSUaWjbXr/zoeXJ1PQaDS1TvUZDyzNSr086aW/IcBaVYxT5S4mvSAm01NB/lYbQESBP3uQ0hIYAWeo1GhqWZirkye++htAQ0Gs0NBzNVPS5ZJII3XskPpcao4eiUiiGrtlDkxKVfENqPjS5RFcqyjS62ltrqwSUgBJQAs1DIOq6nhvCyHGCHEPAPvuWDPzlxtUCdKxS6Khz5QZ2PbM9gwMFJLv4tVs6mIpuKwEloASUgBKom4AKxf9x4XJNeXl5hoBoaesehSlfLIOG0BCgOZ0isXZbkjPHiXDZudrLloUmZ01FCSgBJaAElEDbIaBCUdqSoiI/P99YIzZ4TcyWbGZaFYPX62zJvNtiXmzT+rqfuZYs21uFYltsea2TElACSkAJhJKAjlEUmqao4ALr4QimRTEcebfVPBtiynbWSQRtteW1XkpACSgBJRBKAioUQ0lT02oVBCgiNSgBJaAElIASUAK7J6BCcfeMwnYGx9IVFxeH1ZXPvHnzsHnz5joZBAuu4O06T97NzlmzZhljRHdzmh5WAkpACSgBJaAEWpCAjlFsQdiNzYrLDN59992YOXOmMW6RYxcnTJiAv//97y2+MgbLwaV9zjrrrJriv/fee3jppZeMMYAUsuzK5QoTWVlZePTRRxs19u/pp5/GEUccgS5duhjp3njjjbjuuutw2GGH1eSjG0pACSgBJaAElEB4CbQqoWg6UK5rBZPS0lKEayJKKJuQY+cuvPBCYwzdk08+iczMTGzZsgVvv/02ioqK0K5du5rs6AKGAo2uXjjOkp/3lAFFKSd11HbNw/1c95Gsa0+yOeSQQzBy5EhDKE6ZMgV/+9vfcOCBBxrlYFp0M8R6NDTmk0Jz0KBB6Ny5s5EO68C8GJd1qauNayquG0pACSgBJaAElECLEAirUKQgeOihhzB16lRDLLz55pv1Cp1nn30Wb7zxhgHl9NNPxwUXXGBs06J17bXXYsOGDYa1jUvzcYm+lgg5JVX4eP4WbCgox7DOqZgypANiHbYmZf3777/j119/xYIFC5CcnGyklZKSgttvv93YLigowDXXXGO4d2G38M0332y4e/n3v/9tiD26fXnggQeQlpaGm266CRR1BxxwgCG+rrjiCoMVxeaVV16Jrl27Yvbs2YY4e/jhhw3hRrFJnmvWrEHPnj2NbufaQpHlMsuWlJSETp06oVu3bvjXv/5lCMQ//vgD+++/P1hurr9K4ctw2223GWUpKSnBypUrcf3116NDhw54/PHHDVH50Ucf4bHHHjOE8eWXX47TTjvNiKd/lIASUAJtkYDb64fXH0Cc88/nhtsLeGXBnjhdVr0tNnmrrFNYhSKtThQahx9+OB588EFDsNRFcc6cOcbazFy4muG4447D8OHDDasWBRKdVVNEvv7667jsssvw7bffwm5v3qqVVHpw2ydLsCKnFAkuO35emYf1+eW45pA+sIpT570NFH/du3evEWI5OTmGcHK5XOjRo4ch+Cio7rjjDrz22mug6GIX7n/+8x/ss88+Rvc0BeJTTz2FJUuWGJxYFo4hnD9/vmF1JPePP/4Yd955p8GM3cVM75133gEF49atW439y5cvx1FHHdUgS6bL9Bg4zpAWTuadmpoKtk3w2EUKyN69extp8p2WSFom09PTQQsm60qhyHJfddVVoLXSFKRGBvpHCSgBJdDKCPhECJZUeZBTUo11eRVYsLFEnhVVKK3yolx6UFx2Kw7om4XTxnTG9BU5+GDhfBRVlyAzwYURnbrh8L798EfxTHy9/hNU+ytxSBf58d/lANitO8RljN0Fh9XZyqhocVsTgeZVU7shwe5GWps2bdpkWJXqO53C6KCDDsLAgQONU2gh++STTzBixAhDFD7xxBNGFy2tjM8884xhrarPqsguzdr+9Wo7Za6vHMH756wvNERit/R4QxgmxTgwfXkuTh7dBZ1S997NDi1wwRa8GTNmgJbWn3/+GR988IFhLeW4Po4ZpBjj8ezsbEyePNkoHq2N5513niEoKS6Du5TJm3WlJZdpnHHGGYYQo1CnVZeB7xwryPGGfO277771CngjQtAfpn/uueca1k7uDq4HP1NEMn92j7NbmpZIc4wi95999tlG1zrrFR8fD/o7VKFIcuEPbB++NISOgDINHUumFEk8q8RSuF5E4acLtuLnVQXYWkxh6EZxVTl8jo1wxm6ExVYFm70c8Q4L1s4dhoVbR2FJyZso9P8Kt60SS4uBOQUZeG15F+QF5sETKGclMW3DNHSIz4bT5jCePR0TO+G43sdicvfJqPZVi4XSK2nGNxluJPFscmUiJIFIZMoy7S6EVSiahau9zJq533xnt7IpErmPFrdly5YZlkRaE9l9yUBxQRFC4RksFH/55Re8/PLLxo2EXdW0wvFlAqI4o1WMLwqrYCuYkXCtPzarBVVuz//2BuAXq5rFEpDuAlkNRPYH/E7ZVytSAx+ZH8UbQ79+/bBu3TqjbhzvR+spX6NHjzb2sawUgOb5HLdJgWYGlp/HWBfzHB4zLX9mPRmHaTHwPZhFsMDjeYzDYL5z2zw/uOzcRzFoBp4fnBbzMctkltE8l+U20+T1wM88Pzh989zGvDN9psf4wYH7eMycUR58TLfrJkBmvM7M9qjNtO5Yund3BDiumix5rWtoGoFIuEZZBoZfVxfj68VbsWRzAYqqxLG/axsSXSVISCxFdkwuLPIZVrdxrvz8Qpzc+p2+bVhf9AfaO5cjxRYPvyVBrg05DQGUW9ahoyVTzs8y4sjdHL4qP2TBWYkNrCpZixcKXsfWvK3YXL4F5Z4KjMwahgO7HASXTZ4VgR3PFka2WnZca7zu+K+hwPske3rMejV0rh7bPQFyJE/2dtJgFQn3UZaJz+lgDVBXTSJCKNZVsOB9vGCDb6bc5j4TtHkhm+/mfjMNTpg48sgjjQueE0PYxUqxZZ5PMcNt82XGq/ddzh3VLQ3tkmKwubASibEO5JVVY3zPDHQWCyPkyygSpd7odR0wyzJ+/HhjggfH+90mY/oofLm0HyeymOVjo5qBVlV2NS9cuNAYY8gu+D59+hjikRNh2G1//PHHGzOoly5dauxnOsFpMC1TBI4aNQpvvfUWDhCr7dq1a/HTTz/hpJNOMrIzy8gP5jbbwQxMI5g9x0t+/fXXRtrkTsF+zDHHmKcb4x8rKyuNL01wXKbNz3w3XzWRGrlhxuN7cDD38xrSCTPBZBreZruSGb83GkJDgOKbPIN/TIUm5ehMpSWvUd5WbFZ5fFpscs/bYWjgvl9WFuLp6R/Db/sZ3pRiad8SeMQ6WG3chvjDldJPxJrvz/tSlTRXO+lSdouA9PgdKJV7Kh/MfIJUyWkeeZr4A9W7NCo7npmKXYwUuZUVeOGPV5FuyYB8xO85i1AWqMCZA86oEYp2KW+5Ryyafhny5Uoy9geLyNoZ8B7MyYV6n6xNZu8/kyeNL5F0HzX1VfCzu3YNI0IomiLQfK9dSFoMg335cbt9+/bGbF0C53JsnJjBGy8nYwTPDGZaFIp8MdDayPF5jGeKCCp85s2XKSSMkxv4k5UUi1uPHIg3Zq7HlqJK7NsrE2eM7QKnfce4ETPtBpKoOcQGMhuJXbIce0hXOBwfyK5X1mvSpEkYMmSIEYddzeYYzF69euGGG27AxRdfbIz1ozWOrmcYOF6T3fFMhww5rpMXKeMGp8EbAXkycIIJu4A5PpATYjiGkJZahrrax2wHHmdXdfCyeBSYZE33OjyPVmEzLQ45+O9//4tXX33VGC7A8pizpPng5PmmdbKufJlfQ4E82QZ1xeV+5hFJX9aG6hIpx8hSmYWuNciS3726rtHQ5RJdKbXUNcrJJstXL0NB4RZ0yu6Lbp07GqAXb5mFSufTKI+tRoGNvVPyTPH/KQrrag0+hGP8NH64URSwiQVxx1kUiryPlYnhISD7uDs4JdoG+dkp3VfMIsljR5qnQPYxpgU/b/kFZww5Ay6rS/YE8OXaL/HV2q+M7unR7Ufj1H6nIt65494uEXYJFBAUNvqd3wXNXu8gTz7XIokpr7Hd3YMschKvqrAFCpsVK1YYlqbp06cbIo8gFy9ebHQPUtxw3BwnN/zwww9GOSmaHnnkEcPqdf755xtWN06G4WSXe+65xxjPV9+voNWrVxtpffrppzVwaF1bv369ITZNAbYnQNjlbJebwt4GNgF/vdXOm1Y4dk9xskdGRoaRPM+lFY6iKliMUizT8sjZx8Hp0NS9bds2Y7+ZBy+K4DS4n+1gijx+pjWRYs28oCkw6wrs+md7MU+maQpR81x28ZIthaop/My0KOqZL2dH853HKeDMOjblIco61CcUyZS8yEpD4wjwxwpvcpx8piE0BPh9JU+1KIaGZ0tco6tFGOaU5mLJgq8wY907KLWUI0P+nTr6Bkwa8xfc++W1mL7uc5S5HCjazSOBD172x9jlnp4pz6DT47phk7sIUz1yL5fniVVOyPYF0EnuZcvFm8ZWpxg3/icDSYzxmUWiCLpYEYvJohbT/TGyx4JSGc+YnNQPT5zwFmIk3q9bfsU9s+6BTSygm8s2o9RTioHpA3H5sMsxqfMkibNr4D2UQ7Q4ZlxDaAiQJ59z9emT0OSyZ6nwBwF7XKmp6OmkrhB2iyJnxnKiBgMnYRx88MGGVYsTN9jdSqF4wAEH4Oijj67ptuQ2xSLDbbfdZljNaLXi+XQN09KN0BSRaFSinj+0stUOFD+moAs+RjHJV+1AX4i0OjIEC8jgNPigqv3ZjFM7vdqfg+OZFsHgcygK60uL1kXTwhgct746Bqer20pACSiBliLgEzH39KxX8MGSl+H2FKJCupPTZMhRmj8eG1CA1xc9iFEDJiIhRiyIouBsYgKkZY9ijoHvMjDH6O7lD+FAwCdjEwPoKGKsOxw4NKUfjj7gLvjK8zBl8ZsoKN8mP3St6JwxAN0TOmH1wlexIGclckRAWiStCpnx7OGsZzmnnceBnu5qvJyaiK32AOSntihQG04urRJL5Y7x5TO3/makt7V8K8o8ZYZgXJi7ENdOvxYn9DlBJsMch35p/VhUDUpgFwJhF4r/+Mc/DL+AFDG0PplWJ3aB8gvFQOFAQblx40bjs9mNzA/c/vzzz40JIOz6VIuHgUj/KAEloASUwF4SKK4uxuyc2aj0VGKfDvtgS2EuPpj/kIwldItItKJMXhViDswUQZiGWGzx5mB72QZM6nkkvlv1hXTvyqRGUYzVnEgizzGniLrutjh0s8fL5JYUdI1rh0kxHdA5IRtJqT1g7bSPOE5ME5EHjOw+SYSeZ0fJZWYzQ/8hJ6H/qmnA1nlAyRagaANQvAmoLJV4YuVP6oCEvFx8E+9Elc2CscXlmNi7O+CihRHiwi0Vm4pzUOUvFSuk2CHl0eq0ybm+Kny06iMsyVuCW8ffij6pfYzz9Y8SCCYQdqFY30oiZpdncGGDBWLwforM+qxWwefpthJQAkpACSiB+ghUeivx2ZrP8eayN7GldAviZAZy19Tu6GJJQkxAJiJY40QAyjAhmXPssYizbHlZxCoY44pDrAi2HsmD8ff9/g9fLH0bpeKqJi25CzJjMzGy0wQMTOyKREciECtduXZnfUVgz7GYJGsN9YlLB4acsOPFmCJEUVEA5CwWMRiPgLjIGfrNv9FjxQyZrBJAUrveiJ10mWFx5Ok9KzORWVaItXE2sV/uMMDwjXbPtJgM8c/oxvcbvt9JKJozpBlfQ3QTCLtQjG78WnsloASUgBKIBAKcAfzcoufw2uKXxc1ZORzSretGJXIqylDiikWAbtGkoAnSreuQSSplMhO51CfLlYq18MSBZ6K9iEKGCX2Px4Tex0g3s7hbE6tdswQKyURZzpUvCdSWMcc9hJjNv8s06UoZ3DhcLI0iLv8XOq+ZievyvbhHfDaudrH8MoGBVkV5jy+WrvRYdoxbsE26pmdum4U4exxGZo6UGdUqEUyG0fyuV0E0t77WXQkoASWgBAwC2ypy8NPKzxErTrGt0n1LEVVlFcudpxoJzkSMd7XHtIqNcIt4ai8rrYxrPwLxHUdggPgsnND1IMM6V4NSxKPV6Eiu2dP8G3bpZu46oc58kmMCSKksx3+3u/FcShIWuZzYbpN6cCKmNx9Wmdzizs/FTTk3o6AiT0RuAN2TuuOagdcgSf5piG4CKhSju/219kpACSiBqCRQ6i5FnCPOmNhBANUVhXCXbJPZw8A2u/hylX2yKTOQLRjndeLqA27H4cvexbaSDejdbhi6jLxAVnn402rHNCI1pAz9Cyr+eAf9ytbjgWoPSmSM5bfxsfgtxiWTYkQEV3AZ2s9R0bUPOid3lzGMfmws3ohfxMVOt/bdIrVaWq4WIsDZ9RoiiACnqtNPJGdwBwd+5n4eb2yg6xquAU13MA2F5557DlyHuXagax2uErNy5cqd/FjWPq+xn+ma6MknnzRO59KDr7zySmOjGudxeUauha1BCSgBJbC3BOge5q6Zd+GqqVfjxh9vxNycuUZSqbLaSc9KWV1LhGG83GdLxI2YTbpnJ5VU4VhfEuxdRmHgYffioONfRZf9rm81IpGVc3Qbg6TjH0Ugq79YSMvF52IVjhdXLf+XX4qLS9IwqMqOUrusnFXpRsHmVdiYswqFVYX4eNXHuO7H6zB72+y9xa3x2gABFYoR1oj0szRu3DhjNRRz1jffTznlFGM/jzc20O/d888/bywb1FAcrlSzatWqXU554YUXcMghhxhrP5922mk488wzdys6d0kkaAeXYnz99deNPfShWFBQEHR0180rrrgCCxYsqDmQl5dnLGNYs6OBDbr82Z0T0Qai6yEloATaKIHH5z+BWVtnif9CWY+5ZB0emvsw2O0cm9oNx3i74pASNwZW+3BKcREe2bQVV5fFoOeYs3cMBCQTrsjSCoOlz6GwnPMpMPGv8LcfjHK4sMWfJcsBliJJFHGXmB4ozF0MX9labKlYh6LqQmMll6/XfS2i+ip8sPKDVlhrLXIoCLTOKz4UNQ9VGrIAO2TgM2KSQ5IinZzStyCdZHNZvqFDhxrvdL7N/TxuhiVLlhgikOcE+46kFZDxu3TpAvpRpHshM3CNbAq0QYMG1bgS4gzzupz+0hLJJQW5TjadA0+cONFY/pArrjBNprN9+3Zw2T+Wa/78+YZLI/q+DE5v0aJFxnEKN9Nf4n777WcIX7NcdLjN+lIUDxgwwDh/5syZGDNmjFFvrrxz5ZVX1rhPYjwKT7pM4pKFXK6QgWt5cxZ8Tk6OwYBlMfM0TtA/SkAJRB0Bt0w6+T3ndyzMW4SF2+ago9xvrAWbkeGIwUanHQtzF+HQrgcja9KN2O/ru3FE0RpZZSsRSX32Q8r4M+DsMrJtMEvuBEx+ABZ5ZgXmvovEpd8gxRkL19ATMfm3t+CW2TozDRc7HsTJqMtYsawmWVwoFzc67yx/B/t32h/psa2ju71tNFhk1EKFYlPaYf2vwOxngdJtsn7dAGD8lUBq16akaAgl+oI84IAD8O677xpCkRa/yZMnY9q0aUbaFGXXXnutYW3jii20tFHMUUwxDn1Ocrk8ii9aICna2GV94403GmKOoopLGbIbuH9/6YqoJ1AMmquo0Ds/06cQ41rQL774orGyCVdc6dSpkyHimAeFHn1hPvvss8aKOVyHmuXu0aOHsRqKKSDffvtt/Pbbb0YZ1qxZYzhbZx4UeRSkffv2Nbq9uRzhJ598gkcffdRYqvDYY481nK/T2vnEE08YInH58uXGijy0fj711FP44osv0L17d4MLy8S86nPDVE/VdbcSUAJthABnHz8y71H8uOknGXvnw1ZxZh2o9iLd50Sgqhhehw8JXCNPQt/h47Gly2vI37oBtpRMpHSSpUTbCIfgalhk6b6EsecgYcxZYim1gquTtf/uYZxdloLuMqP75cQ82S8zo8UQEvDK5B6xopaIuOS4ThWKwSSjY1u7nve2nYs2At/dCuSvlq4IcVWwTm5C0/4NeHddvH1Ps2CXMVefmTVrlmGxmz17tvHZXCfyxx9/NMQXBdR7771nrAF97733GsKQIvHhhx82unjPOussQyzR2vjNN99gzpw54Io3XEv6xBNPNARlQ2WjaOPyiuwu5go4tAxSsFJ8UoBRLHLpRK7ZTLH4/vvvGy/Go7ilgPvoo4/A5RKZxuDBg2u6rulcnUtuMfzzn//EhAkTjCUYKXS5DvRhhx1miOS77rrLEL9cTpD5UohSrLK+FJFvvPGGsZLPzTffbBzjMoIUoxSqn332mZEfx0NqUAJKIPoIUNh8Kn4Rp8kax53EXNatvBjtPV5slSXxcuW2vVVcxfSu8KJXQWENnOz0ZAweNBjd26hIrKkoN0QkMtjlnhnXYx/YqvIwsaQMEyqqZfUXCyrpEkhedl8lUqtcSHFlGufrn+gioBbFvW3vzXOAqiIgraekEABSOovzU5kQQo/5Gb33NlUjHn/dsTuVaxFTANGSx8/cTxH0+++/G9225io0FJV33nmnYSXk8bFjxxrpsHu3Y8eOxjbjsKuWYx0p8iioTCtbcNd0cMG5n13O7OJmFzaFV79+/QwxSKuf6RSdabObmssoUsixO3rYsGGGsGQXt7lO9ZQpU4yyMw+mbVoXOVnmhhtuqMm6vq5idl1ThNICSevjyJEjjTi0JN5+++1GGZguy2ZaQsmOFlcNSkAJRBeBHzf9iCcXPIl1xevgcZfB7gHSvUBHeMTFjQX9PPHo47Zh34IKWYElJrrg1K6t8Eg78ErYZKILln+FG3Lz0cdrxa/JWfLc8aOfWGAPlq76QJEI6vbxWCRd+Evzl6KTOPoe1W6UscpL7ST1c9shoEJxb9uSYxJFFIFjFGlRpCWR3vadcXubYk08ii0KnuOPPx4UV1yikJ+5n4GijVY1M9DSRtHGdZdpqaPlkVZEdj1TEDKwO5jC6rHHHjPOoeAyBZmZrpme+c60OEaQVr3awRR53E9Rduqpp9aIUObF8ZSmRc+MSzEZPMbSzJdlCa6PeT6P81hw4D6Wm9ZICmce5+xuil/zXDNdxuM22WlQAkogeghwPOJNP91kTMawy/ffLT/mN9pl8bqA3VhzeUB1Ff5aYIVDrIsrYocgtfu46IFTT01tCZlIO/5BVGw6F2VvnIez8tdioghHh8x+dvi5Ik0FEsQB+SvL3sATvz+CSrEyxthicGzvY/G3UX+TVWvkOaihTRLQrue9bdbOYwCuz1mwZocVketvDjxW1tzM3tsUjXgUNhRBfB144IFG9zPfzX0UgYceeijmzp1rCDFa+x544AEcddRRYPcsJ7CwW5azmNktzEkwFGfHHHOM4QKH3bAUWDxuupphmsECzqwAz6MIqx24n3HMcMYZZ+DNN9/E+vXrjf3sKuf2vvvua1gx2Q1N9zvsomY9GJgGhSwDxx3eeuutxmQWnscudwZaTL/++mvDCsn8KFwZh8s10hp6//33G/W44447MGLECEMo8zjPM0N9dTOP67sSUAJthwDFikfWSX5j6ZsoFYGTLPfTeK8bLnn3ypi7Ipnd297nwnGFNiy2D8d32efA8Zd7kZ2V3nYgNLEmcZ0GIf7YB1ApK7vEixsdpycBFR4HkrsMErHtxlMya9wv/5KdySK/A8ZsaNPFUBOz1ugRSmBnc02EFjIiiyWDgXG4WNqWfykeWddJf4Z0g/Y8QIraNOsVLYPswqVFkJY6CiAGfuZ+dr9yYsjjjz9uTOagKGLX80UXXWRYzjhBhd24f/vb38Cu58suu8yI37t3b2OiByeF0G8irY/nnnuucWzSpEk1XdTGjv/94YSYtLS04F3GNruTgwUkrYkUfrQ8UnCyq5llaNeunVFGCjrmRzFr+nSk2GMcBpaV9eLEF75zVvU+++wDCkCKS465pBhmOTkWkpbDl156yehuvuaaa4yueZ7HMGTIEOPdtCpy7CO78DUoASXQtgnYLDbMzJmJjzZ9hBmbf5L7oTikkNnOTumIscqHRHGHc3mxVZxLV2KhfRRST34E47KTkeZq21z2pnbxvfeD+5iHUSYzoWOKViCtXS/EHfxP/FC2RgR4MVJFkAc8xQLWhnJ/Ne6f8wDOGXg2JnefLKzV/rQ3zCM5jkUeqDv6MyO5lCEsG50+X3XVVcYEC4oSBgoWWsA4ns3svgxhlrtNik1AgRXKvJlmfV2uzCu463i3BdyDE9gFbHLdg2h7fWp9+bGOrH9dZaFYzc/PVwG5B9RpCaZ11hwXuwdR9dR6CHD8L3k213exnmzb7O6txVtx60+3ocSfj+Ki9djsqzDuAXZ5wlWLaDyzuAxHFWbgO+wDz/AzcfXR41p6kb1WxZ6O2Eq3b5IJLAIwoR0g61b/tH46rv7+chHfYlMUQVgpj1Dp0EfvtD7wyIzy60dfh/HZ41tVPVuysBwmxqFZNPxESuAz9Mgjj8QjjzyCnj0552LXoNJ/VyZtYk99IpGVa84HU13CrDmBtnR+zVkXTVsJKIG9J7BafB96yvOQUbwNnSpKke31iYSRuRfeAK4oLMV5+VX4JOM8eCbdhLMOHq0icXeo5cd2QNa4RrJM1BSRyNBHJgYeWFIqq9eIJVFWsOG/jtKrlVkuk4X84i1u2+zdparHWyEB7XpuhY2mRVYCSkAJKAH6mwhg6oap+El8JJaUl8BdWQhvRbH4PrQjO1At3c0B3JlbiL5VFfjWPglTjjwBg7pmKbrGEhDLYXBwuqtxnfAc7bHjjWSXMe6zh0wIchesR7krBjMcvxgTiCZkT8DEThO1GzoYXiveVqHYihtPi64ElIASiGYCby17C/fNug9emcucYklCnNUFj9OKGHGg7RFr14mlZSgs74rbLOPQYdx5mNRFRWJTrhdnn0nwxnfHyQWr0UHWv34sLQGb7WJdFO4FVhkPWpyDRdYl+G3rTFR6K3F498Obkp3GjRAC2vUcIQ2hxVACSkAJKIHGE9hcthkvLH5RXN74kSYdyfEylE7WhkKqX3zJVifinOI4jC5OwYe9/40hp96OCw4ZBoeMVdSw9wTi0zsCJ76AtZ0Px6hyD67O82NUdZqsbOOQLn4LssVymySuh7ziLu7lP14xfFjufW4aM1IIqFCMlJbQcigBJaAElECjCEzfOB3nfXUeNpduMiZRVHirYPN5DKHYx+PG5UUV2L+0BGsyDsP1pxyE4walQobUaQgBgczeo9D+ovexddAFyK6w4aTSOAyutooLIi+qZCWcJQXLkFOyEasKVuCmGf/E/O3zQ5CrJhFOAioUw0lf81YCSkAJKIE9IrCxdCNu+eUWrC9ZB7vMvKXbjkrxYFEh2w6Riu0r0vExxuHltMuQctDVSIu17VH6evLuCcTKoLXsw65FoN/hsAfKsF9ZiYxXtGGNzOgtl7agsOgik4mqxL/w84tfQEFVwe4T1TMiloAKxQhrGrog4drIr776Kl5++WXjxZVZGNatW2ese0zXLmagg2musTx//nxzl/G+ePFiY33nnXYGfaD/RabLfPiik+ymhr/+9a8w11U+7bTTDGfY9aXJFWNycnJqDi9YsABcm1qDElACSqAhAl+u/RLbyrbKeESHdDvv8FzLKRc+cYd1UmExDkjti84n343Tz74EBw3o0FBSeqwJBGJTstDt9EeQeckH6Nl+f5xdYJPJQ3YZH+ozRKLPU4nNFdvx6+af8Y8f/qEzopvAOtxRVSiGuwVq5V9WVmY4z/7ll1/ANZBXrFhhrG7C07777jvDGfULL7xQE2v69Ok44YQTDCfa5k76UKRo42osGzduNHfv9E6H2Zdeeil+/vlnY03m888/33BwvdNJe/hh+fLlxtrQjEan2XU56zaTZP2uvPJK86OxqgzroUEJKAEl0BCBYrFOWWWiitzmDMsVxSKF4nlFpTi6sBKV2RNxSO9EdEvRJeUa4hiKY1bpzue4RXePgzC61I3ziyqR7pPlVKVNNjvs8Ih4z/BbkVu2BU/Iii6lntJQZKtptDABnfXcBODVvmrM2joLHFQ9IH0AhmQOaUJqO6JS5KWkpBiiLTMzc6f06Bhz1KhR+OGHHwyn4VzF5YMPPsC4ceN2cta9aNEiY6k8Lo3H5fO48kntwHy4ZvT1119vrPQyefJkY6WWK664wlhKj6urfPPNN8YygsOGDcMnn3xiLMPXv39/Y4UV008jhSatiFwTmo5ETb+GXJWFK8sw0Er60Ucfgc7Ou3fvDubFOlBY0mrap08fcLUXxjEDRTItq6wjBSSP0XpKgZmVlYUvv/zSEJcnn3xyTT5mXH1XAkqg7RIYUeXDez4vSuwusSgGxJm2BfvJD99Di4GP4o7HPkP+0nYrH6E1az/qSCzcuBIjlr+AUllZ+43kVHHILcv8iWgs91egsHQz1stz8uHf/yvrQl+LOHtchNZEi1UXAbUo1kWlEfu4NNTDvz+M+2bfh7eXv41bZtyC91a814iYDZ9CAcaVYiikfvvtN/z666+GZZGxuH/w4MFITk42uopzc3MNi+HBBx+809rLXHd5//33Nyx2H374oRGv4VyB8vJywxE3xdg555yDm2++WX6xB4x9XI7v9ddfNwQshd0dd9xhJPfZZ58Z1k868H733XeNslLYMXD1G652wy7uM888E2+99ZYhTNlFznIXFBQYywBSEPLz0qVLce211xpxKXRpDeUKKhSX9BrPc7g6CNO68847DXH4zDPPGMsGGpH0jxJQAlFBoPe2AlxWWIbBVW50lnFwJ5aW4+6cPHwafyY6HH4NBndKjQoOkVTJxLhYDD/5n9g88X5MLEzA1YU+dBD/ioaQl2X+LGLkiBeXRTM2TMf34vdSQ+sioBbFvWyvBbkLZD3RGeiU2AlciL5UZnt9tOojTOo8CRmxGXuZ6o5otMC98sorSEwUr/gSaK3jWs0MtNIdfvjh+Pjjj9GvXz+MHj3aWFt5+/btxnGO/fv++++NtZAHDBgAWiHnzJmDsWPHGsfNPxSkzOef//ynIeBoqWN3Na2ZzINCkdZL07JHoZiamgqu/8x1pbnG8hNPPIF//etf4FrPTItd41w6j4GCkcsUMd1Vq1Zh1qxZO60Iw/WpN23aZIg+nk9RTIskA9er5rrQt956q/H5+OOPx9tvvw12jzPNW265BbRssu60iPKzaeE0IugfJaAE2gQBv7i++WHTD4ZD7Vh7LCb3/AtK4kdiStGrOKZqm0ygsCBd/PV9498XPQ69AJN6q0gMV8PHiZoYMPEYfF1YjZ5Ln8Bk+0Y8lypLVMrAAE456iJuc6q9eVgqs6KP7KFW33C1097kq0Jxb6hJnLzKPEOcUCTS8sabWElFiSEYmyIUjbRiY/H444+jc+fOO5WOx2jxo1B87LHHDKsjLXxfffVVzXnsBl67dq0xUeXbb781zqc1r7ZQZARaAml57NGjhyG4evXqZYwxpMXSzHvDhg2gCP33v/9dIwKHDh1qWPeKiooMwca0KPIY3xSK3MduaIpECtraywZSWAafy/NNscdxlZMmTeIuI9CKyjqx/hSyZhd1rHBiHhTDtdM34+q7ElACrZcAJ648ueBpJLoS4RH3N/Pz5uPojpdjUdoFGFH4ORJQhQ/9/VE09npc1qe9uGephCMmtvVWuJWXPDXGiiOPPxXLZrbDIV9ehlVOabNYC7p4fOI+B1jrL0O7Cm8rr2X0FV+F4l62OcckxthjkF+Zj0RZDzO3MhddE7uiQ3zTZ9lRMJlj/WoXj6IoKSnJsPYtW7bMEGocy2eG1157DYceeqghnNhVzS7c9957z+jGNS2UPJeii+KK3dYUeLUDjzNwrGKHDh2MrmfT4meeS6FGAccxjCwXu5qDBRv3de3a1eg+NuOY70zfzCN4H7fbt29vdEWb+zmhh9ZD83ymy2B+Ns/TdyWgBNoOAY/fg2/Wf4tUix2p1ZXyS9KGzZ5i5FsXYdgJN+Lr36egvKIS/WWM85Wju8AivhRlWWcNYSYQbwMGjd4PS1adiLPXPoPKzERZvcUBjxhVBpW7MWprbphLqNnvKQEVintK7H/nd03qisuGXYa3lr6FwupC9ErphQsGX2CIx71MsiYaLXW0GHLSBsUQLXznnXeeYYHjOD2G+++/3xBn3Datc3myYPu0adMwffp0Y9IIjzHQ4siu6jPOOGPHDvnLdNlNbaZnHuB+7jNFGK15nCzD7mVOKiksLATd87DLmV3Q7B5meTnGkKLRDEyDr4kTJyI9Pd0Y90gBSzF59tlnGxNYOP6QlkqmHx8fX1OWCy+80HCVw3hMm65+WF9aU4PLRsHIzxqUgBJomwT84mLFu305yn3GPGf4ZGkVrwjCcd0TMbb7BGO2s+gSI1R52iaD1lgrl8x47nXiv1Dw6jpcKetwz0vpIBZFKwYVbUHyoB3DqFpjvaK1zCoUm9Dy+3faH6PbjTa6m9Nj02G3Nh0nZyLfc889xmQPU6yZRTzooIOMMYL8TMudab2bMmWKMTGEgpGCijOLgwP30boYHOLi4nD33XejY0dZkikoULDddttthjjlblo22Q3OySqciMJ4LAcDZxzzM2c+U+yxu5jdzAwc+5idnW2MVWTXN8dczps3z+jmpvDlMbr54WQdCj5aNW+88UYj7j777IM33njD8A/JsY6cMc1yUhSybKZllNbKm266qV7rq5GY/lECSqBVEnCIC5zRxeV42e9Dgs0lfhKlF8Jdhn3EDQuDRV6mSDR26J+IIpAU64T/8OvhfT8PEwo3GA3m774/0secHFHl1MLsnoBFxEhUGetpxeKMXHbXmt27FFG0dFF42O1NF3u7x77zGWwCjtcLR947l6TtfCLP+rrwOZuaVtFu3bq1nQo3c00o0vlDhMMeNISGAK3z5Gn+4AtNqm0nFY+sG7zuqeMw1Z+LOTJTwikWqYOLSzGq62HofOqDu1RUr9FdkDRpB++hJSUlxiTGpiRUVVqAsnW/w+KIRXKP0aiy+vHOHx9h3rb56JrcGacMOE4mhWY3JYtWE5c8OYSLkzIjJdBQQ88ijzzyCHr27FlnsVpeFdUqxrZt24zuR77TUnXxxRfXTGowT2XXKd28MJji7vTTT8e+++4LuoKhKxkGCi46ejYtXsZO/aMElIASUAKtjkBAhEVB0mD8Zf2n2NfdQZw4y4SIcg9y00dg52l+ra5qUVXgmMQ0xAw+xKgzR5ffO/1O/LD0DaRKe/4qPVY/bJqOm8b9DeOyx0UVl9ZUWWs4C0tL3gUXXABOiqBrFnZFchZv7UCHzJyUQQfSnKVLcWj+CqfjZbpzoQsVHq9PEddOUz8rASWgBJRA5BJw2ixwTrgUs+PGwSMeJUoqqzG7w0nIGnlU5BZaS9YggbUFazFn9SfoIutyF8uwIneMDauKF+GSby/B84uebzCuHgwfgbBaFDkBYs2aNcb4N4rFf/zjH4ZvPk52MF2lEA3dt/DFwPFq9OXHcWwMFIxDhgwxrIimtdE4UM8fc7WQ4MPMK5w98Mw/uL7BZdPtvSOgTPeOW32x9Pqsj0zT9ivXP/kVyNJ8G0s3guO9OyV0Mg4M79sDzrMewXJxzO90ujCufx90St6x4tOfMXdsKcvaRJr2uTl4BioL4fCUId/mQJ7Nayz155Sx/VbxhfmSCMUpPSejfVzTPYc0rebNG7s5uDalxI0pT1iFIscL0v0KRSIDnUpz7BjXIeakirrCc889hxNPPLHGosi4XKGDky24tvC9995rjDUMjsvZwJyRSyD00UdLJl8mIIpEbnP1D/rpM/cHp9Hc2xwPYlpJmzuvaEif4y7YjrXbkm3NVWH4g4GrxmhoHAGOT+R3Rpk1jldjziJPco32771N3N78njMHTy98BtvKtyHZlYzjex+Pv4hzbc51HtAuDgPajxCkMplFJrbUdw3qNdqYq67x5/CZFOrvfLYzFcNsGfjOVoZYi8OYjOSXWUl2WwwqZUzq72tX4PC+WaCj9bYYTN3RGKNWS9Wfz8rdGcrCKhTp7iR4AgcHebLQvEDrChR7tEI+9dRTNYe5vJw5MJRLu9FlyxdffLHTzZdLwnEVEYoGTmSgEOWLn01AnInLY1wtJByBF1Awi3CUoS3lyWuI7VvXF5IikT8wuGyhht0TIEdOFOADmqLG/M7sPqae0RABuqfi9RnNQtEqXZBFVUV4fNYT2Fa2bccqV2Jx+nDJh+gW0w29U3vDJ2PZdhf0Gt0doT0/znsor9G6euH2PLUdMazOdFzQ/1xkL35V1uWWxRJkeT/OXDdkodWDqfPfwdC0fkiMi2lz9xleo9Qd5NkYcba3jPckHsvEdjafl/XFDatQpJ+84Id1cXFxgzOC3nnnHcO5c6dOO7olWKkuXbrU1I3LudFvH2cT0lJphgMOOAB8MXB282WXXWbMNqwtIhiHDdjSgY0UitllLV3uSM6vodlltds9kusRKWWjmyKd9Rza1qDg1lnPwFvr38LMopmIhwMusRi6xcK02ROD7cjFiGRaEhsX9BptHKfGnsXnEu+V7GULZUgafxYuclmw/pdb8VFyAihCnPJd6CSrt6zwb0dFYLsMLxgcyiwjJi0Ks0ic9UwjVUMGgLAKRTpzpqijlZBr99ISyH0Eye5iijY6nWagNYMrjNA/nxl4nJY4ns/A9YRpKaqv25rn8PyGQjhEBBuI+YYj74ZYtOZjJk9l2ppbUcve1gm8u+JdvLj4JbEqydrzFhkuItMrHXI/5FJ8ST5XW69+dNZPrIjWUefg4IIlmLPmXWRYnYgLiCP1gAMlTiuSYsMqS6KzTXZT67C2CNfsvfzyy41VOwYNGmQ4dH799deNIj/00EOghZErlDDMnj3bWJmDS86Zgd1h7Gqm02eKLTpv5koh5phH8zx9VwJKQAkogcgiwGVPP1rxMeKqPagWocCxatXsnrNaMKG0At1ytgF1u3WLrIpoafacgLT18AnXonvuLGwt3gi3NQblVjd6xY4X1zmxiI3ZhrF9MpGVoC7V9xxu6GOEVSiyOldffTX2228/Y2wgVwoxLYhXXHHFTmMV6SLnk08+AVcuMQMF4m233QauBUzrIkVicLe0eZ6+KwEloASUQGQRqPa6UVK0GZ0ri2Xclg0FMnWFInGwrMV3VWEpUpLTIqvAWpqQEkhL6ICbDnkE7y94Fjnlm2Fzd0NVjgtTt/5dDD92/L5wCi4+5mRkp+zoMQxp5prYHhEIu1BkaYcPH268gkveocPOU+Q5fjB43KF5Lpd+40uDElACSkAJtA4CC3IXYOb6GfBVi39Emx2dvR5j5RXpgcatW7agR+cD4ei9b+uojJZyrwl0azcUfzv0MWNI2OMfP4bp9qfgi5MJczK9pbBkMWYtzsIx+x621+lrxNAQiAihGJqqaCpKQAkoASUQ6QTeWPYGnl34nIxDLJfpK26UyizQgNeB9j4PTinIQ8dOh8BywuOwOuMivSpavhAR8EnbLy3/Cn7xr5gqY1O5rnCeqxTztnyJEfkHoUu6SpUQod6rZGTosAYloASUgBJQAs1PYE3xGjwx7wnklm+H21MlYxJlIp+MLz+sLB5X5TgxwNIZrmMfhDNl5x6l5i+Z5hBWAhZx0eLwwS4DVSkS/fInIH4WfyyYj4vfehzP/rRcuqPDWsKozlyFYlQ3v1ZeCSgBJdAyBMo95Xjo94eQW7EdsjqfBPpHDKDM6kOyrNLRqV03JB/zb7gyurdMgTSXiCHgcsZjZJf9UBoog9vvRq7NjW2yAE+lczuKnI/j+d9vwbd/bI6Y8kZbQVQoRluLa32VgBJQAmEgMHXDNPyycQZcVgfEkAiviMUqeQLZxX9et1GXIeWCNxE74PAwlEyzDD8BC04deQUmDzkD/qRkWeJPFkxgoWQWvN/uQGXCHHy+7F34Wt7NcfjRREAJVChGQCNoEZSAElACbZ3AvG3zxRVOBWJEJBoiQCpMdziTS8sxWCxKcOxYyrWtc9D61U0gOSYVf933Dty0752Il75ncQENl1wrvF64zGNe1ffi9L+q7si6t1kJqFBsVryauBJQAkogugnkVebhlhm34It1n6HSbpUJLBAHy35j2bYJZeW4zONCQvfh0Q1Ja19DYEDWYHSOy5QuaJ8xXtFjHLEixlIOmyzzp6HlCahQbHnmmqMSUAJKICoIcJ3mW2WptteWvoZSd6l0NdtkTKJFrEQB7FdegZsKy5B50L+ADHVxFhUXRCMq6YxNwxkDzoDdV41yuX788oqRiU/7dNgHTseffpQbkZSeEiICOuc8RCA1GSWgBJSAEtiZwHfrv8P3679HvD3O6G4OiI2oWvqdu1bbcf3WQmQNPh4YffbOkfRT1BM4ZPDZKK4qwOerPoHX58aYzofg1H3+aoxZjHo4YQCgQjEM0DVLJaAElEA0EPhEHvR+6WYODvzUy+tDVo+JcB70t+BDuq0EDAI2mxMnj70ekwefC48IxbSkToZ7nC8WbcX05dsR67DjqGHZGNk1VYm1AAEVii0AWbNQAkpACUQTAXY5v7/iA8zN+R0O647HjMxJgE+sianVVZiSPgoxJzwLuBKjCYvWdQ8JJMZn1cT4YN4mPDJ1LVJiY0Q8VmH2+j9w7/GDMLhjcs05utE8BFQoNg9XTVUJKAElELUEZm6diWfmPglnZRmc4jTRKyRoSaRbnKvy8zG0/3gViVF7dex5xd3iF+frxbno49yAbMcSmSEfgz9KB+ObP/JUKO45zj2OoUJxj5FpBCWgBJSAEqiPQG5lLh6X1Ve2VW5FrM0qazhDXrICi4xPvCxvO05O7AeMOK2+6LpfCexCgFdPR+tcrLI9gvmWKvkEJNuzkGa5W7Z67nK+7ggtAZ31HFqempoSUAJKIGoJcDziHb/cgZlbf4NFZjhXip/EapnlHC9CcaK4wjmmOgmY8h8gPiNqGWnF95yAEx5Y4r4xVmxJ8CXCJa/Nzm1IcP2454lpjD0moBbFPUamEZSAElACSqAuAgtzF2L6xumIsbuMwzIkUR7xsgqLCMhDfamIPfm/sHcfZxzTP0qgsQTc3ipscech3R4DR8ACl8yh9wRc2Fi9pbFJ6HlNIKAWxSbA06hKQAkoASWwg0CJuwRPLHjCmKVqkQc5RSKDWzYmllRinz4nwd5n/x079a8S2AMCDlm5p0t6f1QEKpDgFKHoCMAnzrcTYgfg/bnb8eIvG7Aip2wPUtRT94SAWhT3hJaeqwSUgBJQArsQMLqcf70D02U9Z5dthzVRhibCJ69kWXbtmNRhSJp06S7xdIcSaAwBq8WK00ZdjW0V27E+fykcdrsIx4MxdckguEt+lzGLbnw0JwPXT+mP8T3TGpOknrMHBFQo7gEsPVUJKAEloAR2JTBz2yx8ueYLmeHsrDlIoWiRCSznFVdg+JQrZVxias0x3VACe0qgW2pv3D35RazIW4I4WRf8pyUx+GXls8hM+hV+ixfe0j74YvaFGN19X3HJtKep6/kNEVCh2BAdPaYElIASUAK7JfDmklfg8XtqrImM4JEu5wMLS3Bi/wth73fwbtPQE5TA7gjEO+IxXJbyY/i54kXkJX6IfFusWBTFppg4AxnuFPi8E+CQ7mkNoSOgujt0LDUlJaAElEDUEZi95TfM3vyrWHH+tCbSZ2KKdDmfmTweaZNk9RV9bkfdddHcFS4Tdzke2JDscyLRb5dJLnEosv8hkrG8ubOOuvRVKEZdk2uFlYASUAKhIfDDxum4/oe/o8Qvvu3EFY4ZvOJe+7hyC0YedDNs8eISR4MSCDEBe2wi7DK0wfK/684u15zNFStumbSjNMSoRXxrUAJKQAkoASWwhwTyqwrwwMx7UFi2DbFi2TGDVR7Yo0vY5Xw+XF0Gm7v1XQmElMCkPschVn6ElFtK4bZWoMruxcF9T5CJLjEhzUcTgwhyDUpACSgBJaAE9pDAkq1zsKloDeJlAovVH4Cs1IdyMT30L/fjvMzT0GXSNXuYop6uBBpPYGD7Ubjx4MfwxfJ3UeWtwMRuh+KgXkc2PgE9s9EEVCg2GpWeqASUgBJQAiRQWp6Lb5e9IxNWLKgQ1yVx0gXIh4ndW439s47D+KPuhd31Z1e0UlMCzUFgePYY8KWheQmoUGxevpq6ElACSqBNEfD4qnHHD9fhy80/wiZuStxSO68IRrt4TRxd4sD+Y09CjIrENtXmra0yOSVVmL48F+VuLyb0zEDf9omtrQoRVV4VihHVHFoYJaAElEBkE1iwdTa+2zoDyVZZSM3nN9ZyrhLj4cQCH47tcA569xsZ2RXQ0rVpApsKK/HPD5ZgS1E1bFYrPvx9K26c3Afje2W06Xo3Z+V0Mktz0tW0lYASUAJtjMDm/BXw+n3GLGebWBJjZXyiVbqc+/Y+E2OO+YdMJtAu5zbW5K2qOl8sysHaos2Iz/gVrvRpqLJswDtztsg1SxfwGvaGgFoU94aaxlECSkAJRCEBLtXnrbIiXvzWVcnTwyEM3GJuiK2KwfB+xyI2lns0KIHwESio2gxv3JPweTbAFgjA7UpBke9SBAJDwleoVp6zCsVW3oBafCWgBJRASxBYsH0eHp3zH6zMXwW/wwGrGGisAQ9S3QH0tR6Dfh36tUQxNA8l0CCBhPhfkRxYhTRvovh5t8g1WiA/YL6R8YonISVWXec0CK+egyoU6wGju5WAElACSmAHgdzyHNw07VpsLF4v7nBc8MvYrxgRigNKxiDBtR+On3QUUhP0caLXS/gJuCvWylKSXCXIDq9YFG3WGGzbvg4XvjgDk4f2wlnjusg+HR6xJy0VEd/sBQsWYPPmzRg9ejQyMzN3KX9lZSVWrVqF6upq41hcXBz69+9f45F969atmDt3Lrp3744BAwbsEl93KAEloASUwN4TmLNxOjaUrEeKXVa+kGTsskZfaaAag4bvh1NGXYCshL1PW2MqgVASGJ3QHdOs3wJOKwIePyoCXiSXx2N+XiFmb1gGp4yhPX1Ml1Bm2ebTCvtklsceewwXXXQR3n33XUyePBnLly/fBfr8+fNx+OGH48EHH8QDDzyAV199FX6/3KkkzJs3D1OmTMEHH3yAs846Cy+//PIu8XWHElACSkAJ7D2BQKWsqhuwiLfEHcs2yyYCcg/ulRyrInHvsWrMZiAwfsDJOCdpIFxVxajyVcAreWyJl+7nLs/BmTINz/ywGrmlO4xOzZB9m0wyrBbF7du349FHH8WHH35oWAJvvvlm3H333XjppZd2gk1L4qBBg/Dmm2/utJ8f7rjjDpx22mn4+9//jt9++w3nn38+jj32WCQlJe1yLndYpctEgxJQAkpACTSOwMr8PzAvZxFcFqex8gonCPgtHmRXZyArbp/GJaJnKYEWImBL7ozjpzyNw1dPxd3zX8G3njUI2MplkckSxLdfh81brbjni/b493F9EeNQPdCYZgmrUFy4cCGSk5NruotpGbz88svh8XjgkMHSZqC4Y9fzfffdh6ysLBx33HGGECwrK8PKlSvxf//3f8apI0aMMOItW7YM++zz5w2MQrO8vNzoqi4sLDSskaZF0swjnO8sS4A33/9ZScNZlraSt8lTmYamRfUaDQ3H4FTMa9QiLmYiMVhkxZVZ66fhwR9vgruyBJlWJ4rlHpXqtiOuuCN6tzsf/br1ipj7ll6job2KWjNPS0IWHINPQM6qt+Er4nhaG9wWyh0/Etp9h1+2ZmPe+nYY2zPVePaGllz9qZnf+Uh6LpntXH+pw7zWc35+PhIS/hzcQtHodrvBMYnBQrF9+/aG1TAxMRFfffUVnn/+eXz22Wfwer3G+ab10Ol0IiYmBgUFBTvVefr06Ua3NW/IFIwMpaWlxjsbLtyBDWWWK9xlaQv5s535I4LXEH8kREIbt2au5EmO/AHXmJtKa65rS5ad9yDyjMReDra5VUYjfrvgfSRVJyDBniVoAihFFVKdQzF42FmYOCgbvsoiFJWH/x6q12jor9zW/Fzijxy/pwLt/EnIsneE3WoXicgrWHoVbT4Ekn5GTtEYlJf2hVt0REuESHwusUw+n8/QUtyuL4TVouhyuYwCmoXjg4g3TZvNZu4y3vv06VNjNbz00ksxbtw4fPHFFzjmmGOM8xmPgRc2K03BGBwOPPBAjB8/3ti1du1a3HDDDYZFMlJu0Cwzy5KamhpcbN1uAgHyNH84NCEZjfo/AlVVVTv9KFMwoSHAH7m173ehSblpqfChsVhWYJlZPEeW6KtEus8JlwxMzPOXo0+33jj7sMGSQUB+hDUtn1DG1ms0lDRhPEtb83PJEkjCgVk98Pnyn+FzxkvX8w6hmCQ6IVZ8LFY71yMhabSxt6Wu40h8LlE30ajSkEElrEKxZ8+eyMvLMyyIsbGxRjdyWloaOKu5oUBBVVxcLL6RYg3BR/HHtIqKilBSUmLMfg6OTwimhZJWSd4EzVfweeHaNsvCdw2hIaBMQ8PRTEV5miRC9x7JTGes+w63/3QDcnzl4jNRRnf5q9DeZ4dH3tPQ838QeB8NHY+mphTJPJtat3DEb/U85eI8YPiFuGLTDDxauRZ+uxOJsjpLhtcKnyWAb39eijWb1uHSSR2REvfnULfmZB2JTFmm3YWwCkW6uOnWrRtuv/12wzp4//334+KLLzZEHGcvszv2sssuw7Rp04ztzp0748cffwTHID7yyCNG3U455RTceeedoADkbOhhw4YZadZXcapnDUpACSgBJVA3Aa/PjSfnPYoN7mIkyHrOXPmsTLxrl3rd6F21D4Z0P6HuiLpXCUQYAUtSNs4/6D9wfHouXvOWwikTsmTIokzdt6Mav2POsljEu47CVQf1ibCSR1ZxwioU7XY7nnvuOUPo0e0N3duce+65BqHgbml2zb711ltg1wKtiXSl06tXL+O8q666yuhypsjs1KkTHn/8cUNoRhZmLY0SUAJKoHUQ+GHdN5hfuAwuEYleMTZwBRabrMCSYR2Ewyfch5F9OreOimgplQAJtB+Es4ZfhqGzHsF78OBnlx/tfBwzsRKVievwW44dF3v6wNUyRsVW2SZhFYoklp2djSeeeGIXeLQUmuHggw8GX3UFik26xtGgBJSAElACTSPg83nw2eLXdoxXEsuL9NBJNx3gkQH/IzqNxtH7qEhsGmGNHRYCI8/B0KyB+PHXe5BcsASxlngEpMu1vQytsNp+gg+XSbF2ntsQlnJGaKbqRChCG0aLpQSUgBJocQIyqr+kJBepfqthSfSLSPRZ/Ih3p2JI+8ktXhzNUAmEhADH4XXZB9b0rtLlbBW7ooxTlMlZ1dIVnR4ogSPgDkk2bTURFYpttWW1XkpACSiBPSBQUZmPbxe+hFJPGewWmwz6tyBWBigmVVXi0NRjsN+AIXuQmp6qBCKPwIGyvF+C1Ytih4y5lVe11Y0jErrC4Wx4Am3k1aRlSxT2rueWra7mpgSUgBJQArUJVFeX4N/fXYXvt/4m7npi4BMTgjXgR9cKoJNlX5w78QLEac9cbWz6uZUR6N//ePxr/c/4qmARKsXIOCm2A3p1OQ1vzNqKksoqTOiVgcEdk1tZrZq/uCoUm5+x5qAElIASiGgCX6/+DJ9s+xVxjljpktvhby5WBicekH0NJo46Dr06pUZ0+bVwSqBRBJI7YshfHseQ1dNk4G0FcpJG4Orp1Zi7dgk8PiAhZi3++Ze+OGW0jsUN5qlCMZiGbisBJaAEooyAR9zhfLTyA/hl9QrOcJYeOVTJLJYysbgcPH4f9GqvIjHKLom2Xd34TGDISUYdP/ppI+asmw9XwgbEuUpRWZ6JB74CxnRLQ/fM+LbNYQ9qp0JxD2DpqUpACSiBtkbgjXlPYH7eIpn96US5VI5CUaawINaWjVhHdlurrtZHCdQQyCvLgyXjTdhTF8uClR44U+NQWjAJS7cNUqFYQ0mGoQRt66YSUAJKQAlEEYHtJRvw5ZI3kSKDEikQaVGsFkuixevB+MRD0CE1PYpoaFWjjUB6xgI4E+Ygwe9BkkzxT5CfSo6UaSjwrIk2FA3WV4Vig3j0oBJQAkqg7RIoLstFWVUZMv0uZHkDSBD3OLHy0OxTJc61B58ha9C33bprzZRAcdlCmeEfEIfyVlmByCKrl9tgt1Xg2Tlv4pmflstiHsqIBPQ2oNeBElACSiBKCcRI93JqIBXlNg+SfU6kikuc9pXAYf0uxghdgSVKr4roqXY/P0ffBeC2ik9FUUPl8m4R11CVttl49o+HsXBLbvTAaKCmKhQbgKOHlIASUAJtmUD7zA4Ymno+UstkpYpAlfhMtKGf9XgcNvpA2PXp0JabXusmBA7pchCOr/KDKxJViFNuh4jGrjL9ubfPLyMWF2JR/kLlJAR0MoteBkpACSiBKCRQVLoFXr8bJx50Kqw/98fmrUuQlNYBh48bh/YpMVFIRKscbQTieh6Mf438G/ab9wzuRhXaeQGXyEW//GjKtJQhMUYciWpQoajXgBJQAkogmgh4vdV4e+YD+HrVZ3CLUNyn42icc9BtCNjGIMFlgcsWTTS0rlFNwGqDZdzlmNhpFL76+mLMt/qRKN3R5TKrK8PiwT4Bca6oQYWiXgNKQAkogWgi8O3St/Ho4ufhssYgWR6KX6z8Fk5bAq446P5owqB1VQI1BGxJHXF5IBEvOjxYIUMuOss4xVNLfci2xtacE80b2vUcza2vdVcCSiCqCGzKXYIn5j+OaodLfOBYkCsWk2RLDOZu+R1V1WWIcSVGFQ+trBIwCCRlo3O3g3DL0o9RHJuM+Opy2NMHiWLcRwEJARWKehkoASWgBKKEwNd/vI3tFblwOmPhEp+JHql3vsx47mpJEuHoiBIKWk0lUIuARcyI+/8dSO2K5M2/G+8YegoQq6sSkZQKxVrXi35UAkpACbRVAltLc5Ei3c2V4ly4SsZh+WWWp99jxZhepyLGqRNY2mq7a70aQcCZAIw6d8erEadH0ynqACGaWlvrqgSUQFQTyIgbLW5w/MjwWZAiPhNd3kqMiD0Ax40+Maq5aOWVgBKon4AKxfrZ6BEloASUQJsisH+/45BlOwGV7gTY3DHIrNoPxw39K1LidKpzm2porYwSCCEB7XoOIUxNSgkoASUQyQT6Z6fguik345tFJ6LSU40JvXtj/36ZkVxkLZsSUAJhJqBCMcwNoNkrASWgBFqSwJBOCRjSaWRLZql5KQEl0IoJaNdzK248LboSUAJKQAkoASWgBJqTgArF5qSraSsBJaAEwkzA53OHuQSavRJQAq2ZgHY9t+bW07IrASWgBOohUChrOb/3++NYnLsA2QnZOGHYRejZYVQ9Z+tuJaAElEDdBFQo1s1F9yoBJaAEWi0Bv9+Dx6ffhOkbfkCSNR7Lc1died4y3H/Ua0hP7tJq66UFVwJKoOUJaNdzyzPXHJWAElACzUpg9fqf8N2231AdE4tyRwAptlhsLt2K30U4alACSkAJ7AkBtSjuCS09VwkoASUQ4QR8ngo8N+e/2G4LIFaWJquS1VcqrB64/AFUuS0RXnotnhJQApFGQC2KkdYiWh4loASUQBMIfL/8A3xT9AdcFjs8ogs9FgtKLT447dno12FiE1LWqEpACUQjARWK0djqWmcloATaJAGvtwqfLX0XAasF8QEgTqyIFnkPeL2Y3Pts9O7QtU3WWyulBJRA8xHQrufmY6spKwEloARalIDb60FxZQkSAjZ4xAxgF5FoCXiQ5eiKY4ceBTEualACSiBEBLx+L37b+isW5y1Gx4ROmNR5EhKdiSFKPXKSiQihWFxcjPz8fHTv3l1uZHXfyUpLS5GTk4O0tDTjZSKsrq5GVVUV/H6/sSsuLg4ul8s8rO9KQAkogagh4HLEo3PCCGzO+RAxiJOxiT7Y3V6cNOBsdEzVpfqi5kLQirYAgQBeW/oaXl3yLrxeu2gXD2Zvm41/jv0nXLa2pUHCLhS//PJL3HHHHUhKSkJsbCyee+45ZGRk7NTIr732Gp555hkkJycjLy8Pp5xyCq6++mrjnNtuuw3Tpk1DamoqAoEALr74Yhx77LE7xdcPSkAJKIFoIGCzWXH8iKux6ftSbK+aBxcc6J92OI4eemI0VF/rqARajMD2iu34cPnnQFEVUv0+FNls+HrNNEzsNBGHdD2kxcrREhmFVSiWlJTguuuuw/3334+DDjoIF154Ie69917jc3DlBw8ejJdffhndunXDH3/8gcmTJ+OQQw7BgAEDsGXLFpxzzjlGXI/HA5s0lgYloASUQLQSGN6tC+445lHMX78GMQ4nRvfojtR4HY4erdeD1rt5CORXlqO0aCu6+AqR6wAq5StW6QngwTkPwWax4cAuBzZPxmFINaxCceHChXA4HDj00ENhtVpx9tln4/rrr4fP59tJ8A0dOrQGTf/+/REfH4+ioiJjH7uqly9fblgVBw4ciA4dOtScW9cGrZaMU18Xd11xmntfpJWnuevbEukr09BSVp6h5cnUQsm0In8VFiz/BFXiGmdwj4PRpfNYdEnvH/pCR3CKoeQZwdVssaIpz4ZRp1f70buyAitcFhTbRFOIGypOHrNXe/Dastcxsv1IJDuTd0okEpmyTLsLYRWKtAayO5kikSEzMxMVFRXGKzGx7gGhTzzxhNHNPHz4cCMOhePSpUvx+uuvY+bMmbjzzjtx3HHHGcfMP3PnzsUHH3xg3JgLCgpQVlZmvAiI3dXhDhxfWV5evpM4DneZWnP+bFfy5PhVWpkjoY1bO0+OA/bKzNlI+c60Zp5m2XmNkqd5/zP378m71WoXq8Y6PDftX1hYvgkBSS975VRcNPZ69OixH7w+754k12rPJUe9RkPbfPpcqp+nxWpDjLsCZ7nT8IITWGfzIUFEYlbAB4s7UZ4/fmzO3wJHogM+2ccQic8llomGOb64XV8Iq1DkDTL4Ic5tFra+An/yySd47LHH8P777xvjGVkpdl2b53P/LbfcgiOOOKLmOM9JSEhA165djfNoUVy2bJkhyhgvOH+eG47AcpCFdpuHhn4wTzKNhDYOTc3Ckwp5kiMfHMozdG1gfuebIhRt8sD66o93MKd8PbJsSWLVALYGKvHm4pdwY9exsNnlKRYBP4ZDR63ulPQarZtLU/YG30ebkk6bjMt7YkY3DOs8GOeunoHHMuxIkRnQcWJ32u7ZDEe5A8mONFhlzLAYGo0QzDNS7qMsU2NCWIVi586dUVhYaFgq7Ha7Md6Qoo4zl2sHTli54YYbwIkt7GI2Q3BFx44da1gj+UudgtAMffr0AV8Mmzdvxi+//GIcD45rnhuOdz6Aaa2pq97hKE9byJM8nU6nzoAPUWNSzLjd7p2+VyFKOmqTIU9+55siFAlvXfUWeP2lsgKL33gmBcQdznr3FlicVsS5/rwPtnXQeo2GtoX1ubQ7nvLdOvivGFmxBUdu+xkfpiSjXMYmpvitOHbLOqSJNT8mNmanRCLxuURDyu6E644+352q0nIfhgwZYjzIKf44m/nxxx/HlClTjBsnhSFnRDNQ2J1xxhnG+MWOHTti/fr1Nd2Ks2fPNlzrMP6DDz6Inj17Gl3T9dWC3RMEE0lWpkgrT33sWtN+ZRra1lKeoeXJ1ELFtE+7Maj0u0UmSveR/C31l8v4RLm3iqucaAqh4hlNzBqqq/JsiM7/jqV0hWfQKZhQmoBLirNxUWkWripMlwku1fh521T8kb90p0QikSnLtLsQVosirX4cc3jTTTcZlsLevXvjmmuuMco8b948cFY0u5F//fVXo+v4o48+wrvvvmscv/vuu41Zz3Sbs2HDBuOmSx+LTz/9tHbh7q7V9bgSUAJthsBxw8/E8oI1+G39F/CLNbFr5hicM+Ya2P439rvNVFQrogQikEBCj1GoSm+P7II8+GX4x+y4CnyUDuQseRxYAhzQ+QD8fdTfkR4rO1tpsIia3L2cbObK0cTNCSb0pdhQ4HlmCO6uoTNuVmN38Rl39erVuOqqq/Dpp582ucvHLEtT3zmQlKKYviA1hIYAebLrOSZmZ9N/aFKPvlRoiWdXaWO+Y9FHZ+9qzGE35Mlun70JC9ZNxYy13yDemYD9eh2LSosLFZ5y9M/sK/ui77rXa3RvrqL64+hzqX42tY8ENvyG6l+eRl7RRtwYW4LFtoBMEBErf8CPal8VDu16KB484EF4K7ywO+wR9VyirjryyCPxyCOPGD2ytevGz2G1KJoFouhrzAMoWByacfle3wzp4HN0WwkoASXQVghMW/IWHvjpdmOCkU9GJn678nPc+5cXMKDj0LZSRa2HEmg1BCxdxiKm82h4pat589S/w12+RXwpBmCVfzG2WEzfNB1TN0zFQVkHybc17La5PeYa1jGKe1xajaAElIASiHICHm8VPlj8uqzjbJWZzoloL91duRX5+Gjhq1FORquvBMJIQCayOJ1dYK12yUhhWfxDBKFNxg0b057lu7q8YDns4s7Kaml9sqv1lTiM14FmrQSUgBIINwGP14OS6jI4LTs6hOjgwiEPqYKKonAXTfNXAlFNwF5WirO2FSBD9GGlfDEpEx3S/WyTaWYLchfik1WfoFyGh7S2oEKxtbWYllcJKIGoJhDjSsCA9DEoNFzieMUlhxtVfg+Gdtg/qrlo5ZVAuAmkWktxdFUe7swtQi+3Fy5xwk2RFSc+FnMKt+CD1R/iiflPoNzbusSiCsVwX1mavxJQAkpgDwhYxUnumeOuxeiso8R3Yqx0cKXiyF6X4ZBBx+5BKnqqElACoSbgSO2I+A6dsX9ZMZ7amo/Li0rFuuhHf1klrIO8d0joiAXbF2D+9vmhzrpZ04uIySzNWkNNXAkoASXQxghkp2Xh9mMexrrcLXDJ6itdMzJkjbA2VkmtjhJobQRcibDvezX8m2ahk6cKB4rh8LOEWDjky2mTmdB2GZ9oOIaXGdGtKahFsTW1lpZVCSgBJfA/Ak75md+nQ7b4TVSRqBeFEogYAv2mwDLmElnj2Y9Ujxc9pQt6i9OFYosTm0u3ocRditVFq2WccUnEFHl3BVGL4u4I6XEloASUgBJQAkpACTSSgOWgW2BLyETiquk4uzgXb1QU4wdsQGJVBeJscfhYJrWsK16Hf479J2Ltkb/MploUG9nwepoSUAJKQAkoASWgBHZLwBkHTPwbSo9+Fk5vKo4qj0VKwIHOPifaVZYi1RKHOdt/x+85v+82qUg4QYViJLSClkEJKAEloASUgBJoUwT8FQWwlm9FoSNZXOVwBrS4sbL6sKpkFXLKc/DMwmewrGBZxNdZhWLEN5EWUAkoASWgBJSAEmhtBGIT02CLTUBWVTU6e53Is3mQYw+gOuBGvHQ5F1UX4akFT8kyf9URXTUVihHdPFo4JaAElIASUAJKoDUScCa3R9p+FyLLUomr87ZjYHU5vDLz2eX1weKuQnF1MWZtmxXxXdAqFFvj1adlVgJKIHoIeFuXK43oaRitqRLYPYHEsWcj6ei7xI+iFRMrE5Hst0kHtE2siFWokhWWqsSa+OzCZ5FXmbf7xMJ0hgrFMIHXbJWAElACDRFwb1mI3z66Ch++eTL++OYOoDSnodP1mBJQAhFKIJDeG/nWZCQFUtDfbUWJzWv4PbX43eiZ0BkFVYUR7YRb3eNE6IWlxVICSiB6CVQXb8EjX12D76s2wG9zImbFApxdvhnHH/kYYHdELxituRJohQTi0jvB0n4gUnLX4wxPEZaleOCzOJDoCyClrAhbEpLEcX5MxNZMLYoR2zRaMCWgBKKVwPwNv4hIXI9UezLaIw52ZzLe2f4rtuevilYkWm8l0GoJ2J0xyDr6/2DpOBhdqsqxb6UNHqsTFVYHNpRvR9+kHhiYPjBi66cWxYhtGi2YElACUUmgJAfrV3wB8aWBWH+lONVwIl6WAMuHH/n+amRFJRSttBJo3QQS23XH9iEnYPuGWfhLuQ3dPLlY4QScfj9Ky4tx2293ICsmHSf2ORF90/pGVGXVohhRzaGFUQJKIKoJeD2Y8fVt+Gjrj8iJicFyZwCl1iqU+UqRntILHdN6RjUerbwSaM0EAul9sN2ZjR5Vq3F8SQn+ll8mVkU75mydheLizZi3fT7unXUvCmXMYiQFFYqR1BpaFiWgBKKaQGHeajyX/wtKXInI9jrgttiwWvyu2VN74Yp9b0aS7NegBJRA6ySQmZEJ/8jzUBCIQ64lBbOc7TArNhGZPgsqKvJR7C7BgtwFeHHxi/D6ZcJLhATteo6QhtBiKAEloAS2SVdzjsWLdH+MdDZbkeS3IkfcaJw4/ESM6ThKASkBJdCKCdhkfZa+A4Yhf3ZnwG9BpSsODuTJLGhgi0xysfjt8Aa8ePWPV2GTH4mXDb8MDhnHGO6gFsVwt4DmrwSUgBL4H4GMlM6ITegIt68M9oAHThGJNlscumUNUkZKQAm0cgIBvw8JGR1RPfhMWGRVlj5lWzC63IMtDgtkHjSqvDImORCAx+/B84ufx0uLX5IxyoGw11qFYtibQAugBJSAEthBIDMuA6ePvQYlyZ2wyerHFumWOmTImRjeYYQiUgJKoJUToAi0y6S0vodfhKqjnsKWQZfgZEcvTKryo0q6mgMBvxyHWBHlLNn+cu2XyK/MD3uttes57E2gBVACSiDqCVQUAEs+BgrW4bjO+2Dw4Y9iSdEadIzvgOHthsJu0Vt11F8jCqBNEKBYjHNYMXz0vvD374uKF97H8LIEfB1bKRIyYAhEvlssFmyv2I6NJRuREZsR1rrr3Ses+DVzJaAEopqArPsKdzm2fnwNvtn0A9bbbBi06E0cMepS9J50TVSj0corgbZOoKI4H7lFRejpTMVAdzUWOr3SAf1nKKguwN9++Bv+b8L/YULHCX8eaOEt7XpuYeCanRJQAkrAJGC1WlD2+8t4IOd7vJwag98THXgm0YLHl74Nb8FG8zR9VwJKoA0SyBdXOaul67m3eyOuzc8VsegWq6IEsTpSnLlsLmwu24z//P4flHvKw0ZAhWLY0GvGSkAJRDUBsSZainLw2eJ3MS0xHuU2OwqsNjjlrvyzpQQbK/KiGo9WXgm0dQLt0pLwe6/LscabglGVRTi3qBQumbtikYrbRCxaxBm3U8Qiu5+3lG8JGw4VimFDrxkrASUQzQSsIgq3y5rOH/pkzRUZvO6SweviMQM5NitKnQnwxiZHMx6tuxJo8wRixC3OKYfui5KOE8VdTgDDqjyyYosHVTI+kZZFTmjxBnzIjMtCWkxa2HioUAwbes1YCSiBaCYgzwjkOOzIs9sQK9YDt+E5MYAKmf3Yq+NodEvtEs14tO5KICoIdE22Y3vHQ7AtIOu6izusf+UVYnRVlbjLAapFLqZbHLhg8HlIl+X9whV0Mku4yGu+SkAJRDUBn/hUy0pqj5SUHrDlLUO1CMZyGcrucCTjbPGzRhcZGpSAEmj7BHwdhuNG3yX4j+1JDK8qw/Nb3ZgT48I2uxWDfVb0dWSGFYJaFMOKXzNXAkogWgmwW6ldXCbOGXEeHGk9ELAlIzG2Pc4adBbGdBgdrVi03kog6ggcPSgDOV2PxK3VZ8AjrrBi5N6wb0UVTiipRN9S8aM447/SDx2+Jf3C/pO1RBbGfvTRR7Fx40YcdthhOPbYY+u8SBYvXoxnn33WOHbRRRdh4MCBNee9+eabmD59Onr16oXLL78ccXFxNcd0QwkoASUQqQR8Mv7oyB5T0D+tL5YXrkCH+PYYmjnEWL4rUsus5VICSiC0BOJddtx0SDdc9fJELPF9h2G2VSIYnTIYJQCvNQ6uTb/DUrwJSO0W2owbmVqDFsXZs2fj8ccfN5aUaWR6e3SaX2b0XHLJJVi9ejUOPfRQ3Hrrrfjwww93SWPz5s049dRT0bt3b+PF7S1bdswAeumll/Dggw/iiCOOwJw5c3DttdfuEl93KAEloAQimUCvlF6Y0n0yRmSNEJEY9t/vkYxKy6YE2iSBcT1Ssd/w/vjYN1bq55dBKFZUwiVjFW1wB+SeEEaLYoNCsUOHDvjkk09w9NFHY/ny5SFvnJUrV2Lu3Ll4+OGHcdxxx+Ef//gHnnzyyV3yeffdd9G3b19cccUVxouC8b333jPOo5WRAvOYY47Bf//7X0ydOhWbNonyrie4XC7D4zm9nkdKYFkiqTyRwqUp5VCmTaG3a1zluSuTpu5Rpk0luHN85bkzj6Z+Up5NJbhr/IaY2sSn6j8O7oElqQfhD28XWcrPC4es924PeDHVOwib0H7XBEOwh2XaXWjwp2unTp3wxRdfGFbFM844A2effTZGjBhhpMllaLp06YLOnTvvLo96j1N8ZmRkICkpyThn0KBB2LZtGyorKxEbG1sTb9GiRRg2bFjNZ26zK7pKZgYVFBSgf//+xrGsrCwjLVooWXYzrFu3DrSOEgjTZzy+GgPITGNP3sndbrWi2iu/CeSDQ9xdeHyGG806k/H5fKiurjbKVOcJunOPCbB9yZXXqYamEyBPj7htcDgcTU9MUzAI8DvPe51NVmPR0HQCeo02nWFwCvpcCqYRmu3dPZfaxVrQp3c/3Dr7Elxgn4YsSzGW+jvjXcsR+L9KLzKrq+iLO6SBPbt8NaSHGhSKLA1vYldddRXWrFmDG2+80bDs8eHLi+jSSy/FxRdfvNeFrqiogNPprIlPax/TdYt38mChWFs4cgyieVPg+YzHYBVxZrfbjZtvTaKykZOTg59++skAUSTL5fCBxxfBNIeQqPb68eHcTZi/sQhOmbV0QN8sHNi/Xb15sZFYHtZbQ9MJsF3Jk4HXb3O0cdNL2XpSMHma3xvluXdtZ5dZzFZxsm3yM77zHjdsPhWKe0f0z1h6jf7JIlRb+lwKFckd6ZjXKD/V91zy2yw4aWgGrlkxCLdU9UK8zY9Crw2H9m2HjgkWQyOEUiiyTNRQbOuGwm6F4pIlS4wuYVoSZs2ahZ49e9ak19RfwikpKTuJurKyMsNiERMTU5MHN5KTk0GBZwZu0wrJ81guxmPwer2GZc60UJrnjxkzBnwxcLwjxW1iYmKDCtqMuzfvL/+6Dm/My0NWoowvqPDj8Rlb0CEzDft0T6szObORape7zpN1Z6MJ8EeI+SOi0ZH0xDoJkCV/yPB7o2HvCKwvWYdvN3yPCm8l9s2egJ4JPZCYlGj8wN27FDVWMAG9RoNpNH1bn0tNZ1hXCrt7Lg3pkYjbjpOlPH9cg22l1ZjcLRWXH9ALqck766K60t6bffzhSgOb+QO2rjQaFIrz5s0Du5w54YSziWmxC2VgV3N+fj42bNhgdGP/8MMPxmQVPtxpbeSFmpCQgHHjxoGTVswwY8YMnHeeuJQQkcgu5p9//tnofl6xYoUhPDmesb5ASySB8EU1HepAa+IPy3ORJktyxcezm86CioJyTFu+vV6haJYn1GWJ5vSUaWhbX3k2jec6EYl/n3Y9NuWthUuWX/k84VNcPeQaTE4+vGkJa+waAnqN1qAIyYbyDAnGnRJpLNOJvTPAV6XHh1hH8/Y4sEy7Cw0KxbS0NGPSiDkGcHeJ7elxirxTTjkFZ555JsaPH4/PPvsML7zwgpHMPffcg+LiYmOCCl3mvPrqq4Y45EFaMjl5heGGG27AlVdeibVr12LatGm44IILkJ6ebhwLxx8Zj4o4px0dCmdiQtV8VFti8ZlnNGJdHcJRHM1TCSiBCCDwyuLX8EfufCRYHQjI721fcQmmr52GQ/odLGOYG7wNR0DptQhKQAmEg0Bzi8TG1qnBO1TXrl0bm85en3fHHXfg66+/xvr16/H++++jT58+Rlq0ZJrjzNjdxVnOn376qXHsqKOOMiyN/HDAAQfgrbfewi+//IK77rrL+GycFKY/nLhyeYdlsC79L6x2l0xw96G35VekdxkUphJptkpACYSTwKqiNfh27ddIkB6MGPGL5pYfk1XS2VBdmiPDZWSCkArFcDaP5q0ElMBuCDQoFHcTN2SH6Wi7djAFo7mf4xQpHusK7MLmKyKCz4P+BV+jtF0WtvqS5LEAtK/ejIVTX8PbGy7FKSM6oH0zjTWIiPprIZSAEtiJwIItCxAo2w6r0yYOL6RHRG4KFSIWu/sgKzDsdKp+UAJKQAlEHIHQDjqMuOqFoUCy9A6qKpAYH4c+7RNFKIpLnnI/Vq7fhHu/XI4zXpqHrcVVYSiYZqkElEA4CLQXFzipPi86evwyPpEWRQu6imucA5K6Aa7mGaAejnpqnkpACbRNAioUQ92u0t2MXgcCpeKvsTgfFUU56IjtOMb+G96Ouw/Jm3/AW79vDXWump4SUAIRSGB2zlxM2/4rvDYnyq02JIlQ7OiuxjnVMejW5zBZfyH0E+oiEIMWSQkogVZMICK6nlsxv7qLPuJs2W+RcYrfoEdgEWItlRD3zxiPeejj+APfFNFBePe64+peJaAE2gSBL9d+j7t+ugu2ylIkiQ/FSulfGFdmxaEVPgzc51xUdpD16v3S/6wOt9tEe2sllEBbJaAWxeZoWYesKjPmYvgOvQUWhx3VPpcIRRcKA4lIDpThePsvzZGrpqkElECEEHDLuqxvL3oH6WWb0StQhfYyODFWup8TMntgxNEPwnXg3xr0WxYh1dBiKAEloARkOUENISdQ5a3C+ys/wC/rv4MrLQb7l9owsVwsB9LLZBX/ObEO2dagBJRAmyXgk5WBAsXr4RJLotviFJ+tQMDvRm6G9CYMOnZHvQMVbbb+WjEloATaDgG1KDZDW76/8n28vvQ1lEtXU05KR7yQ4sAKWX4nJVAKV2IKAv2OQoVHpzs2A3pNUglEBIGArxLdZWJbntxhqywBFFllEovcD/b36vc+IhpIC6EElECjCahFsdGoGnei2+fGL5t/QVZsFpKcSUBGP3hkDNJcWzHGdhyLeRlH4625qSj4/md079gO50zohg6J2gyNo6tnKYHIJzB/+0L8d85j2O7Nh1t84eTJ+OQelX4cWQns33V85FdAS6gElIASCCKgCiUIRig2rRYrYuwxKKouMuYzBuxOVLsSMCcpG292OQFLfinChE23ooMtH4s2dcHr5efj8qP3M7qmQpG/pqEElED4CJR5KvDAL/9BweZf0U5WYcmw+Axr4pXWVAwdfSbQ/4jwFU5zVgJKQAnsBQEVinsBraEodhmTNKXHFDwy9xF4/B6Ue8qxqXQTNohrjF9zFmKw14sTY12IDSTgcMs8/LHqcWzLH4KM2OZdz7GhMusxJaAEQkNgQ+F65OYuQHubCz4Zm2iT7me3rwyrx5yLoSMvCE0mmooSUAJKoAUJ6BjFZoB9YJcD8a9x/8KwrGHIqchBjCMGya4kpMvD448YPz6Oj4XHEofttvbo7l2NNE8O/BYVis3QFJqkEmhRAiluN5JkdaZS+cEo81dQaZVFPAMBZAWcLVoOzUwJKAElECoCKhRDRbJWOiPbjcRp/U+DQ7qf7Ba7DGMXd2ny0LDJ42OzxY0qGdTOtaxT4sS6aPWI411tiloI9aMSaHUEOiRn4zhrlrjEqsQWqw/5qMYkjxMj03esYd/qKqQFVgJKIOoJqDppxksgMzYTXRK7oNRdKq4xpAtK8vJJV9QEdyWyLEXoiY1IKVsN9+unoey7e+EpK2jG0mjSSkAJNDcBS2ImTtrnUtzlScQFReW4pRT464BTEdt1VHNnrekrASWgBJqFgI5RbBasOxKNd8Tjun2uw10z78K64nVwySSXM/ufgVNdnWCb8yYsGxbCJxZHZ/k2xC1+A5X+QjhOeKQZS6RJKwEl0NwEbMNOxNCOQzE0ZwmQ3AnoNEJ8qOrQkubmrukrASXQPARUKDYP15pU2QX9yhGvYG3xWiQ6E9E1qatxrHDJl0gRK6NNPPH6bVZYnbGwLP8ChUW3ITUlrSa+bigBJdAKCWRKVzNfGpSAElACrZyACsUWaEBaFgdlDDJyWpy3GG8tfwfbZBJLv8xMnFFSjix/QNzjWOGQiS+b5nyK1IPPboFSaRZKQAkoASWgBJSAEmiYgLXhw3o0lAQ4A/qB2Q9geeFyeJM64Ye4ODyRmoQAl/aTSS52sSym/XoXijYsDmW2mpYSUAJKQAkoASWgBPaKgArFvcK2d5EWbF+AwupCdIhrh6TkjmjnTMVclwNfxMehSHyueRADa2UBtvz6DsTIqEEJKAEloASUgBJQAmEloEKxBfHHOeIQEJ9qfpn57PH7sNJpxVqHAzdnpuP/MpKxKCYWlfY0zFi5HT+uK2/BkmlWSkAJKAEloASUgBLYlYAKxV2ZNNue4VnDMSB9ADaWbsSqolXIdReDS/7ZRTyudtnwaGo80lGKRE8BNiye0Wzl0ISVgBJQAkpACSgBJdAYAioUG0MpROdwUssNY27AmQPOFAfc7FuWvzLrmS+n9DUvjLXg8yQLjg98h4PmXY2Sma+HKGdNRgkogVATyCmtwpJteSit8oc6aU1PCSgBJRAxBFQotnBTpLhScEKfE8Bl/txeuuDeMRiRy3355PXftFTMio1DvLcc+T8+A09xTguXULNTAkqgIQLSAYDX5/2Esz+4DBd/cjbO+/Af+G396oai6DEloASUQKsloEIxTE13ev/TkZ2QLSu1UB6KP155sTGKxC/v54kueGU8o6eqDBVlRbJXgxJQApFCYHHOOrww51+Iq/wZ3QPrUFb0MR755U4UV1VFShG1HEpACSiBkBFQoRgylHuWEB1vXzL0EmMdaFoVTaHoFHPFz7L+81yXHwkWEZHleXuWsJ6tBJRAsxJYuvUX2D2bkWhJlO9tHDKsiSgs+R2bCtWq2KzgNXEloATCQkCFYliw78iUXdAXD7lYJrTYjA5o+lKM91uR4AvgD4cbme4tcLx7Onxf3ACUaRd0GJtKs1YCNQQ6O/ziIN8iQ0Us8i+AavkbY/EjxeKtOUc3lIASUAJthYCuzBLGlnTanLh8+OVYWbASW/O3wmK3obqyWHwtFmGj0y7rQHsQW52PwIyHgOL1wHFPA66kMJZYs1YC0U3gj60V2JSbgW5IxAZrtXQGWOH2VeCEtEHITusZ3XC09kpACbRJAmpRDHOz0iZxUp+TwEkupe4S5HtLxEYhFkWXHc8mJ8MrjrgDNhd8q6YCG2eFubSavRKIXgLvzVuAaz+5Ck8uuw+lYlHs57HjEFscbsgchfMn/gsWZ0L0wtGaKwEl0GYJqFCMgKad0GkCJmRPMEqSHZeNvu4AOnm8+CHehe12K2xitfB4PPC5KyOgtFoEJRB9BHJLy/DOgn/DEfgBHax5qBBr4lqrDQePvBUHH/0CHNkjog+K1lgJKIGoIKBCMUKaOTMuExmxGeiV3hflssTfGumGXu+w46mUBGxx2BAT8MD/6xPAOnXEHSFNpsWIIgK5pWtQVv0HkmUCiw1OJAZcKA+UYDnyAUdsFJHQqioBJRBtBFQoRkCLc1m/3qm9kR6Tjg2l67HJ7hZLoqz9LN1bHyXG46r2GfgtLhlY+wvw/gXiQ2dDBJRai6AEoodAqvxoi7FaUb3D7anh81TW40Q7O/0VaFACSkAJtF0CYZ/MUlxcjA8//BBerxfHHnss0tPTd6FdWFiIGTNmYM2aNejcuTMmT54Ml8tlnLdo0SKsXr3aiM8dw4cPR8+erWtQOdd+bhffDteNvg4Pz30Y62XiikUeSj7Zb5XxiotdTtyXZsV/hFHXgrWwzH0FOPDmXTjpDiWgBJqHQPu0XjgiewzeWT8VZdZYeAJujEzuhBFd9m+eDDVVJaAElECEEAirUCwvL8eJJ56ILl26wOl04pVXXsEnn3yClJSUnfC8+OKLmDVrFoYNGwZuv/7668aLYvHuu+9Gfn4+Bg0aBL/fjw4dOrQ6ocjK+vw+9Enrg3MGnoMFuQtQ4i6FTdzmWP1e6egKYIPTijVOC7q5ZTb0nJdg6zAE6H/UTpz0gxJQAs1DwCIeCs7Z///Q8/eOWLR9PjoldsLBwy5EfFLH5slQU1UCSkAJRAiBsArFTz/9FNXV1XjuuecMHLQUvv3227j44ot3wsPP1157rbGvsrISgwcPxuLFizFy5Ehj32WXXYajjz56pzj1faD/s0gOgzIHoUdyD8zOmQ2rVVZokfLSqlgt7z/HutCrKgCbJwkdF78H9DkCsDkiuTpaNiXQZgg4EtrjwP1vw4FcelOGhmhQAkpACUQDgbAKxd9++w1jxoyp4bzvvvti5syZuwjF+Pj4mnMKCgpkaFAAyeI6hoHbjz32GL755hsjrdNPPx02m63mfG7Q0ujzySonEtxurq/8Z2D8cAezDHyPs8fhjP5nYFnBMlR5q8R5jhV++ccafZGQgFW2VFyUZ0dqeQXipGvajBvuOkRS/uaPAZON+R5JZWxNZSHPYIbB262pHiEpq7CwiEg07hohuHeQZVTzDEmjyBKoeo2GiOSfyZjXpfn+5xHd2hsCkfhcqv29qa9ezSoU2V28fr04iq4VKPIOPfRQlJSU7NRNzP3cV1/gOMYrrrgCRx11FHr16mWcduaZZxrij8fuu+8+LF261OiODk7ju+++w3/+8x/jZsLubgpHjo1kiIQvAcvDcrEsbLgRSSNwRd8r8OOmH7GlfKsUUlZ+EFEoi7agQPxt50hdf7KMxdCiSrhsFRFRh2De4d4mw7KyMjgcDlTJ+ruR0MbhZtKU/MmTln/DRZP84IoGnqyz/JfvHv9TzDWFYN1xS0tLjR+wVhmPrKFpBKLxGm0asd3HDn4u7f5sPWN3BCLxucQy0YjGezu36wvNKhTZPfzLL7/sVAA+ZDp27GgIxbi4OFRUVNSUjWIpNrZuVxOMR5FIa+G9995bE4fd1WbgJJaTTz4ZN910ExITE83dmDBhAgYMGGB8pnC98847jeMNgamJ3AIbbCiWxbSScvvsUWejb8e+uHrq1Qj4nbDIWEUH3KiyuvGBdTwybIdgYlqKzMSUB5kIzWZ5krVA3ZsrCzKkUIyJiWmuLKIqXQpu3kyCv1dtFYBVrh2LxYoyDvOwWiAjPuTHJQVyaGvMexp51u4BCW0u0ZNaNF2jLdGqtZ9LLZFnW88jEp9L/EFgt9sbNAA0q1A877zzwFd9YciQIfj4449rDs+bNw8jRoyo+Wxu8IZK8Zebm4v33ntvJ+FpnsN3VpiBjREc2HVtdl/T8shf8OYr+LxwbrM8tR8YHRM7ItYRgyJfubjjkKeURXwpimVxkTMP19unIm7JCqD9YCCzXziLHpF5mzxrM43IwraCQpEjHxzRwLNCRqe8NGsOflj/G2JssThu4CQcNain3DNC21B6jYaWZzRdo6ElV39q5jVa/xl6ZE8ImDwj6T5KvVRbM9WuU7MKxdqZ1f5MdzhPPfWU0WVMyw8tkPfff79x2j333GO833DDDfjvf/+LBx98EP/+97+NiS8UhMcccwySkpKM/X379jXc47B7+aSTTkKCjOWrL/Bh11pCdnw2DuxyED5c+aGIRK9IRSvi4UKGfbmI5tnwv1sAa1waMErE+P7Xq+Pf1tKwWs6IJvDSnG/x0oL/Q5o9F3myKtL9Mz5HStz92L+Vud2KaMhaOCWgBFoNgbAKxczMTGOW87PPPmtYK9566y106tTJgEd/iGagEKRIZNfXtm3bjN2clEL3ONnZ2fj555+Nff/4xz9wwgknmNFa/btNlgi7YvgVWF+yHovzFiM1JhXtywrFsFiGFTLb2WtxwFkuk3t+ewoWWhYHHd/q66wVUALhJFApS2f+tv4VdLJtQzLkB6fFj63+hfhxzQciFP8RzqJp3kpACSiBsBAIq1BkjXv37m1YFGvX/rDDDqvZdcQRR4CvukJDXdt1nd/a9qXFpGFy98nYVp6Dromd4S/KR570rFfb7JgeF4uxlUBSVTGwaY4KxdbWuFreyCPgq4DVu0m8DMTKcA+r4XXALqODLZ41kVdWLZESUAJKoAUIhHjUTQuUOAqzOKjrQRieNQwby7ZgpbTYRllObIOsBX1LZiqubZchn8Wn2/KvgM1zo5COVlkJhI5AjKzbPDw+CwWoln8BiC8CyDQejEnqHLpMNCUloASUQCsioEKxFTRWiisFN4+9WV7/RGJSD3hljGaZzYpymZH5Y1wMrmvfDtsK1wEfXQJsX9oKaqRFVAKRScAiQzrOGHQODrTHyzJ9pbD7S3FaYg9M6n9SZBZYS6UElIASaGYCYe96bub6tZnkY+wx2KfDaKTFJ2J5sQVOmQQtOhEumRE+N8aFt5NTcVnhdjiWfQZk9W8z9daKKIGWIlBRLQuuyE/ndFnx6J/xmdi6dqq4WIpHRu/JQFr3liqG5qMElIASiCgCKhQjqjkaLoxMYkePlG74efOP4nybTbfDsZtN3j9OjMW4Eh9GVFZBG7VhjnpUCQQTKK8O4EVxhzN9/SzwB9nxAybhmCGj0KHjqODTdFsJKAElEJUEtOu5lTX7iX1ORGZcJrwBb03JbaIXt9steCnZghmBnjX7dUMJKIHdE3hhzlQ8u/AGFFU8jY3Fj+Dfv9yE6at08sruyekZSkAJRAMBFYqtrJX7pvXFDfvcYKwJ7TeWirDCEZA5miIWFyRYZcnE54GfH5ZFrctaWc20uEqg5QlUiDucGRteRy/bRnTzW9BLvlPpgbmYtvaTli+M5qgElIASiEACKhQjsFF2V6Qjex6JS4ZeIsuL2WU1Gru48bAhUR5yyX4bVnmKUP6TiMU5L+4uGT2uBJSAX9ZL962Xb1AM3OKX1GtxwSmfHN5VykYJKAEloASEgArFVnoZnCMzMw/pejBixDVOp4AdmR6gVGa3bBIn3b94Y1C5/AcuAt1Ka6fFVgItQyBOZjmPdCYiD97/ucPxoxJ+jBPn9hqUgBJQAkpAhWKrvQYcVgcuHXYpRncYDr84317v9MNrA7Y6PXg6swK/lm0Vh4vihFuDElAC9ROQyStn9DwKB/kDIhUrYQ+U43R7Mvbv9Zf64+gRJaAElEAUEdAJsq24sXul9MI9E+/Bdd/ficqNn6GHWBBdATfybD58aS/G/p9eBdv4K4DhZ7TiWmrRlUDzEkgbegb+KT+8tq3+TtzhxCF94AlA5zHNm6mmrgSUgBJoJQRUKLaShqqvmCmuZHRLa485GzqgVDrPXMiX5WljUOCME4fBybDNfUXWSTwUSMiqLwndrwSijoBPRmVUiFucxFhxRmqXlY2Gn4n2Q0+TPhYxy2tQAkpACSiBGgIqFGtQtN6NsdnD8MWa71BaaUWsx4Ec8Rrs97rxdVkVDnO4EVNVpEKx9TavljzEBGavK8Yrs+ZiS1kuBmb1wEXjh6BTmohFFYkhJq3JKQEl0BYIqFBsA604sdN+uHjYZny6/BOsLNgGm8UYbYVnErZjHbrjwthOiGsD9dQqKIGmEthS5MYdU5/E1urPkWArx5frMlDsvgD3H3USnHo3bCpeja8ElEAbJKCznttAo3Jiy5kDzsLh3c5BtTcR3T0udPAGEGOJxzsyXnEDJ7V4ZX0yDUogygm8vXAq1nlegz0mF7BVIdO1Fsvyn8P6QvmsQQkoASWgBHYhoEJxFyStd0eyjLeqQjzWO/pgvasXLJYYZAa2IOH7a4EPLwGKNrTeymnJlUATCeSXV+CbdS/IhBW34aS+2mJBgc2JQGAzfN7NTUxdoysBJaAE2iYBFYptqF1HdhiITimJyPeUw+IphT9QAI/4VvzBX44Nyz4BPrgYKFzfhmqsVVECjSMg3m8wfdliVFQtlbXQpY9ZPsuoRMN3Yor4Iu0gk780KAEloASUwK4EVCjuyqTV7umW3A03jr8SY7q2gzOmFFtsFpkB7cFbMQH8PSsD87bOBL69FZB9GpRANBH4bNF6/HfmS6iwe+GxWFEhP6AqxaLo81fjuKzRSE7pEU04tK5KQAkogUYTUKHYaFSt48QDOh+IRw96GCMzxiPLXYUUv18eilaskpH692RkYNPqr4BVU1tHZbSUSiAEBDw+4KOlLyDO+h0yfRZZoC8AeYPV58FfLMk4ZtTVMl5RZ7KEALUmoQSUQBskoEKxDTZqvCMe5c40VEjn2jqHDeViOaF3uMUuJx5KcMG7+H2xKsrTU4MSiAIC7/++GmsKpxsrGDllbfSungBSvF6Mtqfipgn/h5h2A6OAglZRCSgBJbB3BFQo7h23iI+1b+fR2GJPlEXJLNLVJpNc5OWUcVlTY2Pwsztfu58jvgW1gKEgsLXIg7fnv4wyx3YU2C3Y6vCj2GKHy+3HiI5jkdBvSiiy0TSUgBJQAm2WgArFNtq0U3ofgBP7nYhSrmEr3WwuEYkxAQvkP962ueA2bIxttPJaLSXwPwLP/DoN6/0fSz+zjEeUH0t+SwC5dh9S7TYc0vkA5aQElIASUAK7IaBCcTeAWuth+la8Yvhl6Bg7QHqZfYYs9MrYLKc/Hmt9TrjluAYl0JYJTF22FZ+sfRzemCq58q2wB2RsoljYYzzVOLn7ZHTqPbktV1/rpgSUgBIICQEViiHBGJmJZCbE47zB56DCE48ybxK8/o4osaRjH3scEmY8BPz4ALB9aWQWXkulBJpAYEuhFw/8/BKqY3h9O2T4hficF5HoF8mYJb4TRw04SSaw6I+lJiDWqEpACUQJAZ3q18Yb+sQBh2JDyRZ8vvprsaYAE+VROWzzt7hr08eIEUvjYXOfw8CjngZ67N/GSWj1oonAF0sXYIP3Y1hjbCIPDbeJxhAMrlB0evfj0LnDqGjCoXVVAkpACew1ARWKe42udUSMsTvxz30vwjlDjoVHHpJLpl2H5yqlK9qWIs64gVnVZbh99pPo222CjOPSy6F1tKqWcncEtpR9B7stD7aAy3CFw/M9AT/G25Nx3LCL5FqnHwANSkAJKAElsDsCqgx2R6iNHO+YlCl9b5V4pHorSkUQekUk0pecx+nCVF8h+vq8KhTbSFtHezU25LlRWrQayTKBiyuc22RsYrXYFePlh9JlA8+DM71XtCPS+isBJaAEGk1AhWKjUbX+E6tkSkuO+FfMt21GnIhEDlCtkDFbq1I7yzCumNZfQa1B1BOYs64Yt335A/z+HOl2tiPVG0CZGA89fjdOiu+JkcPOi3pGCkAJKAElsCcEqBU0RAsBvxPF3vbGzE+LuAnxyStgcWJlSQWqK4uihYLWs40S8PmB1+d9i3zcB4dztazobEWRLGOZ6fXjdCTj4n3+AcSltdHaa7WUgBJQAs1DQIVi83CNyFSdMskz3dEFlb5EFPvTUOVPRbI/Bu0L/4D7PbG0zH9TRv3rii0R2XhaqN0SqJbVVrZXfIBsSy7S5EdRJ48LqR4vxiQNxN8nv4DkXofsNg09QQkoASWgBHYmEBFCsaysDKWlpTuXrNYnv6xZ7JUHAV8+ma1bOxQUFEh3k5gUNNRLQHwO49i+R8BS3QlueYjGEqO1ChscAVxXtQIf/3oP/Gt+qDe+HlACkUzAZa1GgjVHfuzIBBaxJtJ7qMVnhSWtE5A9JJKLrmVTAkpACUQsgbCPUXz66afxxhtvICADzo899lhcc801u8CaN28eLr30UqSl7eg26tu3L/7zn//AIist5Obm4sorr0ReXh4cDgcefPBBDBgwYJc0dMcOAn/pPwS5Jbfgi+Vfwuv5CBUBG5J8dmyVjronrWXIXD8V43seqLiUQKshIIsP4fMFufh+8XpYi+PhdubALr8ZvTK0wmvxYHRCdqupixZUCSgBJRBpBMIqFCkAKew++eQTxMTEYPLkyRg1ahQmTpy4EyeKQYrAV155xbAaWq1iJRCRyHDrrbciKSkJFJwvvviiIRq//vpr2O1hrdpO5Y+kD7JyGS4cPwJHD+iDGz/7QSy5WxEjpsZYedi6ZX2/ad5SjI+kAmtZlMBuCLw+cy1un/48/LGL0cVZgFT50WmzVMmSlR4cl9wHE/uJc20NSkAJKAElsFcEwtr1/OmnnxqisF+/fujWrRsOP/xwfPzxx7tUhKKQ3c2bN2+Gx+NBRkaGcU5FRQV++uknXHbZZUhOTsa5556LLVu2YO3atbukYe5wuVyGyDSFprk/nO8sS0uXJzEhFoW2jtI955alzdwyE7oauS4Lpub8hqd+uxtl7pJwImly3uFg2uRCR3ACkcqzWrw6vbPiJbgy30S75CUoi8nDOqsfoxxD8cjoG3D6lGdgT+sekWQjlWlEwmpEoZRnIyDtwSnKcw9gNfLUSGTKMu0uNKvZjeKOXcq1Awtms9mwYcMGdO3ateYwxeKMGTNqPpsbiYmJoBXx3nvvxYoVKwxB+X//938oKSlBdXU1srKyjFPj4+PB19atW9G7d28zuiEmn3/+eUOMMY75YjnqKl9NxBbaIKfy8vKWFYsBKwbEH445ZcVwOcrht/nQTvZluWVlv0VfINGfhKMGnQW/OClubYHtyjGvtELTUh0JbdzaGAaXlzyrqqqMH2lkGSk8HeI0O1e+N7G2FegV6CwL9cntTO55ZfLjZ1u7PnD0PUMmbYknxeLi4OpEzDavUY6r5r1QQ9MIROo12rRahTd2WJ5L4a1ys+Yeic8llontzLkf3K4vNKtQvP/++/HNN9/sVAA+ZCgIX3jhBWN/cOG4XddDaOTIkfjhhx9qxOWkSZNwxBFHoH///sb5Zhrme+00evXqhTPPPNNgQIsju7BpWaT4rH1ufaCac785UYeipqWCTbqbzxl9BHKnJmFtxdNw+jYj1SdM6F1RuuxmbvwGx444Fy5bLAL+XScPtVQ59yYfXge0PKtQ3Bt6u8Yxv1cUNJEivK1y/a7Lrcb9U3/EVs8GuZe44ZKhEzbxC1rs98CJLdL+dLYtf0RCRmLgj9zY2FjjPhSJ5WtNZYrEa7Q18aurrOF4LtVVjrayLxKfSywTheLutFCzCsXzzz8fJ5206/ggp9NptH12djY2bdpUcx1s3LgRHTp0qPlsbvCBb4YuXboYk1X++OMPjBkzxhB8nMjCeOyK5qtdu3bm6cY7j5npMo+3337biGfeXHY6OQwf+IXkQ4PitSVDrw5O3Hfs4Xjoiy8xM+8PedgmwiMFKBeLjLV8E3LXfo2uvY8AXC0nYENVf15j5itUaUZzOvxBxe+L+d0NN4sKsXzfNe19LCp6BDGxhTKMwiYiUYShvPzeMhwclyW+FP+8b4S7vHXlT5bmD9a6juu+PSMQadfonpU+8s4O13Mp8kiErkTmMylS7qOsGb83FIoNhYaPNhSzEccyMzPRo0ePXV6dOom7Cgm0CtJSyLGHnLDy1Vdf4S9/+YtxbPbs2eCLgeKusLDQUL6cALNw4UIMGTLEmLDCyS+cxOJ2u/Huu+8aM6N79uxpxKvrD88jGL4iJYSzPMmyREuvlANleTO7OCeuwnabB1udfuRKl92NP92Ed7+/HgFvVaSganQ5wsm00YVsRSdGGs9nflwmP25ehiWhQGY224zJWLyZZXncuMveGZMG7ehBiGTEkcY0klk1pmzKszGUGn+O8mw8q8aeGYlMWabdhWa1KO4u87Fjx+Kcc87BySefbCjaY445BgcffLARzZzUMnr0aHzxxRd48803kZCQAPpLvP766w1rIk/kWMWLL74YRx55JCorK/HQQw8ZXY67y1uP/0lgcJ8DZVziZQi4v8DapLVo73bAabEiz27FMxu+xKCNR6J/9x3t8mcs3VIC4SGwMqcSL/3xJCzJa1HJbmVRiLyRxYlIPCy5PyYf+iiQ+ufY5/CUUnNVAkpACbQNAmEVikR4ww034LzzzjMGdbdv376GKt3emOGiiy7C8ccfDzrm5oxnCkYzdO7cGZ999pnRhc1JLS05zs8sQ2t/H9olCWcfdQ6+/D0b2wv/hWrxP5fj8EDeUC4P4RcXv4h/dxwnXXnxrb2qWv42QGB+ziKUx3yPGEMe7qiQTH4W0ejD0K77qUhsA22sVVACSiByCDRr13Njq0mBFywSGY/jEs2xiRwbRYHISTDBItFMn/3rHLuoItEksufvE3om4LwDxsFrS0O+iERZrMWYAhBrcWD2ttn4XRxxa1ACkUCg0rtceiAqpSh/ztLzyOz8ga4MTOhzbCQUUcugBJSAEmgzBCJCKLYZmq28InZbBizew4zxmx55BrtFoLNfr0jckDyw+Dl8vPpj+HQt6Fbeyq23+KVVfjz8/XJ8PneeuMTh/OYdge8BXxVO7XMCXGn1j09uvTXXkisBJaAEwkdAhWL42Edczh1THBjV7nAkl2YhSVziOP1WmSggM0llskBRZQmenPcEZm6dFXHl1gJFB4GXZHLbs8tuF8c3v8rklR0zmr0iF72eSpxmb4fDB5weHSC0lkpACSiBFiSgQrEFYUd6Vja5Gi46YCgy7OfCXt4RbqtMm4df3KIEUF69HVtL1uGj5e+KxbH1OeGOdPZavoYJVHv9+GHzO+hqn4c+4v6mvzuAdl4f9nFb8ETCYPzrwIdgS8puOBE9qgSUgBJQAntMQIXiHiNr2xGyxap40UHHINF2I9Kr2olQFMui9O2JxxwZs2jFnPXfY0XOvLYNQWsXcQRKZWUYl28pksXK7bG44LfEIdHjR9+O++PAk96DrcvYiCuzFkgJKAEl0BYIqFBsC60Y4joc1C8Nj580CsfJrFKviERZ3AeVshJGnCzx56gqxdrtC0OcoyanBOom4JMfKG/P2Yi/vfcZqsuKUCJjE/1yTVaJ769SGUPb1xUnYyPC7ryh7sLrXiWgBJRAGyCgd9g20IjNUYWU1CT0QjYGVS3Felc8UmQVvwS/FyV+Csa05shS01QCuxD4afVm3DfzPtgcC5FsK0eJVVYxkvGzMTIcYv/qAA5NH7JLHN2hBJSAElACoSOgQjF0LNtUSjZxtl3Q7QwMnvMHHFnbUCouiGhZTK8aLLOhh7WpumplIpNAaaUfL879SNYbn4YeYkaMkR8r+TY/0sWaeJWsS96/6xQ4+h8VmYXXUikBJaAE2ggBFYptpCFDXQ2OSRi/7374ZNXtqNz0I1Lic1GNDqiM2xc9O+68lnao89b0lAAJvDJrGeblf4XYWC/WyCznJK8NST4PAvHpGLD/vbB3GiXdzpG9nrO2pBJQAkqgtRNQodjaW7AZy98+0Y4bT94Xj3zfUcYlliErKQaXjO+E7ukxzZirJq0EZAyiJ4AZW95BvGOlWLJtMlrWh3y7jEsUOEc4E2BVkaiXiRJQAkqgRQioUGwRzK03k75Z8Xjk5IHIL6tGYowDsU5b662MlrzVELBZ3SIO56CruMUplEkrRTKJhT4TOwTcOD1loKzMotdhq2lMLagSUAKtmoAKxVbdfC1TeJvMeKY1sXbYWLoJuZV56JrUBekxOsGlNh/9vHcE8kp9WJ+TizRPNdaJK5wuHh+ypMt5s0xkOSYmG91HXiir93FwhAYloASUgBJobgIqFJubcBtM3y+TCd5e/iY+XPIaqquKkZrcGReNuBrjs8e1wdpqlVqSwA8rN+L+X95ETuVKdLZUwCPdzdtELAYCXvSRFVgOH3UukNmnJYukeSkBJaAEopqACsWobv69q/zygj/w1uxHEV9aiHhxw11SMhcv+v6DQUe8gCRn4t4lqrGinkBRhRv3z3oMudVT0c5uEz+JkJnONhwszhQ72ZzYt9thaNf/uKjnpACUgBJQAi1JQIViS9JuI3mt3jZfnB/nIdmaJKtkWBDrtyMndwVySzYiKWNAG6mlVqOlCawv2Yyiil/Q318Bu59jEANYIy6ZuvU5E0f3PwRoN1BnObd0o2h+SkAJRD0BHegT9ZfAngNIKReHdlwdQ8Yu8gKqtFmQJOPIkso4J1WDEtg7AknWCqRZSuCGU3x1uox3e0CW7mvfDcgepiJx77BqLCWgBJRAkwioUGwSvuiM3CFpGCZUJsKDcnGA7EWhzQ2P1Yvnfv4nFvz+jIhIWXdNgxLYQwJdrA6Ml6EMW8WpdpHFhy3y3k0mrYyx6AznPUSppysBJaAEQkZAhWLIUEZPQh16Dsbgzn/FiduzMayyGvaAD9tlJZf3qmVN3jn3YvbCV6MHhtY0ZARsMinqspRBOKfajT4yRnGKz4ubbe2Q2n54yPLQhJSAElACSmDPCKhQ3DNeerYQiHNYcOAxZyF7yqPYEpOOCll3t0K6oW1WO7Y4nbhn7kMo3jxHWSmBPSPgiEXS/jfh7I6T8IAvCdckDUSXA24B0rrvWTp6thJQAkpACYSMgArFkKGMroSSnMCEwT3hiU1ChQxY5DrQbpnY4pBuwiX+crw6635A3JloUAJ7RICub45+DDj1DeDEF4FeB+1RdD1ZCSgBJaAEQktAhWJoeUZVajabC11TD0a13wevuDLhBBe+uaxOfCsudLYXr4sqHlrZEBHgqisJsp64WBg1KAEloASUQHgJqFAML/9Wnbv0NuPMkech0zYKPllaTQyKsIlYTJC5LJU+CwqKtwPe6lZdRy28ElACSkAJKIFoJqBCMZpbPwR179c+BZcO/zts1cmI8buR5vMhQVZuGegNoNO3twHvngusnhaCnDQJJaAElIASUAJKoKUJqMPtlibeBvM7ZtAozN38d0zd8CYqLQXoFyhGjLjI+ZelGB3ycnDy1NvRNb0nkNKlDdZeq7QnBNbmVeLNubOxuXQ7RmT3wUnDBiExVn+v7glDPVcJKAEl0JIEVCi2JO02mpfLAdx6yMnYb/W+KNz4O5auvg2zXTI7Wrqgl8Y4sNq9Gfeu/xkpKae1UQJarcYQ2FZShcs+vwcrKr6EXXxvfrMtA2uKL8Edh54gM+Ybk4KeowSUgBJQAi1NQG/PLU28jeYXJ8JwyoCO2G94byy0e5HpsyLT70F7XxVWOS2YNvNxYPl3bbT2Wq3GEPhw6VSsqnoHWTElyHC4kRC7Ed9sfhHrCvMaE13PUQJKQAkogTAQUKEYBuhtOcsySyxKrClw+quw3ebBMqcdeQ4nnvHl4oufHgAqitpy9bVuDRBYVjQDcWJJdAZs4qTdgjhZANKHjajwbGsglh5SAkpACSiBcBJQoRhO+m0w795pnTGw40FYbBHn27K6hkemQsf7rbDZ4/GadxO2F25og7XWKjWGQEcUizCkONwRZJ48Mi1+dBKH7RqUgBJQAkogMgmoUIzMdmm1pXLYbLhjv79iTPcp8Hu9SPfZ0NnrQIYvgALxj5cn4lFDdBKY4spAR68HxeJXqUhe1bL042mBRKQmdIhOIFprJaAElEArIKBCsRU0UmsrYkZ8Mi4beS6SbOnIcMtYNPGlWOSrRGJaP2SndG5t1dHyhojAwD5H4x5/Gk4sKsRhJUW4u6AEZ/Y/HYjPCFEOmowSUAJKQAmEmkDYZz1Pnz4djz/+OHzif+/CCy/EEUccsUsdP/74Y3z44Yfi0NlivPx+v3HuhAkT8Pzzz+Onn34y9gfEf99ZZ52FAw88cJc0dEfLEuif3htn7XMZ3ln0BrZWFyMtqROu2OcqpLiSWrYgmlvkEMgejuF/eQLDF74NVJUAvQ8DBh8fOeXTkigBJaAElMAuBMIqFNeuXYuLL74Y9913H2JiYnDVVVeha9euGDBgwE4FHTp0KJKSkmC1WrF9+3ZcdNFFuOaaa4xzvv76a/Ts2ROTJ082xGafPrJWrIawE7BYrDh30Bk4sMt+yK3MQ7ekLsiIDbIccR3oahELtCbJ+tAaooRAl7EAXxqUgBJQAkqgVRAIq1CklXDYsGE4+uijDVj77bcf3n77bdx+++07wevWrRv4Ynj11VcxYsQIDBkyxPhst9tBcTh69GhDbBo7G/jjdDprLJMNnNaih2gpbauhqwhEvnYKq75DYOYz8JbmwtGuLzBRRH+GvIcwmNbnECYZ1Ukpz+Zp/rb83W8eYvWnqtdo/Wz25ohem3tDreE4kXiNNqadwyoUly9fjn79+tWQ7d+/P+bMmVPzufYGu5ZffPFFnHbaaYbY4/GEhAQ888wzeP3115GcnIwHH3ywRlSa8XNycrBixQojzqZNm+CWcXN80ULJNMMd2O3u8XiMMoW7LM2av1gZLcUbUP3t/diUXyGruMQjJXcu2nv+A9uRDwJ2F6RBmlwEXvhsX4ZIaeMmVyqMCZg8yZTXaX3fGatwr/IEMGttITYVlWNQh1QM7pQoFmN/KJo1jASaJ2uyrK6uhk0mgGloGoHGXqNNyyW6YkfNc6mFmtW8RpldpDyXWCa2M4fzNRSaVSg+9dRT+Pnnn3fJv1OnTrjnnnvglVmxDocs6/G/QGuf+YA39wW/L126FBs2bKixQPLYbbfdhri4OKOitEReeuml+Oyzz3a6+S5btswQk4RSUlKCyspK48XP9T30gvNt7m02UlVVFSoqKpo7q7Cmb7E54Fk9F+tKvPg1NQsldj86VHfChE0bkbV9A6zp3RGQiS9NDWxX8uT1RbaR0MZNrVM44wfzpAW/Lp6CHNUe4PGff8NvOT/AL8s3Opd0wPF9jsCJw/vAJ0s6atiZAK9R8uRDQ0PTCDTmGm1aDtEXO1qeSy3VssHXaKQ8l1gmUyhyu77QrEJx7NixoCisXYDERLEySEhPT0de3p+rMuTm5hr76ivsm2++iXHjxiEzM7PmlOzs7Jrtv/71rzj00ENRWloqy8Wl1Ozff//9wRfDxo0bDTFJ62PtctVEaOENXjQsS3CZW7gILZbdMmscXnduwpKYPDhEOwRivfAVV+Lsec/Buf8VQGqPkJSFPF0ul/EKSYJRngjHENP6xbHC9YWf1mzAp9ueRTvbSiSLYXi7tMEb6woxecRdaJ8cU1+0qN1Pwc37kArF0FwCjblGQ5NTdKQSTc+llmrRSHwu8T5UnwHA5NKsQpHjD/mqL0yaNAk333yz0Z3Fgn7//fe44goRCxK2bt1qvHfosMPHGn99c/Yzu5bNQCVMC2RsbKyxi7Of2RVNC2N9gecTDF+RIhRZlmgJyxJiMScuFr09XrjEj57Y/fBVYhyGLPgS/bavRtLpz8MSIncp0cS1ua+fxrBcX7oAKYE/kO1zIAAb4uHGevyK7ZWb0T6pZ3MXsVWm3xiurbJiYSi0sgwtdOUZWp5mapHGtTHlaVahaIKp7/2QQw7Bu+++iyOPPNLogqb18ZhjjjFOf+CBB4x3UxjOmDHDsA5NnDixJjl21Z5//vmGMKRoXLBggdGlzS5sDZFJICbWjWpXJrZWBdA5IONFEYtSmxVb4rPQYeNK2Nf8hvjBf4nMwmupGiTQ2V4Eh6y7Ui0SUfxpwx1wIdlShlSUNxhPDyoBJaAElEDkEgirUOT4RPpB/O2334x+cnYr07LIcP311+9EjZZJWhTZvWAGdmHff//9WLVqlWEd5EzojIwM87C+RyCBwRkD0SMjCwVFRSgoc2KzDFH1Wf14N6kUXpkIMTZnC1z9vTXXQQRWQYsURCCn2IPpy3NQ4QOGeuPQ1+7AUhGLsbKWc7nViyNs8ejorL+7Oigp3VQCSkAJKIEIJBBWoUgeHJ8zfvz4XdBkZWXttI/jGesK9LvIl4bWQaBjQidcM/JKvLrkNfy2aSvivFXo6nFIB3QF3kkLoP2cZxBY8wPSD7gE8T3HyQUS9ku0dYANQylXb6/EDZ9/gBVl0xGw+dDL0gN/D3TDQusabJDlGgfL7JYjux0Na3LnMJROs1QCSkAJKIFQENCncCgoahp7RGBMh7FIi8nA6sJ1cMhEFpuvSCY/VGO9w4U8WyH2Xf8h/K9/IY6ZxwCT7wfaD96j9PXk5iewZns1rv3kfazyPIZOMcVIlolJmy0L8EHcYXggWyaOFa8FOo4QM+PJIvbV/Uvzt4jmoASUgBJoHgIqFJuHq6a6GwKpMSlIj0+BP74drMX5qCpaJf72AtJlWQ6/dF/a/OLaZuV3sHx2DSxnfwo4dkxY2k2yergFCPhl7tULs3/BcvcLSIgrQlHABo/Phg6eMmyUGc9lE99CgkPFYQs0hWahBJSAEmh2AurAq9kRawZ1EeByfkf2OBIVnnyUyoSHjY4Ayi0+vJychFsy07Ha6ZJe53gEtsxHoGBNXUnovjARqPL4sbL4PaQ6N4qLIxtiZNZ+mc2P7eI4OitQKLPZK8NUMs1WCSgBJaAEQk1AhWKoiWp6jSZwUr+TcOu4W7Ffr8kIOOLR1e2BU0THD3ExuEnE4qIYJ8q8Mcjz/DmBqdGJ64nNRsCCSsRbVqKdV1bakVyqISscyb8qiweHJXSBw67t1WzwNWEloASUQAsTUKHYwsA1uz8JWERmDM8ajmFZw5CSlA1/cncsddhRJBOcFohIvLh9Mh5pfwD8iV3+jKRbYScQI0sx9rE5UW2xoYvMVI+XVVd8MlTgHCRh8pDzdQJS2FtIC6AElIASCB0BFYqhY6kp7SWBBGcCAuJ4L1fcX5aJSKSVyiGWxTyLA790cCMlQQbFaYgYAhYZL3paxwPQT8RhpfjATPb5cYrbinPH3wRrtwkRU04tiBJQAkpACTSdgE5maTpDTaGJBIZmDsXYDmPwzvK3xR+mJCZ/LNKdGSNj3nIr1mBL2SZ0T+7RxFw0eigJdBl5Ae7yuLFqw49w2VzoNeAEWAccG8osNC0loASUgBKIAAIqFCOgEaK9CBQafx35V2NZxRcWv4A4R5ys7GGF2+dGnD0FDquutBNx10hMMmIn3YjBlZfu6Gp2JURcEbVASkAJKAEl0HQC2vXcdIaaQggIxNhicNHQizAoYxAqvZWo9FSKVdGCKT2mIDshOwQ5aBLNQiA2BVCR2CxoNVEloASUQCQQUItiJLSClsEgkB6TjkcPehTvr3hfupu3GBNdjux5pGFdVERKQAkoASWgBJRAyxNQodjyzDXHBgh0kiX+rh5xdQNn/O+QzwMs+RBY9a3MfIkDBp8AdJu4+3h6xm4JyDyiHWNFd3umnqAElIASUAJtnYAKxbbewm21fgvfQvl39yPPnwCbrBSdtX4WnMc+JsvGjWyrNW72em0r8uD1uXOxLG81uqZk49Tho9EzM77Z89UMlIASUAJKIHIJqFCM3LbRktVHQCa5VC36FCvLY1HlSIRfnD2X521Dr6VfwKZCsT5qDe5fkVOCq794GCuqPofTXomf8+MxP/cveOrovyMjwdVgXD2oBJSAElACbZeATmZpu23bZmvm8QWwpVhmRIvvRZfdIjOjrbBLd+mmEu//6iwfrPIbiC8NuyVQVhXAIzNfxlrvG2gXU4w0uweJMYVYV/ouft0wd7fx9QQloASUgBJouwT0Sdp227bN1sxtdeHX2AlwV72IWYkexMrKIKMDdlkhpA+6/vYEsPZHIK4bMPosIHtAm+UQiop9sWwRnpn3BtaWf49YhxU2v0W68iFWWnlZSuGuXiFb40KRlaahBJSAElACrZCACsVW2GjRXuR4uWr/6J6Gz7xJaG+zQKa1YFpGIu7Mf18muPwAuMvEZUtHYPXXwFH3A70OjnZku9Tf4wNemzcd/51/D6yBjYi1BVAhTs4rxOF5vMxmIdN4cXw+0Bm7S1zdoQT+v71zgY6qOvf4P8lMMpkk5MWbJLwfF+QNggpaHlJAQVu0UuutV6pSwaKiVVnXKnK5Syu6bsVVWK20aNWrLhQpoBdQkZeIFQQrFMI7SJBACHlNksm87vedOHHAJJDkJDmT/HfWyZw5j332+e19zv7Pt/f+NgmQAAm0HAJsem45ed1s7tQjI56LIr9CckI6iqO6w2PviaSEOBx37wa85dLkHI2AzYFAUQ6w4k7gM7EyitWR4XsCr+/ZgOd3Po7IyGPSfB+AVxyc66/GchGKedKkD58bd9nboXfHkd+fxDUSIAESIIEWR0DrBgYSCC8ComN0Ruj0ZAcS7EkidiJwqqQUUV6xg0n3RGk9NUSPPdIOlJxFYMN/ytESrpoVXvfZAKndm3MMHx/bifcPvoYOUXkohh3RwiwQEYBNtHREwIcfuUowxdEJV46aDyRnNEAqGCUJkAAJkEC4EKBQDJecYjorCagAvL7zBCz9aim8AS98spSWR8Hh7YpXkg9jS6xTRFASbi7wYLhHetz5vCIWn0DEyS9ELM4G0oZVxtWSVl6XkeLPf/EsynxnYLPZZBBQNKKkmdktMtonCtsj67/wOjB3xFyg701AQruWhIf3SgIkQAIkUAUBCsUqoHCT9QlM6jpJhI4TW7O3Ik4cbl/ZdhwWuZbhsPcsvFFRaBsViy0xrTEvUITJxSUiFj3w734dEUc3IWLqYuDfplj/Jk1K4eZvdmDF/lX49NRGsRwWChs7ykQcan/EJBGHNllKxJI4o1iam/vdBoy416QrMxoSIAESIIFwJ0ChGO452ELTHyl96sZkjDEWRZBTchrFjkwEXJFGU2qMiJ8iGYzxdOtk7HQ4cEdhMbqJOAoUn4H/7/cj8sQOYPTDMjo6pVkT3HFqD574aB6iy04j4JD7j7CJBRFwCB+PCMVysST2c5djMuJw0zCxtg7/VbPmwZsjARIgARKoHQEKxdrx4tEWJeD1++APeIyp5wJGj8QI6acYQH5kJN5ITMBnTgcePpePca5SFJfIYJjP30WH6AREjnncondkTrL+L3MtEktOGhbWUrEaFoo4lJ6cFQNXhNlUxOOxHj+Gs99PZVabIeZclLGQAAmQAAk0GwKRzeZOeCMtmkCHuA4Y2fEquGW0rhgSK0IgQga9QCyMARy32zC3XWs81SYZR+0pyPOnIj9zm4ySLmsW3E67TuOtzLfx132vYH/egcp7csp2f2SUDPCJRrr4I1dLYqkI6BJ/Oa4qK8M9Qx+Ac8JCisRKYlwhARIgARIIJUCLYigNroctAW2KfmjIQziYdxD/PPdPEUY6hFdG84pYVN2oM7fIFryTEI/djmjcVOjCQE9nOBEDh1/6LxbmIDI2SfwvxocVAxWFm7/ZjA3H16O0QPwhihB8P2EFHh75BK4W4TzR2R7b/V6cliHNMTLCOdVThhluP0Ym9kDfYb+AY8DtYXW/TCwJkAAJkEDjEqBQbFzevFoDEmgX1w5/HP9H/OmrP2HHsR0o9hQbYlE65hlXlR56MutIhIgmN96M8+Jo+yvQP+cAcj/8b7hzDsIW3xoJ186Cs/+NDZjK+kddJlbQAncB9p7biz/t/B9kF55AsVgIE6RBOVUcIZaVHMZ7+/8XV3YcgYF9b8ZTB9dhjStHBq9Iv04RixNHPQoMnC7qObr+iWEMJEACJEACzZoAhWKzzt6Wd3PaBD3/6vnY3HYz3jz4JjbnbJam6CioxVGMbSIVA/D74uGKSkVOdBbOrJkHz6ljKJUR0razuShf+1+wRcmAmJgYoHUvIDHdEhADkvgtJ7eI5XAD9uXtQ4mnRBYXkorzkSA62C99MT3iC/GszY5WngAKzh6Ugd6lsIkYHjL1JQz5eqV0TnQBPScAvX4sArpCPFvi5pgIEiABEiAByxKgULRs1jBh9SEwosMIDOkwBM/teQ5/P/x38R1YBlukDfG2BHRs1QM26bNoLxYrW24WdiYl4lBMGdp4YzG14ATc78zB+YhYmfklFvYeo4GMEYjoPVGapVvVJ0m1PtcjTeJZBVnIdmVjT85uvLn/DeS582ET4RslwldDuboC8or8la+ig1EaKYN65Kke6oqA3SdiUHyOo+PgikVPYCABEiABEiCBWhCgUKwFLB4aPgTKfeVwxDjw5FVPYnzn8dh5eic+yvrIsCw6Y/wo8RWKa51J+CD3Day3nxfhZYcvxofTMsvLSFcrlIjgGlpwFG12HETxF6uR1O8TOKe9CEQ1bHOtNiurb8iNJz5B5vkDOFl0Em7Z5pcpCCPFquiMipHPitlnNDfKxIqoLetJvgBybRHoIFMYDissx+B2V8Iv8zRztFr4lFmmlARIgASsSIBC0Yq5wjSZQkDFlSPSgevSrjMWddK97tg6lHhLjO89knrivuOrkJKXC6eILRkCjS3iRuczpw+x0ufvrylt0dHjkwEg4mPw8GYM2LgU58vLESlisc2gKXC07yaniFfC7C+N0dOBDgMQEZMopj2x6kXZRJTKrDAS1NKn0Rti7ztLoG5XMbvm6BrsytmFLq26oF9qP7yybzk+zd4OtSZGfWc5jPyumdgvn6HCTzQiYmWITqx8pnp9+PdzLnQubYucxInocM1MsaAaN6WXYiABEiABEiCBOhFoUqGo/a4KCgpw/PhxJCQkoHv37tXeRGlpKTZu3IgoaWobO3YsoqO/t+wcO3YMu3btQs+ePTFw4MBq4+COlk2gb2pf6BIMWYVZMtLZiYjWfeF2i3/FQBmKfDlI8/pRLqLsrFjocmUWk2SZDPmsFLfhe17C3mjpvyj+B685vBY3TpiPzAOvYK1YAL+JikB0QkekRIs10u9GSXQ8WidmwCWiMV8GniRIs7VL+gi2kv0D2gyUEclXyxSES7D6yGoRkfIn14u1xYrG9IjAjIBDLIehwRCassErK5o2n6wniEicnO/G2PxEuCMSca7NeOT96KcY3LMbuqVeeH5oXFwnARIgARIggcsl0KRC8ciRI/jlL3+JU6dOYdCgQVi1alWV6Xa5XLjtttsQFxeHcrHoLF++HH/7298QIwMOtm3bhtmzZ2PUqFF45plnMGfOHNx5551VxsONJBBKQAe+9E7pJX4HM9EmOR15xSeB85Gw+UUg2v2Ik+bccrHKtRVr3TfypBxJdaKNzy4OqwPI8p/C4Y8fwNroIpyNtcloarEclp2Cv/SkIfyi5JhA7m75X2HVk29Gs7eOqNkormxei22NUyU5hiCscOAjswyKGxsddFOVHVBd+zjFQppe7jFc4PQVYTu0yA1vzA3IGzcX8YlJGNMjHW2doTbH0LvlOgmQAAmQAAnUnkCTCsWOHTtixYoVWLt2LVavXl1t6t99912oWNTj1Ao5cuRIrF+/HlOnTsXChQsxa9YszJw5E9u3b8e9996LadOmIT6+an94apFkIAElEC1NyLMHz8are1/FkYIj6JKQAZ9YE13lUSj3nJLm33LYRTS6xQNhkXQMbOuLQoJ8V/FXKANjVtqL4Y60I06afmWz9BWUsiVLUOipuNMQ/F6xLoJRyvBZcVcTKSOVgyKxYp/GXRH0Uxcdj6LxdPR5cce5MgwvtInfRx/OIwNHOk1Fn4m/wtBubStO4n8SIAESIAESMJlAkwpFp9MJXUKbkau6vy1btmDMmDHGLm2iu+aaa/Dpp59i4sSJyMrKwrhx44x9gwcPNirfQ4cOQdeDQZutCwsLja85OTnw+/3GEtzf1J/B9OgngzkELpdpWlwa5l05D4XlhYiR5t73j7yPlYdXorQoDsXlQHxUa+THxiDK+y3iPW7xVKiPjIg9v9gIpVlahpbI9whDzAVzL2jTC34PvSMVjXY5P0ZWdC7qKFGDui0oCtVhdpQsZbIvQSyIPcrLMK7YhV5F0dgdcyve7TMFKXFR6NMtAxN7piM1VoRkI5Sby+UZeq9cr5lAkKm+0xjqTyDIszGeh/qn1voxkKf5eWRFppomNcDVFBpUKB48eBBnzpz5wfVVHGpTs1pUNFwqkfn5+RcIv5SUFEMgFhcXw+v1Gv0bNR4VnNocrceHhq1bt+KFF14w+oGpZVKvV1RUZBxyqWuHxtNQ65pRmi5WGOYQVo5aNux2u/Q9dF+yfOnxkfLnlr/r21+PAQkDcL7sPL4p/gbZxdnSnCsWOxGGGzJXAiUFYj2MQFpKV6T5WyGzaJe42rFXCsUK2VhhBQzIaBO/+K0JWg31UVRJoGIwWcqtXz5P2mIMa6Se11vSOqnAJVbKCCRKX8VEtw15vu44buuHTRlDMXz4cPwsPVGeG01vQMbRnEdeac0PuBlElY9y9Hik0V3SbIVnxoz7auo4tIwqy+B7sKnTE87XZxk1P/dYL5nLtLb1krlXrzo2TZPP5zN0lK5XFxpUKK5btw6ffPLJDwRQRkYG+vfvf9kvSBWAZTIvbTBopaWCULfrS1b7LWrQgq03fbGFUi2Oo0ePNtJx9OhRPPLII2jVqtVlXz943Yb61DTrfSQnJzfUJVpcvMpTy4HD4aj1vSeKX0UNI+QvGLRCT2uXga3i9DrW7sQN3SYjwpOCh9YtxOGiHYiWnovx3mgU2UXERXrR2hOJzmXSUzEmH2cdFeUzSfogipMbOS4C/YsSMajUhRNxueLWxo50KdP93fE41Gkmikr9yLbH4lhGfySn98LI9Pbo3TYWTsOYqb0dGz/o86fPmT43DOYRUJ7sDmMOT5ZRczgGY2G9FCRh3md96iXzUnFhTKqb1KhSkwGgQYWiDizR5VJBE6kAqwt9+vTB3r17K3fv27fPaG7WwS06WloHxaSnpyM3N9ewzHXp0qXyWF3RF3FsrLTRSVDhoMo5uBgbm/hfMC01KfomTmLYXd5sphrf6LRRxhIKY/nNz+OTw4dxvkgs1WUR2JiVh28LipEgfRjLxFpYnnsczvKTSIvKRRqK0TqyABHOVCT0uwWFSe3R4cvluMqdKSOvM1A+9D9w7YBrUOaRciqOsuOr6k6rZTc0AY20bjbPRkq2pS9DpuZmD3mSp7kEzI/NimVU03Sp0KBC8VIXVwuFWhx1EIq6yNEBLdokrRbHJUuWGKfrQJXp06djypQpxsAX7W944MABLF682BB7um/+/PlYsGAB/vKXvxijnzt16lTtpWtSzdWexB0kUA2BNq1s+NmQPpV77xa/NUVun5RN6Qoh7nWOygCU0wXl8mNFGrelydgR4UdKfAy6tnEYvR3zRwyWKQUDSHDITCrfxVKlQKy8AldIgARIgARIoPEINLlQ3LRpk2H2VN+IH374Idq2bWsIxdTU1EoKvXr1wrJly/Dyyy8blsdXX30VaWlpxv4HH3zQOH/p0qWGH8ZHH3208jyukEBjE7CJFTDZ+b0p8Ir2TuhSXUhSdSjN1QwkQAIkQAIkYEUCTSoU1YWN+j6sKqjfxNCgfQx1uTjYbDY88MADxnLxPn4nARIgARIgARIgARKoO4HqOwbWPU6eSQIkQAIkQAIkQAIk0AwIUCg2g0zkLZAACZAACZAACZBAQxCgUGwIqoyTBEiABEiABEiABJoBAQrFZpCJvAUSIAESIAESIAESaAgCFIoNQZVxkgAJkAAJkAAJkEAzIECh2AwykbdAAiRAAiRAAiRAAg1BgEKxIagyThIgARIgARIgARJoBgQoFJtBJvIWSIAESIAESIAESKAhCDSpw+2GuKHLiVPnla5pbunLicPMY3Qu6piYGDOjbPFxKc/o6OgWz8EsAMrzcuYENet6LSEeZarPPoM5BFhGzeEYjIX1UpCEeZ9WrJcuRw+1OKGocz1nZ2djzZo1lhCLwbmnPR4PhY15zyOUpz4ArIjNgerz+eD3+43pMs2JkbHoXPd2e8UM3xTh9S8PLKP1ZxiMgfVSkIS5n1asl/S9rpoomOdV3XGLE4qJiYkYOXKkIRSt8nIuKCjAtm3bMHnyZFptqiqltdym+bp161a0b98ePXv2rPEBqGXULfJw5Xno0CGcPn3amEazphdKiwRUh5tWhh988AFGjRoFfScx1I8Ay2j9+FV1NuulqqjUfZtV6yV9F6kmquk9FCEHBep+6zzTDAInT57EfffdZ4hXM+JjHMDDDz9sVMI/+clPiMMEAu+9957xY+aFF14wITZGoQSmTJmCpUuXIi0tjUBMIMAyagLEkChYL4XAMGk1XOslDmYxqQDUJ5qysjKjqVRNwAzmENBmPeXKYA4BZalMGcwhoM+6NkOxjJrDU2NhGTWPZZCnllHWS+ZxDdd6iULRvDJQ55h00AWtCnXGV+WJ7dq1Q6tWrarcx421J6AslSmDeQT0meeAK/N4soyax1JjYr1kLk+NLVzrJTY9m18Wah2j/mJzu92IjY2t9bk8oWoCylMHsthsLa4bbtVA6rnV6/VCBwvoqD0GcwiUlpYaPK3kgcGcO2uaWFhGzeXOeslcnhpbuNZLFIrmlwXGSAIkQAIkQAIkQALNggCbnptFNvImSIAESIAESIAESMB8AhSK5jOtV4xqmj5+/DgOHDiAkpKSesXFkysIqJsH5ZmTk0MkJhHQZtOioiK6HqoHz8OHD7NM1oPfxacWFxfD5XJdvJnf60BAB11o+dS6iKH+BLTbzrfffov9+/fj7Nmz9Y+wkWNg03MjA7/U5ebOnWsUJu1IfObMGSxcuBDjxo271GncXw2Bt956C0uWLEFKSopRKU+dOhWPPfaYJZytV5NkS2/WUZAPPvggPvvsM8NZ9EcffYSEhARLp9lqiVNBM3PmTKPCKCwsxN13320sVktnuKTn1KlTmD17NjIzMzFkyBC8/vrr4ZJ0S6ZTf1Sru7b4+HjDWKEDMPQdmpSUZMn0hkOi3njjDSxfvtxgqv5ob7nlFsOFm/pWDIdAoWixXNJfG8nJycYgjD//+c9YtmwZduzYQWFTx3z617/+ZTycGRkZOHr0KMaOHQv1tzZ48OA6xtiyT1OhuHnzZsOa+Pjjj+Pzzz9nBVLLIvGHP/wB69evx+rVqw1xo74+VXB37ty5ljHxcCVw/vx57Ny5E3v27MG6devw8ccfE0w9CGjLiy5XXHGF4XJI/X1OmjQJjzzySD1ibdmn5ufnG6PInU6nYQi6/vrrsWXLFnTr1i0swHBIqMWyqU2bNkaK1FStQb+Hy68OI8EW+9e3b9/KFOlDqb+S1YrDUDcCOuXc+PHjcfDgQU6PWDeEhkj8+c9/blhktTLu2rWr4cycQrFuQPWHtVa8eXl5fFfWDeEFZ6kFMegKS4WNls9z585dcAy/1I5A0BqrI8l1dL62cIWTBwkKxdrld72O1gKivyyqCipgHA6Hsev3v/89NmzYgCNHjmDVqlV8+VUF7Ltt2k9O+3VeHNQ1jj6coSJbZ8FQzsOGDbv4cH7/joA6Ldam0YuDclSewbmztSwz1J6AToSllW6HDh0qT9ZKWedaZagfgeCP6/rFwrNDCXz99dfYuHGj0QoTup3rtSegrQiLFy82LN9PPvkkOnXqVPtImugMCsVGBJ+VlWX0TdLmu4uDNuOpeV/DjBkzcPPNN+O1117DggUL8PbbbxvWh4vP4XfgpZdeMpqbQgWhctFfwS+//HIlt7Vr1+LFF1/EihUrEBcXR3TVENAm0EWLFv1gr/aZ1a4QypWh7gS0nOoSOtuFrgcFeN1j5pkkYC6BEydO4I477sDvfvc7DBw40NzIW2BsV199NXr16oUvvvjCeMdOmDAhbN6nFIqNWGC7dOmClStX/uCKamVQE38waHOzLjqwZdCgQYYFon379sHd/Awh8NBDDxkd2UOFovLUilebSTVs2rTJGMCiwrt///4hZ3P1YgL68ho9evQFltjgMWqNDQblq8zp0DxI5PI/1Zp47NixyhN0Tt1p06ZVfudK3QgEy2TdzuZZoQR0IOX06dNxzz334K677grdxfU6EtBBf7roj23tp6yCMVx+eFMo1jHT63KavshqmlZO3eF8+eWX6N27txG9WnDUPK19cBiqJqCz2dQ0o40OBLr99tvx9NNPQ8W2NudrRR0qzKuOuWVuVcuhLjUFdfOgQkeb/Q8dOmRMPxnsW1vTedxXQeDWW281mqBuvPFGfPXVV8jNzTXEOfnUjYBaZFVs66KusHTQWuvWrWt819btSi3jLC2POoBF+8+qlwh9Z2q3k9TU1JYBoAHuUushrXe0/v/HP/5huMpRvuESOOrZQjmlPsDmzJkDdfegVjF9OOfPn48+ffpYKJXhlRR166AuclTIaJO/cn3qqafYT7Ee2fjEE09g9+7dRl9G/YWsVkgttwyXR0D7d6rbq61btxpW2Xnz5tEF1uWhq/Io7VN7//33G5Wv+v/TMqnuXYJdeao8iRurJbBr1y6Dp/6w1n6f+s686aab6MKpWmKX3vHcc88Zng10uk5dfv3rXxsi/NJnWuMICkVr5MMFqVDLoj6c7Et3ARZ+IYFmRUCfcx35qC0NDCRAAs2bgP6I0YGX+kMm3AKFYrjlGNNLAiRAAiRAAiRAAo1EgFP4NRJoXoYESIAESIAESIAEwo0AhWK45RjTSwIkQAIkQAIkQAKNRIBCsZFA8zIkQAIkQAIkQAIkEG4EKBTDLceYXhIgARIgARIgARJoJAIUio0EmpchARIgASWgvlJzcnIqYajbpu3bt1c5FWXlQVwhARIggSYiQKHYROB5WRIggZZJQEWhTo0WnC9bfaw9//zzhn+1lkmEd00CJGBlAnSPY+XcYdpIgASaHQF1YnzDDTcYDnfHjRtnODNet24ddIpPBhIgARKwGgEKRavlCNNDAiTQ7AlkZmYaU0vq9HO/+c1vMGPGjGZ/z7xBEiCB8CRAoRie+cZUkwAJhDkBnet5586dyM7O5uwsYZ6XTD4JNGcC7KPYnHOX90YCJGBJAhs2bMCZM2fQq1cvLF++3JJpZKJIgARIQAnQoshyQAIkQAKNSKCwsBDXXXcdFi1ahLS0NKOPogrHzp07N2IqeCkSIAESuDwCFIqXx4lHkQAJkEC9CWifxFmzZhkjnpctW2bEt2DBAqMJ+p133kF0dHS9r8EISIAESMBMAmx6NpMm4yIBEiCBGgi4XC50794dzz77bOVRv/3tb3HttdciLy+vchtXSIAESMAqBGhRtEpOMB0kQAIkQAIkQAIkYDECtChaLEOYHBIgARIgARIgARKwCgEKRavkBNNBAiRAAiRAAiRAAhYjQKFosQxhckiABEiABEiABEjAKgQoFK2SE0wHCZAACZAACZAACViMAIWixTKEySEBEiABEiABEiABqxCgULRKTjAdJEACJEACJEACJGAxAhSKFssQJocESIAESIAESIAErEKAQtEqOcF0kAAJkAAJkAAJkIDFCPw/z1xi1rjjsjcAAAAASUVORK5CYII=", "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import glob\n", "from PIL import Image \n", "\n", "exp_dir = os.path.join(args.output_dir, args.experiment_name)\n", "png_files = glob.glob(os.path.join(exp_dir, \"*.png\"))\n", "latest_png = max(png_files, key=os.path.getmtime)\n", "img=Image.open(latest_png)\n", "resized_img = img.resize((img.width // 4, img.height // 4))\n", "resized_img" ] }, { "cell_type": "markdown", "metadata": { "id": "hTEnpDwZbjGA" }, "source": [ "Let's now simulate the situation when you want to load a trained model. This is usually done to either evaluate or resume training.\n", "\n", "In our case, we will load the previously trained model from disk and run our training loop on it again.\n", "\n", "Since JAX is built on functional programming and we never returned the trained states from our first `train_loop()` call, our global `opt_state` and `ema_state` variables contain the initial, randomly initialized networks. It's as if we started our program anew, with the difference that we first load an existing model from disk using Orbax.\n", "\n", "As we did when we were saving the checkpoints, we filter out the rngs so that the type and shape of what we ask to restore matches what was saved. The cool thing is that since our states have been marked with sharding, Orbax ensures the state will never get materialized on a single device, but get sharded accordingly. Additionally, Orbax will disregard the sharding information of the states saved on disk and load it according to our current sharding. This allows us to first train a model on one cluster configuration and then load it in another cluster. Note that for this behavior to work, you need to use `ocp.args.StandardRestore()` and not `ocp.args.PyTreeRestore()`.\n", "\n", "After loading the states and rngs separately, we merge them. You can examine the values inside the rngs dropout object. It says `Array(5000, dtype=uint32)`, which should match the number of training steps we have done in our model, so we are resuming from the same rng sequence. The dropout rng value in the EMA is `Array(0, dtype=uint32)` because it was never drawn." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4hKD6yCWeLLZ", "outputId": "fa341ca3-42d1-40b7-d201-448584baf3d1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:33:16,340 - INFO - Restoring checkpoint from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:33:16,575 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 619 Bytes/s (total bytes: 96 Bytes) (time elapsed: 154 milliseconds) (per-host)\n", "2025-10-05 20:33:16,817 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 24.0 MiB/s (total bytes: 4.1 MiB) (time elapsed: 169 milliseconds) (per-host)\n", "2025-10-05 20:33:17,026 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 666 Bytes/s (total bytes: 96 Bytes) (time elapsed: 144 milliseconds) (per-host)\n", "2025-10-05 20:33:17,335 - INFO - [process=0] /jax/checkpoint/read/bytes_per_sec: 51.8 MiB/s (total bytes: 12.2 MiB) (time elapsed: 235 milliseconds) (per-host)\n", "2025-10-05 20:33:17,335 - INFO - Finished restoring checkpoint in 1.12 seconds from gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_5000.\n", "2025-10-05 20:33:17,335 - INFO - {'step': 5000, 'event_type': 'restore', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'checkpointer_start_time': 1759696396.2143493, 'checkpointer_duration_secs': 1.1216151714324951, 'checkpoint_manager_start_time': 1759696396.1660664, 'checkpoint_manager_duration_secs': 1.1698994636535645}\n", "/tmp/ipykernel_37871/985055546.py:24: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n", " opt_rngs = jax.tree_map(jax.random.wrap_key_data, opt_rngs_keys)\n", "2025-10-05 20:33:17,337 - INFO - Checkpoint restored successfully\n", "2025-10-05 20:33:17,337 - INFO - ── Shard ↦ device map: Opt state sharding after restore ──\n", "2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,338 - INFO - model/dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,339 - INFO - model/dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,340 - INFO - model/dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,341 - INFO - model/dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,342 - INFO - model/dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,343 - INFO - model/dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,344 - INFO - model/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,344 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,345 - INFO - model/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,345 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,346 - INFO - model/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,346 - INFO - model/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,347 - INFO - model/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,347 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,347 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,348 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,348 - INFO - model/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,348 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,348 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,349 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,349 - INFO - model/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,349 - INFO - opt_state/0/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,350 - INFO - opt_state/0/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,350 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,351 - INFO - opt_state/0/mu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,352 - INFO - opt_state/0/mu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,353 - INFO - opt_state/0/mu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,354 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,355 - INFO - opt_state/0/mu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,355 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,355 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,356 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,357 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,358 - INFO - opt_state/0/nu/fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,359 - INFO - opt_state/0/nu/fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,360 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,361 - INFO - opt_state/0/nu/fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,361 - INFO - step () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,361 - INFO - step () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,361 - INFO - step () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,361 - INFO - step () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,362 - INFO - ── Shard ↦ device map: EMA state sharding after restore ──\n", "2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,362 - INFO - dropout/rngs/default/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,363 - INFO - dropout/rngs/default/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,364 - INFO - dropout/rngs/default/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,364 - INFO - dropout/rngs/dropout/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,365 - INFO - dropout/rngs/dropout/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,366 - INFO - dropout/rngs/dropout/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,366 - INFO - dropout/rngs/noise/count () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/count () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,367 - INFO - dropout/rngs/noise/key () → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,368 - INFO - dropout/rngs/noise/key () → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,368 - INFO - fc1/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,369 - INFO - fc1/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,369 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,370 - INFO - fc1/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,370 - INFO - fc2/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(0, 256, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(256, 512, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(512, 768, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,371 - INFO - fc2/kernel (slice(768, 1024, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,372 - INFO - fc3/bias (slice(None, None, None),) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,372 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_0(process=0,(0,0,0,0))\n", "2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_2(process=0,(0,1,0,0))\n", "2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_1(process=0,(1,0,0,0))\n", "2025-10-05 20:33:17,373 - INFO - fc3/kernel (slice(None, None, None), slice(None, None, None)) → TPU_3(process=0,(1,1,0,0))\n", "2025-10-05 20:33:17,389 - INFO - Opt state after restore: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'model'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(6, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 0],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(5000, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 1],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 2],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-0.04030257, -0.05037591, 0.07226934, ..., -0.05214077,\n", " 0.05906752, -0.05086785], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 1.0710918 , -0.9678744 , -0.8126575 , ..., 0.17166048,\n", " 0.36044577, -0.6832711 ]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-0.00861618, -0.04522035, 0. , ..., 0.04507154,\n", " 0.0422631 , 0. ], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 0.03979657, 0.04424966, -0.0138542 , ..., -0.05651323,\n", " -0.01515646, -0.02998248],\n", " [-0.01223963, 0.00112585, -0.02565034, ..., -0.04619034,\n", " -0.0092434 , 0.00962243],\n", " [-0.01330583, -0.03468521, 0.01838579, ..., 0.02772252,\n", " -0.02609745, -0.05185288],\n", " ...,\n", " [-0.01194615, -0.05189555, 0.00999935, ..., -0.0643564 ,\n", " 0.01144396, -0.02076687],\n", " [-0.01237209, -0.00924738, -0.03677849, ..., 0.00366568,\n", " -0.01869369, 0.05434604],\n", " [ 0.02957141, 0.0040535 , 0.02555469, ..., -0.00918872,\n", " 0.0008155 , 0.0077408 ]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([0.0070804], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.03213442],\n", " [ 0.02066209],\n", " [-0.04244323],\n", " ...,\n", " [-0.0067385 ],\n", " [-0.06366555],\n", " [-0.04282695]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'opt_state'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m0\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptArray\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(5000, dtype=int32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'mu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-2.3567886e-06, -3.8756975e-06, 5.0598974e-06, ...,\n", " 2.5492732e-06, -1.8280221e-06, 2.3327759e-05], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-1.2811696e-05, 1.8371884e-05, -2.8107081e-05, ...,\n", " -7.7499708e-06, 1.2058406e-05, -2.7912660e-05]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-2.4946366e-05, -5.1071429e-06, 0.0000000e+00, ...,\n", " -6.0458387e-07, -1.6253693e-05, 0.0000000e+00], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n", " 1.10227113e-07, -9.35575190e-06, 0.00000000e+00],\n", " [-1.04507053e-05, -2.12880859e-05, 0.00000000e+00, ...,\n", " 3.34136939e-06, -5.51595440e-06, 0.00000000e+00],\n", " [-1.16010815e-05, -1.84936525e-05, 0.00000000e+00, ...,\n", " 2.57636566e-06, -6.29322312e-06, 0.00000000e+00],\n", " ...,\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n", " 0.00000000e+00, -7.95417463e-07, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n", " 1.59682756e-07, -3.07845085e-06, 0.00000000e+00],\n", " [-7.00975761e-06, -1.49751568e-05, 0.00000000e+00, ...,\n", " 2.39362407e-06, -3.57614044e-06, 0.00000000e+00]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([0.0003306], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 2.5945314e-05],\n", " [-6.1757506e-05],\n", " [ 0.0000000e+00],\n", " ...,\n", " [-2.2441673e-05],\n", " [-3.4819095e-06],\n", " [ 0.0000000e+00]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'nu'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([6.5946330e-07, 4.2301590e-08, 1.9096869e-07, ..., 8.1455914e-07,\n", " 2.5145167e-07, 3.8475892e-07], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[2.9782741e-06, 2.5221141e-07, 1.2813606e-06, ..., 4.2202855e-06,\n", " 1.4535699e-06, 1.6272666e-06]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([6.7727728e-07, 1.5464612e-07, 0.0000000e+00, ..., 2.4988884e-08,\n", " 2.2302729e-06, 0.0000000e+00], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[7.1847148e-07, 0.0000000e+00, 0.0000000e+00, ..., 1.9970895e-11,\n", " 1.9745292e-08, 0.0000000e+00],\n", " [2.3998948e-06, 7.2880732e-07, 0.0000000e+00, ..., 9.9980468e-08,\n", " 8.5483980e-06, 0.0000000e+00],\n", " [1.8887340e-06, 5.6486834e-07, 0.0000000e+00, ..., 7.8932914e-08,\n", " 6.7499118e-06, 0.0000000e+00],\n", " ...,\n", " [1.7248375e-08, 0.0000000e+00, 0.0000000e+00, ..., 3.9913174e-13,\n", " 2.2283823e-10, 0.0000000e+00],\n", " [8.6669672e-08, 4.9999395e-17, 0.0000000e+00, ..., 2.8976873e-12,\n", " 3.6252299e-09, 0.0000000e+00],\n", " [1.1716367e-06, 3.5670226e-07, 0.0000000e+00, ..., 4.8790799e-08,\n", " 4.1711064e-06, 0.0000000e+00]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([0.00091323], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptVariable\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[5.6587835e-04],\n", " [2.0583873e-04],\n", " [0.0000000e+00],\n", " ...,\n", " [3.7648741e-04],\n", " [1.6187840e-05],\n", " [0.0000000e+00]], dtype=float32),\n", " \u001b[38;2;156;220;254msource_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'step'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mOptState\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(5000, dtype=uint32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n", "2025-10-05 20:33:17,394 - INFO - EMA state after restore: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'rngs'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'default'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 0],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'dropout'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 0],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'dropout'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'noise'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'count'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=uint32),\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'key'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", " [0 0],\n", " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'noise'\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc1'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-0.03922425, -0.04768839, 0.06973939, ..., -0.05001258,\n", " 0.05696585, -0.04945442], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 1.063954 , -0.9617063 , -0.80732816, ..., 0.170594 ,\n", " 0.35816 , -0.6788137 ]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc2'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([-0.00952129, -0.04323541, 0. , ..., 0.04311211,\n", " 0.04140445, 0. ], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,048,576 (4.2 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 0.03955881, 0.04395215, -0.01376141, ..., -0.05614801,\n", " -0.01507475, -0.02978184],\n", " [-0.01205855, 0.00108851, -0.02547839, ..., -0.04580969,\n", " -0.00915899, 0.00955787],\n", " [-0.01322271, -0.03437575, 0.01826279, ..., 0.02750572,\n", " -0.02593077, -0.05150444],\n", " ...,\n", " [-0.01188646, -0.05154517, 0.00993251, ..., -0.0631559 ,\n", " 0.01273686, -0.0206273 ],\n", " [-0.01222327, -0.00915908, -0.03653297, ..., 0.00275638,\n", " -0.01871123, 0.05397995],\n", " [ 0.02949264, 0.00397549, 0.02538221, ..., -0.00903669,\n", " 0.00084104, 0.00768876]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m'fc3'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([0.00718139], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariableState\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1,024 (4.1 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m,\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.03202476],\n", " [ 0.02044498],\n", " [-0.0421575 ],\n", " ...,\n", " [-0.0065837 ],\n", " [-0.06370619],\n", " [-0.04253863]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "latest_step = args.steps\n", "\n", "opt_rngs, opt_state_no_rngs = nnx.filter_state(opt_state, nnx.RngKey, ...)\n", "opt_rng_keys = jax.tree.map(jax.random.key_data, opt_rngs)\n", "\n", "ema_rngs, ema_state_no_rngs = nnx.filter_state(ema_state, nnx.RngKey, ...)\n", "ema_rng_keys = jax.tree.map(jax.random.key_data, ema_rngs)\n", "\n", "state_restored = ckpt_mngr.restore(\n", " latest_step,\n", " args=ocp.args.Composite(\n", " opt_state=ocp.args.StandardRestore(opt_state_no_rngs),\n", " ema_state=ocp.args.StandardRestore(ema_state_no_rngs),\n", " opt_rngs=ocp.args.StandardRestore(opt_rng_keys),\n", " ema_rngs=ocp.args.StandardRestore(ema_rng_keys),\n", " ),\n", ")\n", "opt_state_no_rngs, ema_state_no_rngs, opt_rngs_keys, ema_rngs_keys = (\n", " state_restored.opt_state,\n", " state_restored.ema_state,\n", " state_restored.opt_rngs,\n", " state_restored.ema_rngs,\n", ")\n", "opt_rngs = jax.tree_map(jax.random.wrap_key_data, opt_rngs_keys)\n", "ema_rngs = jax.tree_map(jax.random.wrap_key_data, ema_rngs_keys)\n", "opt_state = nnx.merge_state(opt_state_no_rngs, opt_rngs)\n", "ema_state = nnx.merge_state(ema_state_no_rngs, ema_rngs)\n", "if jax.process_index() == 0:\n", " logging.info(\"Checkpoint restored successfully\")\n", " log_shard_map(\"Opt state sharding after restore\", opt_state)\n", " log_shard_map(\"EMA state sharding after restore\", ema_state)\n", " logging.info(f\"Opt state after restore: {opt_state}\")\n", " logging.info(f\"EMA state after restore: {ema_state}\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VelX4aqxex_A" }, "source": [ "Now we run `train_loop()` once again using the states we just loaded and starting from the last step." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XdwElpIsmRsz", "outputId": "21011bd5-50b8-41c1-97db-ca4181011655" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-05 20:33:18,505 - INFO - Step 5100, Train Loss: 0.000881\n", "2025-10-05 20:33:19,607 - INFO - Step 5200, Train Loss: 0.001370\n", "2025-10-05 20:33:20,711 - INFO - Step 5300, Train Loss: 0.000773\n", "2025-10-05 20:33:21,810 - INFO - Step 5400, Train Loss: 0.004323\n", "2025-10-05 20:33:22,912 - INFO - Step 5500, Train Loss: 0.001996\n", "2025-10-05 20:33:24,012 - INFO - Step 5600, Train Loss: 0.002029\n", "2025-10-05 20:33:25,284 - INFO - Step 5700, Train Loss: 0.003011\n", "2025-10-05 20:33:26,391 - INFO - Step 5800, Train Loss: 0.000485\n", "2025-10-05 20:33:27,494 - INFO - Step 5900, Train Loss: 0.000653\n", "2025-10-05 20:33:28,596 - INFO - Step 6000, Train Loss: 0.000697\n", "2025-10-05 20:33:28,982 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_6000.png\n", "2025-10-05 20:33:28,983 - INFO - Step 6000, Test Loss: 0.000865, EMA Test Loss: 0.000275\n", "2025-10-05 20:33:30,091 - INFO - Step 6100, Train Loss: 0.001780\n", "2025-10-05 20:33:31,201 - INFO - Step 6200, Train Loss: 0.000800\n", "2025-10-05 20:33:32,310 - INFO - Step 6300, Train Loss: 0.000672\n", "2025-10-05 20:33:33,419 - INFO - Step 6400, Train Loss: 0.000786\n", "2025-10-05 20:33:34,525 - INFO - Step 6500, Train Loss: 0.000826\n", "2025-10-05 20:33:35,634 - INFO - Step 6600, Train Loss: 0.000914\n", "2025-10-05 20:33:36,745 - INFO - Step 6700, Train Loss: 0.000973\n", "2025-10-05 20:33:37,858 - INFO - Step 6800, Train Loss: 0.002495\n", "2025-10-05 20:33:38,979 - INFO - Step 6900, Train Loss: 0.000930\n", "2025-10-05 20:33:40,244 - INFO - Step 7000, Train Loss: 0.000603\n", "2025-10-05 20:33:40,627 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_7000.png\n", "2025-10-05 20:33:40,628 - INFO - Step 7000, Test Loss: 0.000240, EMA Test Loss: 0.000210\n", "2025-10-05 20:33:41,729 - INFO - Step 7100, Train Loss: 0.000742\n", "2025-10-05 20:33:42,832 - INFO - Step 7200, Train Loss: 0.000778\n", "2025-10-05 20:33:43,945 - INFO - Step 7300, Train Loss: 0.001267\n", "2025-10-05 20:33:45,061 - INFO - Step 7400, Train Loss: 0.000588\n", "2025-10-05 20:33:46,161 - INFO - Step 7500, Train Loss: 0.001284\n", "2025-10-05 20:33:46,161 - INFO - Saving checkpoint at step 7500\n", "2025-10-05 20:33:46,163 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.\n", "2025-10-05 20:33:46,163 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 6\n", "2025-10-05 20:33:46,223 - INFO - [process=0] Saving checkpoint at step 7500\n", "2025-10-05 20:33:46,223 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500.\n", "2025-10-05 20:33:46,332 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500\n", "2025-10-05 20:33:46,553 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs\n", "2025-10-05 20:33:46,554 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs\n", "2025-10-05 20:33:46,555 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state\n", "2025-10-05 20:33:46,558 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state\n", "2025-10-05 20:33:46,565 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:46,566 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.001162s\n", "2025-10-05 20:33:46,566 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:46,572 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.006226s\n", "2025-10-05 20:33:46,575 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:46,578 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.002278s\n", "2025-10-05 20:33:46,578 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:33:46,589 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.010375s\n", "2025-10-05 20:33:46,593 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 700 Bytes/s (total bytes: 24 Bytes) (time elapsed: 34 milliseconds) (per-host)\n", "2025-10-05 20:33:46,594 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.035200s (batch_requests_ready=0.000461s, total_serialization_initiated=0.031193s, others=0.003546s)\n", "2025-10-05 20:33:46,596 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 111.4 MiB/s (total bytes: 4.0 MiB) (time elapsed: 36 milliseconds) (per-host)\n", "2025-10-05 20:33:46,597 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.036603s (batch_requests_ready=0.000614s, total_serialization_initiated=0.034220s, others=0.001769s)\n", "2025-10-05 20:33:46,599 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 643 Bytes/s (total bytes: 24 Bytes) (time elapsed: 37 milliseconds) (per-host)\n", "2025-10-05 20:33:46,599 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.038188s (batch_requests_ready=0.000318s, total_serialization_initiated=0.036442s, others=0.001429s)\n", "2025-10-05 20:33:46,601 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 313.0 MiB/s (total bytes: 12.0 MiB) (time elapsed: 38 milliseconds) (per-host)\n", "2025-10-05 20:33:46,602 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.039559s (batch_requests_ready=0.001333s, total_serialization_initiated=0.036562s, others=0.001664s)\n", "2025-10-05 20:33:46,602 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.105153s (all_items=0.000026s, per_item={'ema_rngs': '0.00001717', 'ema_state': '0.00000381', 'opt_rngs': '0.00000238', 'opt_state': '0.00000262'}, temp_paths=0.105127)\n", "2025-10-05 20:33:46,673 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs/array_metadatas/process_0\n", "2025-10-05 20:33:46,677 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs/array_metadatas/process_0\n", "2025-10-05 20:33:46,684 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state/array_metadatas/process_0\n", "2025-10-05 20:33:46,692 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state/array_metadatas/process_0\n", "2025-10-05 20:33:47,207 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.611453s (commit=0.420513s, array_metadata_write=0.190940s)\n", "2025-10-05 20:33:47,208 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 37 Bytes/s (total bytes: 24 Bytes) (time elapsed: 648 milliseconds) (per-host)\n", "2025-10-05 20:33:47,340 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.740908s (commit=0.551089s, array_metadata_write=0.189819s)\n", "2025-10-05 20:33:47,445 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.847461s (commit=0.641053s, array_metadata_write=0.206408s)\n", "2025-10-05 20:33:47,445 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.5 MiB/s (total bytes: 4.0 MiB) (time elapsed: 885 milliseconds) (per-host)\n", "2025-10-05 20:33:47,446 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 27 Bytes/s (total bytes: 24 Bytes) (time elapsed: 884 milliseconds) (per-host)\n", "2025-10-05 20:33:47,469 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.867791s (commit=0.645247s, array_metadata_write=0.222543s)\n", "2025-10-05 20:33:47,470 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.3 MiB/s (total bytes: 12.0 MiB) (time elapsed: 907 milliseconds) (per-host)\n", "2025-10-05 20:33:47,936 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:48,081 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.222106s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs\n", "2025-10-05 20:33:48,082 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_rngs\n", "2025-10-05 20:33:48,248 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:48,414 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.239519s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state\n", "2025-10-05 20:33:48,414 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/ema_state\n", "2025-10-05 20:33:48,602 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:48,732 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.224504s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs\n", "2025-10-05 20:33:48,733 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_rngs\n", "2025-10-05 20:33:48,900 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:33:49,052 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.223419s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state\n", "2025-10-05 20:33:49,052 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500/opt_state\n", "2025-10-05 20:33:49,141 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500\n", "2025-10-05 20:33:49,469 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500`.\n", "2025-10-05 20:33:49,469 - INFO - Finished synchronous save in 3.25 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_7500\n", "2025-10-05 20:33:49,630 - INFO - Deleted step 2500.\n", "2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] CheckpointManager Save Finalize is syncing with other hosts...\n", "2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] CheckpointManager Save Finalize is done on all hosts.\n", "2025-10-05 20:33:49,631 - INFO - [process=0][thread=MainThread][step=7500] Finished synchronous save.\n", "2025-10-05 20:33:49,631 - INFO - {'step': 7500, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696426.1633513, 'wait_for_prev_duration_secs': 0.00028967857360839844, 'checkpointer_blocking_start_time': 1759696426.223418, 'checkpointer_blocking_duration_secs': 3.2464306354522705, 'get_old_steps_start_time': 1759696429.4698665, 'get_old_steps_duration_secs': 9.059906005859375e-05, 'checkpoint_manager_blocking_start_time': 1759696426.162882, 'checkpoint_manager_blocking_duration_secs': 3.4689133167266846}\n", "2025-10-05 20:33:49,632 - INFO - Checkpoint saved successfully\n", "2025-10-05 20:33:50,739 - INFO - Step 7600, Train Loss: 0.000762\n", "2025-10-05 20:33:51,843 - INFO - Step 7700, Train Loss: 0.000554\n", "2025-10-05 20:33:52,947 - INFO - Step 7800, Train Loss: 0.000492\n", "2025-10-05 20:33:54,050 - INFO - Step 7900, Train Loss: 0.000939\n", "2025-10-05 20:33:55,155 - INFO - Step 8000, Train Loss: 0.000522\n", "2025-10-05 20:33:55,540 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_8000.png\n", "2025-10-05 20:33:55,541 - INFO - Step 8000, Test Loss: 0.000316, EMA Test Loss: 0.000152\n", "2025-10-05 20:33:56,804 - INFO - Step 8100, Train Loss: 0.001092\n", "2025-10-05 20:33:57,912 - INFO - Step 8200, Train Loss: 0.000585\n", "2025-10-05 20:33:59,020 - INFO - Step 8300, Train Loss: 0.001435\n", "2025-10-05 20:34:00,132 - INFO - Step 8400, Train Loss: 0.001997\n", "2025-10-05 20:34:01,242 - INFO - Step 8500, Train Loss: 0.001370\n", "2025-10-05 20:34:02,355 - INFO - Step 8600, Train Loss: 0.000390\n", "2025-10-05 20:34:03,460 - INFO - Step 8700, Train Loss: 0.000709\n", "2025-10-05 20:34:04,571 - INFO - Step 8800, Train Loss: 0.001060\n", "2025-10-05 20:34:05,679 - INFO - Step 8900, Train Loss: 0.000630\n", "2025-10-05 20:34:06,789 - INFO - Step 9000, Train Loss: 0.000743\n", "2025-10-05 20:34:07,168 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_9000.png\n", "2025-10-05 20:34:07,169 - INFO - Step 9000, Test Loss: 0.000382, EMA Test Loss: 0.000085\n", "2025-10-05 20:34:08,279 - INFO - Step 9100, Train Loss: 0.000958\n", "2025-10-05 20:34:09,400 - INFO - Step 9200, Train Loss: 0.000821\n", "2025-10-05 20:34:10,507 - INFO - Step 9300, Train Loss: 0.000445\n", "2025-10-05 20:34:11,772 - INFO - Step 9400, Train Loss: 0.001565\n", "2025-10-05 20:34:12,879 - INFO - Step 9500, Train Loss: 0.000859\n", "2025-10-05 20:34:13,987 - INFO - Step 9600, Train Loss: 0.001134\n", "2025-10-05 20:34:15,096 - INFO - Step 9700, Train Loss: 0.000996\n", "2025-10-05 20:34:16,199 - INFO - Step 9800, Train Loss: 0.001920\n", "2025-10-05 20:34:17,310 - INFO - Step 9900, Train Loss: 0.000885\n", "2025-10-05 20:34:18,420 - INFO - Step 10000, Train Loss: 0.000495\n", "2025-10-05 20:34:18,799 - INFO - Plot saved to /home/georgy/fsdp-in-jax-nnx/outputs/fsdp/eval_10000.png\n", "2025-10-05 20:34:18,799 - INFO - Step 10000, Test Loss: 0.000113, EMA Test Loss: 0.000071\n", "2025-10-05 20:34:18,800 - INFO - Saving checkpoint at step 10000\n", "2025-10-05 20:34:18,804 - INFO - [process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.\n", "2025-10-05 20:34:18,804 - INFO - [process=0] Created OpTracker for checkpoint_manager_save with operation id 8\n", "2025-10-05 20:34:18,858 - INFO - [process=0] Saving checkpoint at step 10000\n", "2025-10-05 20:34:18,859 - INFO - [process=0] Started saving checkpoint to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000.\n", "2025-10-05 20:34:18,967 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000\n", "2025-10-05 20:34:19,157 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs\n", "2025-10-05 20:34:19,163 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state\n", "2025-10-05 20:34:19,164 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs\n", "2025-10-05 20:34:19,169 - INFO - Creating tmp directory gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state\n", "2025-10-05 20:34:19,176 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:34:19,177 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.001225s\n", "2025-10-05 20:34:19,179 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:34:19,184 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.005068s\n", "2025-10-05 20:34:19,185 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:34:19,188 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.003032s\n", "2025-10-05 20:34:19,191 - INFO - Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=False\n", "2025-10-05 20:34:19,205 - INFO - [process=0][thread=MainThread] Initiated \"orbax.checkpoint._src.serialization.type_handlers.ArrayHandler\".serialize. Time taken: 0.014571s\n", "2025-10-05 20:34:19,211 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 587 Bytes/s (total bytes: 24 Bytes) (time elapsed: 40 milliseconds) (per-host)\n", "2025-10-05 20:34:19,212 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.041266s (batch_requests_ready=0.000434s, total_serialization_initiated=0.039421s, others=0.001411s)\n", "2025-10-05 20:34:19,213 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 95.3 MiB/s (total bytes: 4.0 MiB) (time elapsed: 42 milliseconds) (per-host)\n", "2025-10-05 20:34:19,214 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.042618s (batch_requests_ready=0.000598s, total_serialization_initiated=0.040283s, others=0.001736s)\n", "2025-10-05 20:34:19,215 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 558 Bytes/s (total bytes: 24 Bytes) (time elapsed: 42 milliseconds) (per-host)\n", "2025-10-05 20:34:19,216 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.043751s (batch_requests_ready=0.000308s, total_serialization_initiated=0.041589s, others=0.001853s)\n", "2025-10-05 20:34:19,218 - INFO - [process=0] /jax/checkpoint/write/blocking_bytes_per_sec: 272.1 MiB/s (total bytes: 12.0 MiB) (time elapsed: 44 milliseconds) (per-host)\n", "2025-10-05 20:34:19,219 - INFO - [process=0][thread=MainThread] Initiated Pytree async_save. Time taken: 0.045207s (batch_requests_ready=0.001276s, total_serialization_initiated=0.041893s, others=0.002038s)\n", "2025-10-05 20:34:19,219 - INFO - [process=0][thread=MainThread] Initiated CompositeCheckpointHandler.async_save. Time taken: 0.112145s (all_items=0.000021s, per_item={'ema_rngs': '0.00001311', 'ema_state': '0.00000334', 'opt_rngs': '0.00000262', 'opt_state': '0.00000215'}, temp_paths=0.112123)\n", "2025-10-05 20:34:19,278 - INFO - [process=0][thread=array_type_handler] Wrote 9 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state/array_metadatas/process_0\n", "2025-10-05 20:34:19,280 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs/array_metadatas/process_0\n", "2025-10-05 20:34:19,282 - INFO - [process=0][thread=array_type_handler] Wrote 3 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs/array_metadatas/process_0\n", "2025-10-05 20:34:19,324 - INFO - [process=0][thread=array_type_handler] Wrote 23 array_metadata.ArrayMetadata to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state/array_metadatas/process_0\n", "2025-10-05 20:34:19,821 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.608847s (commit=0.419599s, array_metadata_write=0.189248s)\n", "2025-10-05 20:34:19,822 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 36 Bytes/s (total bytes: 24 Bytes) (time elapsed: 651 milliseconds) (per-host)\n", "2025-10-05 20:34:19,980 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.763334s (commit=0.564349s, array_metadata_write=0.198985s)\n", "2025-10-05 20:34:19,993 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.778045s (commit=0.586967s, array_metadata_write=0.191078s)\n", "2025-10-05 20:34:19,993 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 4.9 MiB/s (total bytes: 4.0 MiB) (time elapsed: 821 milliseconds) (per-host)\n", "2025-10-05 20:34:19,994 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 29 Bytes/s (total bytes: 24 Bytes) (time elapsed: 821 milliseconds) (per-host)\n", "2025-10-05 20:34:20,073 - INFO - [process=0][thread=write_metadata_after_commits] Commit + Array metadata written. Time taken: 0.854778s (commit=0.669826s, array_metadata_write=0.184951s)\n", "2025-10-05 20:34:20,074 - INFO - [process=0] /jax/checkpoint/write/bytes_per_sec: 13.4 MiB/s (total bytes: 12.0 MiB) (time elapsed: 900 milliseconds) (per-host)\n", "2025-10-05 20:34:20,528 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:34:20,706 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.244012s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs\n", "2025-10-05 20:34:20,706 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_rngs\n", "2025-10-05 20:34:20,866 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:34:21,034 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.241950s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state\n", "2025-10-05 20:34:21,034 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/ema_state\n", "2025-10-05 20:34:21,211 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:34:21,360 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.232543s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs\n", "2025-10-05 20:34:21,361 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_rngs\n", "2025-10-05 20:34:21,538 - INFO - [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.\n", "2025-10-05 20:34:21,685 - INFO - [process=0][thread=MainThread] Pytree save finalize (merge_ocdbt + ArrayMetadata validation) completed. Time taken: 0.231616s. use_zarr3=False, enable_post_merge_validation=True, directory=gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state\n", "2025-10-05 20:34:21,686 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000/opt_state\n", "2025-10-05 20:34:21,771 - INFO - Finalizing gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000\n", "2025-10-05 20:34:22,091 - INFO - [process=0][thread=MainThread] Finished saving checkpoint (finalized tmp dir) to `gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000`.\n", "2025-10-05 20:34:22,092 - INFO - Finished synchronous save in 3.23 seconds to gs://solaris-east5/georgy/fsdp-jax/checkpoints/fsdp_10000\n", "2025-10-05 20:34:22,198 - INFO - Deleted step 5000.\n", "2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] CheckpointManager Save Finalize is syncing with other hosts...\n", "2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] CheckpointManager Save Finalize is done on all hosts.\n", "2025-10-05 20:34:22,199 - INFO - [process=0][thread=MainThread][step=10000] Finished synchronous save.\n", "2025-10-05 20:34:22,200 - INFO - {'step': 10000, 'event_type': 'save', 'directory': 'gs://solaris-east5/georgy/fsdp-jax/checkpoints', 'reached_preemption': False, 'preemption_received_at': None, 'synchronous': True, 'wait_for_prev_start_time': 1759696458.8044248, 'wait_for_prev_duration_secs': 0.00032973289489746094, 'checkpointer_blocking_start_time': 1759696458.858993, 'checkpointer_blocking_duration_secs': 3.2333545684814453, 'get_old_steps_start_time': 1759696462.0923612, 'get_old_steps_duration_secs': 9.632110595703125e-05, 'checkpoint_manager_blocking_start_time': 1759696458.803941, 'checkpoint_manager_blocking_duration_secs': 3.396230697631836}\n", "2025-10-05 20:34:22,200 - INFO - Checkpoint saved successfully\n" ] } ], "source": [ "start_step = latest_step\n", "train_loop(start_step, opt_state, ema_state)" ] }, { "cell_type": "markdown", "metadata": { "id": "lu8_2Y3pe89P" }, "source": [ "That's it. Our model has trained for another `5000` steps." ] }, { "cell_type": "markdown", "metadata": { "id": "gJv0M6mThO5K" }, "source": [ "As a final note, I wanted to share some JAX findings I discovered through my work with it that might be helpful to whoever is reading this.\n", "\n", "When you call JITed functions, like we do with `train_step_fn()` inside our `train_loop()` function, JAX launches it asynchronously, so your host Python code continues right away. This makes JAX run as much in parallel as possible. JAX will synchronize everything later when your next JAX function requires the output of the previous function. This means that if you want to measure the execution time of your `train_step_fn()` and do `time.time()` before it and after, you will see an unrealistically small number. To make your measurement legit, you need to tell JAX to block for the output with `jax.block_until_ready(opt_state)` before you make your second time observation. However, this should only be done for temporary debugging, and your production code should not have any `block_until_ready()` in our training loop because it breaks the internal JAX parallelization and makes your training loop run slower overall.\n", "\n", "You cannot do `logging.into()` or `print()` inside your JITed functions because they don't run in Python. To print some debugging info, you need to use `jax.debug.print()` instead. However, there are two things to note. First, if your function fails somewhere after the print and never finishes, you won't see that print. The JITed function needs to finish for you to see all of its prints. Second, having any `jax.debug.*` function calls in your JITed function slows down its execution significantly. Make sure they never make it into your production code." ] }, { "cell_type": "markdown", "metadata": { "id": "P_pBALH3kaFg" }, "source": [ "This concludes our JAX FSDP tutorial. I hope it was useful. Happy JAXing!" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [ "1mr5W7P-Ybhq" ], "gpuType": "V5E1", "provenance": [] }, "kernelspec": { "display_name": "fsdp-jax-notebook", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 0 }