Networks#

The networks module contains various neural network architectures for diffusion models, including transformers, encoders, and decoders.

Overview#

This module provides modular network architectures that can be combined to build complex diffusion models. The main components are:

  • Transformers: DiT and LightningDiT implementations

  • Encoders: Pretrained vision encoders (SD-VAE, DINOv2, RAE)

  • Decoders: Trained decoders for reconstruction

  • Utilities: Helper functions for model conversion and initialization

Architecture Support#

The library supports both traditional Flax and modern NNX implementations:

  • Flax Linen: Traditional JAX neural network library

  • NNX: Next-generation neural network library with PyTorch-like syntax

Performance Considerations#

  • NNX vs Flax: NNX provides more intuitive syntax but may have different performance characteristics

  • Memory Usage: Use gradient checkpointing for large models

  • Compilation: JIT compile models for better performance

Best Practices#

  1. Use NNX for new code: Prefer NNX implementations for new projects

  2. Profile your models: Use JAX profiling tools to identify bottlenecks

  3. Consider model size: Larger models require more memory and computation

  4. Use appropriate encoders: Choose encoders based on your data and requirements