Testing#

This page describes the testing framework and guidelines for the DiffuseNNX library.

Overview#

The library includes comprehensive tests to ensure code quality and correctness:

  • Unit Tests: Individual component testing

  • Integration Tests: End-to-end functionality testing

  • Performance Tests: Benchmarking and performance validation

  • Regression Tests: Preventing regressions in functionality

Running Tests#

Run All Tests#

# Run the complete test suite
python tests/runner.py

Run Specific Test Categories#

# Run interface tests
python tests/interface_tests/meanflow_tests.py

# Run network tests
python tests/network_tests/dit_tests.py

# Run sampler tests
python tests/sampler_tests/sampler_tests.py

Test Structure#

The test suite is organized as follows:

tests/
├── __init__.py
├── runner.py              # Main test runner
├── interface_tests/       # Interface module tests
│   ├── meanflow_tests.py
│   └── sit_tests.py
├── network_tests/         # Network architecture tests
│   ├── dit_tests.py
│   └── encoder_tests.py
├── sampler_tests/         # Sampler tests
│   └── sampler_tests.py
└── utils_tests/           # Utility function tests
    └── checkpoint_tests.py

Writing Tests#

Test Naming Convention#

  • Test files: *_tests.py

  • Test classes: TestClassName

  • Test methods: test_method_name

Example Test Structure#

import unittest
import jax
import jax.numpy as jnp
from interfaces.continuous import SiT

class TestSiT(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures."""
        self.model = SiT(
            input_dim=1152,
            hidden_dim=1152,
            num_layers=4,
            num_heads=8
        )
        self.key = jax.random.PRNGKey(0)

    def test_forward_pass(self):
        """Test forward pass of SiT model."""
        x = jnp.ones((2, 1152))
        t = jnp.ones((2,))

        params = self.model.init(self.key, x, t)
        output = self.model.apply(params, x, t)

        self.assertEqual(output.shape, x.shape)

    def test_parameter_count(self):
        """Test that model has expected number of parameters."""
        x = jnp.ones((1, 1152))
        t = jnp.ones((1,))

        params = self.model.init(self.key, x, t)
        param_count = sum(p.size for p in jax.tree_leaves(params))

        self.assertGreater(param_count, 0)

Test Guidelines#

Deterministic Testing#

Always use fixed random seeds for reproducible tests:

def test_deterministic_sampling(self):
    """Test that sampling is deterministic with same seed."""
    key1 = jax.random.PRNGKey(42)
    key2 = jax.random.PRNGKey(42)

    samples1 = sampler.sample(model, params, key1, batch_size=4)
    samples2 = sampler.sample(model, params, key2, batch_size=4)

    np.testing.assert_array_equal(samples1, samples2)

Fast Tests#

Keep tests fast and focused:

def test_small_model(self):
    """Test with small model for speed."""
    model = SiT(
        input_dim=64,      # Small input
        hidden_dim=64,     # Small hidden dim
        num_layers=2,      # Few layers
        num_heads=4        # Few heads
    )
    # ... test implementation

Comprehensive Coverage#

Test edge cases and error conditions:

def test_invalid_input_shapes(self):
    """Test that invalid inputs raise appropriate errors."""
    with self.assertRaises(ValueError):
        model.apply(params, invalid_input, t)

def test_boundary_conditions(self):
    """Test boundary conditions."""
    # Test with minimum valid input
    min_input = jnp.ones((1, 1152))
    output = model.apply(params, min_input, t)
    self.assertEqual(output.shape, min_input.shape)

Performance Testing#

Benchmark Tests#

Test performance characteristics:

import time

def test_sampling_performance(self):
    """Test that sampling completes within reasonable time."""
    start_time = time.time()

    samples = sampler.sample(
        model, params, key,
        batch_size=16, num_steps=32
    )

    elapsed_time = time.time() - start_time
    self.assertLess(elapsed_time, 10.0)  # Should complete in < 10 seconds

Memory Tests#

Test memory usage:

def test_memory_usage(self):
    """Test that model doesn't use excessive memory."""
    # This is a simplified example
    # In practice, you might use memory profiling tools
    samples = sampler.sample(model, params, key, batch_size=64)
    self.assertIsNotNone(samples)

Continuous Integration#

GitHub Actions#

The project uses GitHub Actions for continuous integration:

name: Tests
on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - name: Set up Python
        uses: actions/setup-python@v2
        with:
          python-version: 3.11
      - name: Install dependencies
        run: pip install -r requirements.txt
      - name: Run tests
        run: python tests/runner.py

Test Coverage#

Coverage Reporting#

Generate test coverage reports:

# Install coverage tools
pip install coverage pytest-cov

# Run tests with coverage
pytest --cov=interfaces --cov=networks --cov=samplers tests/

# Generate HTML coverage report
pytest --cov=interfaces --cov-report=html tests/

Coverage Goals#

Target coverage areas:

  • Core interfaces: 90%+ coverage

  • Network architectures: 85%+ coverage

  • Samplers: 90%+ coverage

  • Utilities: 80%+ coverage

Debugging Tests#

Verbose Output#

Run tests with verbose output:

python -m unittest -v tests.interface_tests.meanflow_tests

Debug Specific Tests#

Debug individual test methods:

if __name__ == "__main__":
    # Run specific test
    unittest.main(argv=[''], exit=False, verbosity=2)

Best Practices#

  1. Write tests first: Use test-driven development

  2. Keep tests simple: One concept per test

  3. Use descriptive names: Test names should explain what they test

  4. Test edge cases: Include boundary conditions and error cases

  5. Maintain tests: Update tests when code changes

  6. Use fixtures: Reuse common test setup code

  7. Mock external dependencies: Isolate units under test