Configuration#
The configs module contains configuration files for experiments and training runs using ml_collections.ConfigDict.
Overview#
This module provides configuration management for:
Model configurations: Architecture hyperparameters and network presets
Training configurations: Learning rates, batch sizes, total steps
Data configurations: Dataset settings and preprocessing
Interface configurations: Diffusion/flow formulation parameters
Sampler configurations: Evaluation and sampling settings
Sharding configurations: Distributed training setup
Available Configurations#
DiT ImageNet Configuration#
Main configuration for DiT training on ImageNet:
from configs.dit_imagenet import get_config
# Load configuration with preset
config = get_config('imagenet_64-B_2')
# Access configuration sections
print(config.network.hidden_size) # 1152
print(config.network.depth) # 12
print(config.network.num_heads) # 16
print(config.data.batch_size) # 64
print(config.interface.train_time_dist_type) # 'logitnormal'
Lightning DiT Configuration#
Configuration for Lightning DiT with continuous time embeddings:
from configs.lightning_dit_imagenet import get_config
config = get_config('imagenet_64-B_2')
# Lightning-specific settings
print(config.network.rope) # True
print(config.network.swiglu) # True
print(config.learning_rate) # 2e-4
REPA Configuration#
Configuration for REPA (Representation Alignment) training:
from configs.dit_imagenet_repa import get_config
config = get_config('imagenet_64-B_2')
# REPA-specific settings
print(config.repa.repa_loss_weight) # 0.1
print(config.repa.feature_dim) # 512
print(config.sampler.sampler_class) # 'euler-maruyama'
MeanFlow Configuration#
Configuration for MeanFlow training:
from configs.mf_imagenet import get_config
config = get_config('imagenet_64-B_2')
# MeanFlow-specific settings
print(config.interface.interface_class) # 'mean_flow'
print(config.sampler.sampler_class) # 'euler_jump'
print(config.network.take_dt) # True
RAE Configuration#
Configuration for RAE (Regularized Autoencoder) training:
from configs.rae_imagenet import get_config
config = get_config('imagenet_64-B_2')
# RAE-specific settings
print(config.encoder.encoder) # 'RAE'
print(config.visualize.reconstruction) # True
print(config.sampler.num_sampling_steps) # 50
Configuration Structure#
All configurations follow a consistent structure with these main sections:
Network Configuration#
config.network = {
'hidden_size': 1152,
'depth': 12,
'num_heads': 16,
'patch_size': 2,
'num_patches': 256,
'class_dropout_prob': 0.1,
'rope': False, # Lightning DiT specific
'swiglu': False, # Lightning DiT specific
'take_dt': False # MeanFlow specific
}
Data Configuration#
config.data = {
'data_dir': '/path/to/imagenet',
'stat_dir': '/path/to/stats',
'batch_size': 64,
'image_size': 64,
'latent_dataset': False,
'num_train_samples': 1281167,
'num_workers': 8
}
Interface Configuration#
config.interface = {
'interface_class': 'sit',
'train_time_dist_type': 'logitnormal',
}
Sampler Configuration#
config.sampler = {
'sampler_class': 'heun',
'num_sampling_steps': 32,
'sampling_time_dist': 'uniform',
'sampling_time_kwargs': {}
}
Usage Examples#
Loading Configurations#
from configs.dit_imagenet import get_config
# Load with preset
config = get_config('imagenet_64-B_2')
# Override specific parameters
config.network.hidden_size = 512
config.data.batch_size = 32
config.learning_rate = 2e-4
Command Line Overrides#
Configurations can be overridden from the command line:
python main.py \
--config=configs/dit_imagenet.py:imagenet_64-B_2 \
--config.network.hidden_size=512 \
--config.data.batch_size=32 \
--config.learning_rate=2e-4
Creating Custom Configurations#
from configs.dit_imagenet import get_config
import ml_collections
def get_custom_config():
# Load base configuration
config = get_config('imagenet_64-B_2')
# Modify for smaller model
config.network.hidden_size = 512
config.network.depth = 8
# Modify for faster training
config.data.batch_size = 32
config.total_steps = 1_000_000
# Add custom settings
config.custom_setting = 'value'
return config
Configuration Files#
The following configuration files are available:
configs/dit_imagenet.py- DiT ImageNet configurationconfigs/lightning_dit_imagenet.py- Lightning DiT configurationconfigs/lightning_ddt_imagenet.py- Lightning DDT configurationconfigs/dit_imagenet_repa.py- REPA configurationconfigs/mf_imagenet.py- MeanFlow configurationconfigs/rae_imagenet.py- RAE configurationconfigs/common_specs.py- Shared building blocks and presets