Quick Start#
This guide will help you get started with DiffuseNNX quickly. We’ll cover basic usage, training a simple model, and generating samples.
Basic Usage#
Import the necessary modules:
import jax
import jax.numpy as jnp
from networks.transformers.dit_nnx import DiT
from interfaces.continuous import SiTInterface
from samplers.samplers import HeunSampler
Create a simple DiT model:
# Initialize DiT model
model = DiT(
input_size=32, # Input size (e.g., for ImageNet latents)
hidden_size=1152, # Hidden dimension
depth=28, # Number of transformer layers
num_heads=16, # Number of attention heads
patch_size=2, # Patch size for vision transformer
num_classes=1000, # Number of classes
class_dropout_prob=0.1, # Dropout rate
rngs=jax.random.PRNGKey(0)
)
Create an interface and sampler:
# Create SiT interface
interface = SiTInterface(
network=model,
train_time_dist_type='uniform'
)
# Create Heun sampler with 32 steps
sampler = HeunSampler(num_steps=32)
# Generate samples
key = jax.random.PRNGKey(42)
rngs = nnx.Rngs(0)
x = jax.random.normal(key, (4, 32, 32, 4))
samples = sampler.sample(rngs, interface, x)
Training a Model#
For training, you’ll need to use the main training script with configuration files. Here’s how to run training:
# Run training with DiT configuration
python main.py \
--config=configs/dit_imagenet.py:imagenet_64-B_2 \
--bucket=$GCS_BUCKET \
--workdir=my_experiment
# Run training with LightningDiT configuration
python main.py \
--config=configs/lightning_dit_imagenet.py:imagenet_64-B_2 \
--bucket=$GCS_BUCKET \
--workdir=my_lightning_experiment
Configuration#
DiffuseNNX uses configuration files to manage hyperparameters and settings. Configuration files are located in the configs/ directory:
from configs.dit_imagenet import get_config
# Get default configuration
config = get_config('imagenet_64-B_2')
# Modify configuration
config.network.hidden_size = 1024
config.total_steps = 1000000
config.data.batch_size = 64
# Configuration structure
print(config.network) # Model architecture settings
print(config.data) # Dataset settings
print(config.interface) # Interface settings
Available Models#
The library supports several model architectures:
DiT (Diffusion Transformer): The main diffusion transformer
LightningDiT: Faster training variant of DiT with optimizations
LightningDDT: Diffusion-decoder transformer variant
REPA: Representation alignment wrapper for any interface
Available Interfaces#
The library supports several diffusion and flow matching interfaces:
SiTInterface: Straight-through transport with linear interpolation
EDMInterface: EDM-style variance preserving diffusion
MeanFlowInterface: Mean field flow matching with guidance
sCTInterface/sCDInterface: Score-based consistency training
Available Samplers#
Multiple sampling strategies are supported:
HeunSampler: Second-order deterministic sampler
EulerSampler: First-order deterministic sampler
EulerMaruyamaSampler: Stochastic sampler
EulerJumpSampler: For two-time variable models (MeanFlow)
Example: Complete Training Script#
Here’s a complete example of training a DiT model:
#!/bin/bash
# Set up environment
export GCS_BUCKET="your-bucket-name"
export WORKDIR="my_dit_experiment"
# Run training with DiT configuration
python main.py \
--config=configs/dit_imagenet.py:imagenet_64-B_2 \
--bucket=$GCS_BUCKET \
--workdir=$WORKDIR \
--config.total_steps=1000000 \
--config.data.batch_size=64 \
--config.log_every_steps=100
Next Steps#
Now that you have the basics, explore:
Interfaces - Core interfaces and algorithms
Networks - Model architectures
Samplers - Sampling strategies
For advanced usage, see the Contributing guide for extending the library.