Networks#
The networks module contains various neural network architectures for diffusion models, including transformers, encoders, and decoders.
Overview#
This module provides modular network architectures that can be combined to build complex diffusion models. The main components are:
Transformers: DiT and LightningDiT implementations
Encoders: Pretrained vision encoders (SD-VAE, DINOv2, RAE)
Decoders: Trained decoders for reconstruction
Utilities: Helper functions for model conversion and initialization
Architecture Support#
The library supports both traditional Flax and modern NNX implementations:
Flax Linen: Traditional JAX neural network library
NNX: Next-generation neural network library with PyTorch-like syntax
Performance Considerations#
NNX vs Flax: NNX provides more intuitive syntax but may have different performance characteristics
Memory Usage: Use gradient checkpointing for large models
Compilation: JIT compile models for better performance
Best Practices#
Use NNX for new code: Prefer NNX implementations for new projects
Profile your models: Use JAX profiling tools to identify bottlenecks
Consider model size: Larger models require more memory and computation
Use appropriate encoders: Choose encoders based on your data and requirements