Evaluation API#
The eval module provides evaluation metrics and tools for generated samples.
Frechet Inception Distance#
File containing evaluation code for calculating the FID score.
- eval.fid.calculate_stats_for_iterable(image_iter: Iterable[Array] | Array, detector: Callable[[dict, Array], Array], detector_params: dict, batch_size: int = 64, num_eval_images: int | None = None, verbose: bool = False) dict[str, ndarray][source]#
Calculate the statistics for an iterable of images. This function is ddp-agnostic.
- Parameters:
image_iter (-) – Iterable / Array of images to calculate statistics for.
detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.
detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.
batch_size (-) – Batch size for processing images.
num_eval_images (-) – Total number of images to evaluate
verbose (-) – Whether to print verbose output.
- Returns:
stats, Inception statistics for the images.
- Return type:
dict[str, np.ndarray]
- eval.fid.calculate_real_stats(config: ConfigDict, dataset: Dataset, detector: Callable[[dict, Array], Array], detector_params: dict, verbose: bool = False) dict[str, ndarray][source]#
Calculate the statistics for real images.
- Parameters:
config (-) – Overall config for experiment.
dataset (-) – Image Dataset to calculate statistics for.
detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.
detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.
verbose (-) – Whether to print verbose output.
- Returns:
stats, Inception statistics for the images.
- Return type:
dict[str, np.ndarray]
- eval.fid.calculate_cls_fake_stats(config: ConfigDict, rngs: Rngs, sampler: Samplers, generator: Module, encoder: Module, detector: Callable[[dict, Array], Array], detector_params: dict, guide_generator: Module | None = None, guidance_scale: float = 1.0, all_eval_sample_nums: list[int] = [50000], save_samples_path: str | None = None, mesh: Mesh | None = None) dict[str, ndarray][source]#
Extract and calculate the statistics for class-conditioned synthesized images.
- Parameters:
config (-) – Overall config for experiment.
rng (-) – nnx Rngs stream for random number generation.
generator (-) – Generator.
encoder (-) – Encoder.
detector (-) – Function to extract features. Note: detector is assumed to be pmap / pjit’d.
detector_params (-) – Parameters for the detector. Note: detector_params is assumed to be processed to match detector.
guide_generator (-) – Guiding generator.
guidance_scale (-) – scale for generation guidance.
all_eval_sample_nums (-) – a list of number of total samples to generate.
save_samples_path (-) – Path to save the samples.
mesh (-) – Mesh for distributed sampling.
- Returns:
stats, Inception statistics for the images.
- Return type:
dict[str, np.ndarray]
- eval.fid.calculate_fid(config: ConfigDict, dataset: Dataset, sampler: Samplers, generator: Module, encoder: Module, guidance_scale: float = 1.0, guide_generator: Module | None = None, sample_sizes: list[int] = [10000], step: int = 0, mesh: Mesh | None = None) dict[str, float][source]#
Calculate the FID score betwee the synthesized images and real dataset.
Utility#
File containing the util functions for evaluation.
- eval.utils.get(dictionary, key)[source]#
Get a value from a dictionary. If value not present, default to None.
- Returns:
value, the value from the dictionary or None if not present.
- Return type:
Any
- eval.utils.download(url, ckpt_dir=None)[source]#
Download a file from a URL to ckpt_dir.
- Returns:
ckpt_file, path to the downloaded checkpoint file.
- Return type:
str
- eval.utils.all_gather(x: Array) Array[source]#
convenient wrapper for jax.lax.all_gather
- Returns:
all_gathered, the gathered array from all devices.
- Return type:
jnp.ndarray
- eval.utils.build_keep_indices(item_subset: list[int], batch_size: int, len_dataset: int)[source]#
This function simulates the behavior of a DataLoader with the item_subset sampler. The intent is to find and remove the indices of images that are processed twice to avoid biasing FID.
- Returns:
final_indices, list of keep indices for each batch.
- Return type:
list
- eval.utils.build_eval_loader(dataset: Dataset, batch_size: int, num_workers: int = 8) DataLoader[source]#
Build the dataloader for evaluation.
- Returns:
(loader, keep_indices), the dataloader and keep indices.
- Return type:
tuple[torch.utils.data.DataLoader, list]