Evaluation API#

The eval module provides evaluation metrics and tools for generated samples.

Frechet Inception Distance#

File containing evaluation code for calculating the FID score.

eval.fid.calculate_stats_for_iterable(image_iter: Iterable[Array] | Array, detector: Callable[[dict, Array], Array], detector_params: dict, batch_size: int = 64, num_eval_images: int | None = None, verbose: bool = False) dict[str, ndarray][source]#

Calculate the statistics for an iterable of images. This function is ddp-agnostic.

Parameters:
  • image_iter (-) – Iterable / Array of images to calculate statistics for.

  • detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.

  • detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.

  • batch_size (-) – Batch size for processing images.

  • num_eval_images (-) – Total number of images to evaluate

  • verbose (-) – Whether to print verbose output.

Returns:

stats, Inception statistics for the images.

Return type:

  • dict[str, np.ndarray]

eval.fid.calculate_real_stats(config: ConfigDict, dataset: Dataset, detector: Callable[[dict, Array], Array], detector_params: dict, verbose: bool = False) dict[str, ndarray][source]#

Calculate the statistics for real images.

Parameters:
  • config (-) – Overall config for experiment.

  • dataset (-) – Image Dataset to calculate statistics for.

  • detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.

  • detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.

  • verbose (-) – Whether to print verbose output.

Returns:

stats, Inception statistics for the images.

Return type:

  • dict[str, np.ndarray]

eval.fid.calculate_cls_fake_stats(config: ConfigDict, rngs: Rngs, sampler: Samplers, generator: Module, encoder: Module, detector: Callable[[dict, Array], Array], detector_params: dict, guide_generator: Module | None = None, guidance_scale: float = 1.0, all_eval_sample_nums: list[int] = [50000], save_samples_path: str | None = None, mesh: Mesh | None = None) dict[str, ndarray][source]#

Extract and calculate the statistics for class-conditioned synthesized images.

Parameters:
  • config (-) – Overall config for experiment.

  • rng (-) – nnx Rngs stream for random number generation.

  • generator (-) – Generator.

  • encoder (-) – Encoder.

  • detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.

  • detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.

  • guide_generator (-) – Guiding generator.

  • guidance_scale (-) – scale for generation guidance.

  • all_eval_sample_nums (-) – a list of number of total samples to generate.

  • save_samples_path (-) – Path to save the samples.

  • mesh (-) – Mesh for distributed sampling.

Returns:

stats, Inception statistics for the images.

Return type:

  • dict[str, np.ndarray]

eval.fid.calculate_fid(config: ConfigDict, dataset: Dataset, sampler: Samplers, generator: Module, encoder: Module, guidance_scale: float = 1.0, guide_generator: Module | None = None, sample_sizes: list[int] = [10000], step: int = 0, mesh: Mesh | None = None) dict[str, float][source]#

Calculate the FID score betwee the synthesized images and real dataset.

Utility#

File containing the util functions for evaluation.

eval.utils.get(dictionary, key)[source]#

Get a value from a dictionary. If value not present, default to None.

Returns:

value, the value from the dictionary or None if not present.

Return type:

  • Any

eval.utils.download(url, ckpt_dir=None)[source]#

Download a file from a URL to ckpt_dir.

Returns:

ckpt_file, path to the downloaded checkpoint file.

Return type:

  • str

eval.utils.all_gather(x: Array) Array[source]#

convenient wrapper for jax.lax.all_gather

Returns:

all_gathered, the gathered array from all devices.

Return type:

  • jnp.ndarray

eval.utils.lock()[source]#

Hold the lock until all processes sync up.

eval.utils.build_keep_indices(item_subset: list[int], batch_size: int, len_dataset: int)[source]#

This function simulates the behavior of a DataLoader with the item_subset sampler. The intent is to find and remove the indices of images that are processed twice to avoid biasing FID.

Returns:

final_indices, list of keep indices for each batch.

Return type:

  • list

eval.utils.build_eval_loader(dataset: Dataset, batch_size: int, num_workers: int = 8) DataLoader[source]#

Build the dataloader for evaluation.

Returns:

(loader, keep_indices), the dataloader and keep indices.

Return type:

  • tuple[torch.utils.data.DataLoader, list]

eval.utils.get_detector(config: ConfigDict)[source]#

Get the sampler for fid evaluation.

Returns:

(params, forward), detector parameters and forward function.

Return type:

  • tuple[dict, Callable]

eval.utils.calculate_fid(stats: dict[str, ndarray], ref_stats: dict[str, ndarray]) float[source]#

Calculate the FID score between stats and ref_stats.

Returns:

fid_score, the calculated FID score.

Return type:

  • float