Configuration#

The configs module contains configuration files for experiments and training runs using ml_collections.ConfigDict.

Overview#

This module provides configuration management for:

  • Model configurations: Architecture hyperparameters and network presets

  • Training configurations: Learning rates, batch sizes, total steps

  • Data configurations: Dataset settings and preprocessing

  • Interface configurations: Diffusion/flow formulation parameters

  • Sampler configurations: Evaluation and sampling settings

  • Sharding configurations: Distributed training setup

Available Configurations#

DiT ImageNet Configuration#

Main configuration for DiT training on ImageNet:

from configs.dit_imagenet import get_config

# Load configuration with preset
config = get_config('imagenet_64-B_2')

# Access configuration sections
print(config.network.hidden_size)  # 1152
print(config.network.depth)        # 12
print(config.network.num_heads)    # 16
print(config.data.batch_size)      # 64
print(config.interface.train_time_dist_type)  # 'logitnormal'

Lightning DiT Configuration#

Configuration for Lightning DiT with continuous time embeddings:

from configs.lightning_dit_imagenet import get_config

config = get_config('imagenet_64-B_2')

# Lightning-specific settings
print(config.network.rope)         # True
print(config.network.swiglu)       # True
print(config.learning_rate)        # 2e-4

REPA Configuration#

Configuration for REPA (Representation Alignment) training:

from configs.dit_imagenet_repa import get_config

config = get_config('imagenet_64-B_2')

# REPA-specific settings
print(config.repa.repa_loss_weight)  # 0.1
print(config.repa.feature_dim)       # 512
print(config.sampler.sampler_class)  # 'euler-maruyama'

MeanFlow Configuration#

Configuration for MeanFlow training:

from configs.mf_imagenet import get_config

config = get_config('imagenet_64-B_2')

# MeanFlow-specific settings
print(config.interface.interface_class)  # 'mean_flow'
print(config.sampler.sampler_class)      # 'euler_jump'
print(config.network.take_dt)            # True

RAE Configuration#

Configuration for RAE (Regularized Autoencoder) training:

from configs.rae_imagenet import get_config

config = get_config('imagenet_64-B_2')

# RAE-specific settings
print(config.encoder.encoder)        # 'RAE'
print(config.visualize.reconstruction)  # True
print(config.sampler.num_sampling_steps)  # 50

Configuration Structure#

All configurations follow a consistent structure with these main sections:

Network Configuration#

config.network = {
    'hidden_size': 1152,
    'depth': 12,
    'num_heads': 16,
    'patch_size': 2,
    'num_patches': 256,
    'class_dropout_prob': 0.1,
    'rope': False,           # Lightning DiT specific
    'swiglu': False,         # Lightning DiT specific
    'take_dt': False         # MeanFlow specific
}

Data Configuration#

config.data = {
    'data_dir': '/path/to/imagenet',
    'stat_dir': '/path/to/stats',
    'batch_size': 64,
    'image_size': 64,
    'latent_dataset': False,
    'num_train_samples': 1281167,
    'num_workers': 8
}

Interface Configuration#

config.interface = {
    'interface_class': 'sit',
    'train_time_dist_type': 'logitnormal',
}

Sampler Configuration#

config.sampler = {
    'sampler_class': 'heun',
    'num_sampling_steps': 32,
    'sampling_time_dist': 'uniform',
    'sampling_time_kwargs': {}
}

Sharding Configuration#

config.sharding = {
    'mesh':  [('data', -1)],
    'data_axis': 'data',
    'strategy': [('.*', 'replicate')],
    'rules': [('act_batch', 'data')]
}

Usage Examples#

Loading Configurations#

from configs.dit_imagenet import get_config

# Load with preset
config = get_config('imagenet_64-B_2')

# Override specific parameters
config.network.hidden_size = 512
config.data.batch_size = 32
config.learning_rate = 2e-4

Command Line Overrides#

Configurations can be overridden from the command line:

python main.py \
  --config=configs/dit_imagenet.py:imagenet_64-B_2 \
  --config.network.hidden_size=512 \
  --config.data.batch_size=32 \
  --config.learning_rate=2e-4

Creating Custom Configurations#

from configs.dit_imagenet import get_config
import ml_collections

def get_custom_config():
    # Load base configuration
    config = get_config('imagenet_64-B_2')

    # Modify for smaller model
    config.network.hidden_size = 512
    config.network.depth = 8

    # Modify for faster training
    config.data.batch_size = 32
    config.total_steps = 1_000_000

    # Add custom settings
    config.custom_setting = 'value'

    return config

Configuration Files#

The following configuration files are available:

  • configs/dit_imagenet.py - DiT ImageNet configuration

  • configs/lightning_dit_imagenet.py - Lightning DiT configuration

  • configs/lightning_ddt_imagenet.py - Lightning DDT configuration

  • configs/dit_imagenet_repa.py - REPA configuration

  • configs/mf_imagenet.py - MeanFlow configuration

  • configs/rae_imagenet.py - RAE configuration

  • configs/common_specs.py - Shared building blocks and presets