Installation#
Prerequisites#
Before installing DiffuseNNX, ensure you have the following:
Python 3.11 or earlier: The codebase requires Python ≤ 3.11
Google Cloud Storage access: Training and evaluation jobs stream checkpoints to a bucket
Weights & Biases authentication: For logging and experiment tracking
Environment Setup#
Google Cloud Storage: Ensure you have access to a Google Cloud Storage Bucket. Once established, run:
gcloud auth application-default login
Weights & Biases: Visit https://wandb.ai/authorize to get your API key, then export it:
export WANDB_API_KEY="your_api_key_here" export WANDB_ENTITY="your_team_name"
Environment Variables: Create a .env file (do not commit this):
# .env (do not commit) export WANDB_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" export WANDB_ENTITY="my-team" export GOOGLE_APPLICATION_CREDENTIALS="$HOME/.config/gcloud/application_default_credentials.json" export GCS_BUCKET="my-gcs-bucket"
Then source it in every shell:
source .env
Installation Methods#
# Clone the repository
git clone https://github.com/willisma/diffuse_nnx.git
cd diffuse_nnx
# Create a virtual environment (recommended)
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt
# Install the package in development mode
pip install -e .
Dependencies#
The main dependencies include:
JAX: The major dependency for distributed training and autograd.
Flax NNX: For neural network definitions.
PyTorch (CPU-only): Mostly for dataloading and validate NNX’s correctness with PyTorch.
TensorFlow (CPU-only): Dependencies for some logging library & TPU utilities.
For GPU users, replace jax[tpu] with jax[cuda12] in the requirements file.
Verification#
Test your installation:
# Test environment variables
python -c "
import os
print('WANDB token loaded:', bool(os.getenv('WANDB_API_KEY')))
print('GCP creds set:', bool(os.getenv('GOOGLE_APPLICATION_CREDENTIALS')))
"
# Test GCS access
gcloud storage ls gs://$GCS_BUCKET
# Test Python imports
python -c "
import jax
import flax
from interfaces.continuous import SiT
print('Installation successful!')
"
Troubleshooting#
Common Issues#
JAX Installation Issues: Make sure you’re using the correct JAX version for your platform (TPU vs GPU vs CPU)
Permission Errors: Ensure your GCS bucket has proper permissions and your credentials are correctly set
Import Errors: Make sure you’ve installed the package in development mode with pip install -e .
Getting Help#
If you encounter issues:
Check the GitHub Issues
Review the troubleshooting section in the README
Ensure all environment variables are properly set
Verify your Python version is ≤ 3.11