Quick Start#

This guide will help you get started with DiffuseNNX quickly. We’ll cover basic usage, training a simple model, and generating samples.

Basic Usage#

Import the necessary modules:

import jax
import jax.numpy as jnp
from networks.transformers.dit_nnx import DiT
from interfaces.continuous import SiTInterface
from samplers.samplers import HeunSampler

Create a simple DiT model:

# Initialize DiT model
model = DiT(
    input_size=32,           # Input size (e.g., for ImageNet latents)
    hidden_size=1152,        # Hidden dimension
    depth=28,                # Number of transformer layers
    num_heads=16,            # Number of attention heads
    patch_size=2,            # Patch size for vision transformer
    num_classes=1000,        # Number of classes
    class_dropout_prob=0.1,  # Dropout rate
    rngs=jax.random.PRNGKey(0)
)

Create an interface and sampler:

# Create SiT interface
interface = SiTInterface(
    network=model,
    train_time_dist_type='uniform'
)

# Create Heun sampler with 32 steps
sampler = HeunSampler(num_steps=32)

# Generate samples
key = jax.random.PRNGKey(42)
rngs = nnx.Rngs(0)
x = jax.random.normal(key, (4, 32, 32, 4))
samples = sampler.sample(rngs, interface, x)

Training a Model#

For training, you’ll need to use the main training script with configuration files. Here’s how to run training:

# Run training with DiT configuration
python main.py \
  --config=configs/dit_imagenet.py:imagenet_64-B_2 \
  --bucket=$GCS_BUCKET \
  --workdir=my_experiment

# Run training with LightningDiT configuration
python main.py \
  --config=configs/lightning_dit_imagenet.py:imagenet_64-B_2 \
  --bucket=$GCS_BUCKET \
  --workdir=my_lightning_experiment

Configuration#

DiffuseNNX uses configuration files to manage hyperparameters and settings. Configuration files are located in the configs/ directory:

from configs.dit_imagenet import get_config

# Get default configuration
config = get_config('imagenet_64-B_2')

# Modify configuration
config.network.hidden_size = 1024
config.total_steps = 1000000
config.data.batch_size = 64

# Configuration structure
print(config.network)      # Model architecture settings
print(config.data)         # Dataset settings
print(config.interface)    # Interface settings

Available Models#

The library supports several model architectures:

  • DiT (Diffusion Transformer): The main diffusion transformer

  • LightningDiT: Faster training variant of DiT with optimizations

  • LightningDDT: Diffusion-decoder transformer variant

  • REPA: Representation alignment wrapper for any interface

Available Interfaces#

The library supports several diffusion and flow matching interfaces:

  • SiTInterface: Straight-through transport with linear interpolation

  • EDMInterface: EDM-style variance preserving diffusion

  • MeanFlowInterface: Mean field flow matching with guidance

  • sCTInterface/sCDInterface: Score-based consistency training

Available Samplers#

Multiple sampling strategies are supported:

  • HeunSampler: Second-order deterministic sampler

  • EulerSampler: First-order deterministic sampler

  • EulerMaruyamaSampler: Stochastic sampler

  • EulerJumpSampler: For two-time variable models (MeanFlow)

Example: Complete Training Script#

Here’s a complete example of training a DiT model:

#!/bin/bash

# Set up environment
export GCS_BUCKET="your-bucket-name"
export WORKDIR="my_dit_experiment"

# Run training with DiT configuration
python main.py \
  --config=configs/dit_imagenet.py:imagenet_64-B_2 \
  --bucket=$GCS_BUCKET \
  --workdir=$WORKDIR \
  --config.total_steps=1000000 \
  --config.data.batch_size=64 \
  --config.log_every_steps=100

Next Steps#

Now that you have the basics, explore:

For advanced usage, see the Contributing guide for extending the library.