Source code for utils.visualize

"""File containing utilities for generating visualization."""

# built-in libs
import functools

# external libs
from absl import logging
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import ml_collections

# deps
from utils import wandb_utils, sharding_utils
from samplers import samplers


@functools.partial(
    nnx.jit,
    static_argnums=(6, 7)
)
def sample_fn(net, g_net, encoder, rngs, n, c, guidance_scale, sampler):
    """:meta private:"""
    x = sampler.sample(
        rngs, net, n, y=c, g_net=g_net,
        guidance_scale=guidance_scale
    )
    return encoder.decode(x)


[docs] def visualize( config: ml_collections.ConfigDict, net: nnx.Module, ema_net: nnx.Module, encoder: nnx.Module, sampler: samplers.Samplers, step: int, g_net: nnx.Module | None = None, guidance_scale: float | None = None, mesh: Mesh | None = None, ): """Generate and log samples from the model. Args: - config: configuration for the training. - net: nnx.Module, the network for training. - ema_net: nnx.Module, the ema network. - encoder: nnx.Module, the encoder for training. - n: jnp.ndarray, the initial noise. - c: jnp.ndarray, the initial condition. - guidance_scale: float, the guidance weight for the guidance network. - sampler: samplers.Samplers, the sampler for the network. """ image_size = config.data.image_size // config.encoder.get('downsample_factor', 1) num_samples = config.visualize.num_samples // jax.process_count() # fix rngs for each visualization rngs = nnx.Rngs(config.eval.seed + jax.process_index()) # generate the noise n = jax.random.normal(rngs(), (num_samples, image_size, image_size, config.network.in_channels)) c = jax.random.randint(rngs(), (num_samples, ), 0, config.network.num_classes) n = sharding_utils.make_fsarray_from_local_slice(n, mesh.devices.flatten()) c = sharding_utils.make_fsarray_from_local_slice(c, mesh.devices.flatten()) logging.info("Generating model samples...") net.eval() x = sample_fn( net, g_net if g_net is not None else net, encoder, rngs, n, c, 1.0, sampler ) x = jax.experimental.multihost_utils.process_allgather(x, tiled=True) wandb_utils.log_images(x, 'network', step=step) net.train() logging.info("Generating EMA samples...") x = sample_fn( ema_net, g_net if g_net is not None else ema_net, encoder, rngs, n, c, 1.0, sampler ) x = jax.experimental.multihost_utils.process_allgather(x, tiled=True) wandb_utils.log_images(x, 'ema_network', step=step) if config.visualize.guidance_scale > 1.0 or guidance_scale is not None and guidance_scale > 1.0: logging.info("Generating EMA samples with guidance...") guidance_scale = config.visualize.guidance_scale if guidance_scale is None else guidance_scale x = sample_fn( ema_net, g_net if g_net is not None else ema_net, encoder, rngs, n, c, guidance_scale, sampler ) x = jax.experimental.multihost_utils.process_allgather(x, tiled=True) wandb_utils.log_images(x, f'ema_network_cfg={guidance_scale}', step=step)
[docs] def visualize_reconstruction( config: ml_collections.ConfigDict, encoder: nnx.Module, x: jnp.ndarray, mesh: Mesh | None = None, ): """Reconstruct and log samples from the encoder. Args: - config: the configuration for the training. - encoder: the encoder for the network. - x: the original samples. - mesh: the mesh for the distributed sampling. """ x = sharding_utils.make_fsarray_from_local_slice(x, mesh.devices.flatten()) wandb_utils.log_images( (x.astype(jnp.float32) * 127.5 + 128).clip(0, 255).astype(jnp.uint8), 'encoder_original', step=0 ) logging.info(f"Reconstructing images...") z = encoder.encode(x) x_rec = encoder.decode(z) wandb_utils.log_images(x_rec, 'encoder_reconstructed', step=0)