Source code for eval.utils

"""File containing the util functions for evaluation."""

# built-in libs
import math
import os
import requests
import tempfile

# external libs
from absl import logging
import flax
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import scipy
import torch
from tqdm import tqdm

# deps
from eval import inception
from samplers import samplers


[docs] def get(dictionary, key): """Get a value from a dictionary. If value not present, default to None. Returns: - Any: value, the value from the dictionary or None if not present. """ if dictionary is None or key not in dictionary: return None return dictionary[key]
[docs] def download(url, ckpt_dir=None): """Download a file from a URL to ckpt_dir. Returns: - str: ckpt_file, path to the downloaded checkpoint file. """ name = url[url.rfind('/') + 1 : url.rfind('?')] if ckpt_dir is None: ckpt_dir = tempfile.gettempdir() ckpt_dir = os.path.join(ckpt_dir, 'jax_fid') ckpt_file = os.path.join(ckpt_dir, name) if not os.path.exists(ckpt_file): print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) # first create temp file, in case the download fails ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') with open(ckpt_file_temp, 'wb') as file: for data in response.iter_content(chunk_size=1024): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print('An error occured while downloading, please try again.') if os.path.exists(ckpt_file_temp): os.remove(ckpt_file_temp) else: # if download was successful, rename the temp file os.rename(ckpt_file_temp, ckpt_file) return ckpt_file
[docs] def all_gather(x: jnp.ndarray) -> jnp.ndarray: """convenient wrapper for jax.lax.all_gather Returns: - jnp.ndarray: all_gathered, the gathered array from all devices. """ assert x.shape[0] == jax.local_device_count(), f"Expected first dimension to be the number of local devices, got {x.shape[0]} != {jax.local_device_count()}" all_gather_fn = lambda x: jax.lax.all_gather(x, axis_name='data', tiled=True) all_gathered = jax.pmap(all_gather_fn, axis_name='data')(x)[0] return all_gathered
[docs] def lock(): """Hold the lock until all processes sync up.""" all_gather(jnp.ones((jax.local_device_count(),1))).block_until_ready()
[docs] def build_keep_indices( item_subset: list[int], batch_size: int, len_dataset: int ): """ This function simulates the behavior of a DataLoader with the item_subset sampler. The intent is to find and remove the indices of images that are processed twice to avoid biasing FID. Returns: - list: final_indices, list of keep indices for each batch. """ keep_indices = jnp.array(item_subset) < len_dataset final_indices = [] for batch_start in range(0, keep_indices.shape[0], batch_size): indices_in = keep_indices[batch_start:batch_start+batch_size].reshape((jax.local_device_count(), -1)) batch_keep_indices = all_gather(indices_in).reshape(-1) final_indices.append(batch_keep_indices) return final_indices
[docs] def build_eval_loader( dataset: torch.utils.data.Dataset, batch_size: int, num_workers: int = 8, ) -> torch.utils.data.DataLoader: """Build the dataloader for evaluation. Returns: - tuple[torch.utils.data.DataLoader, list]: (loader, keep_indices), the dataloader and keep indices. """ dataset_len = len(dataset) n = jax.process_count() pad_factor = batch_size # pad the dataset to be divisible by the batch size and local_device_count: if (pad_factor // n) % jax.local_device_count() != 0: pad_factor *= jax.local_device_count() dataset_len = int(math.ceil(dataset_len / pad_factor)) * pad_factor item_subset = [(i * n + jax.process_index()) for i in range((dataset_len - 1) // n + 1)] keep_indices = build_keep_indices(item_subset, batch_size, len(dataset)) item_subset = [i % len(dataset) for i in item_subset] loader = torch.utils.data.DataLoader( dataset, sampler=item_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=False, shuffle=False, # important worker_init_fn=None, persistent_workers=True, timeout=60.0 ) return loader, keep_indices
[docs] def get_detector(config: ml_collections.ConfigDict): """Get the sampler for fid evaluation. Returns: - tuple[dict, Callable]: (params, forward), detector parameters and forward function. """ if config.eval.detector == 'inception': logging.info('Loading InceptionV3 model for FID calculation...') detector = inception.InceptionV3(pretrained=True) def inception_forward( renormalize_data: bool = False, run_all_gather: bool = True ): """Forward pass of the inception model to extract features.""" params = detector.init(jax.random.PRNGKey(0), jnp.ones((1, 299, 299, 3))) params = flax.jax_utils.replicate(params) def forward(params, x): if renormalize_data: x = x.astype(jnp.float32) / 127.5 - 1 # TODO: ablate following resize choices x = jax.image.resize(x, (x.shape[0], 299, 299, x.shape[-1]), method='bilinear') features = detector.apply(params, x, train=False).squeeze(axis=(1, 2)) if run_all_gather: features = jax.lax.all_gather(features, axis_name='data', tiled=True) return features return params, jax.pmap(forward, axis_name='data') params, forward = inception_forward(renormalize_data=True, run_all_gather=True) logging.info('InceptionV3 model loaded.') return params, forward else: # TODO: add DINOv2 raise NotImplementedError
[docs] def calculate_fid( stats: dict[str, np.ndarray], ref_stats: dict[str, np.ndarray] ) -> float: """Calculate the FID score between stats and ref_stats. Returns: - float: fid_score, the calculated FID score. """ m = np.square(stats['mu'] - ref_stats['mu']).sum() s, _ = scipy.linalg.sqrtm(np.dot(stats['sigma'], ref_stats['sigma']), disp=False) return float(np.real(m + np.trace(stats['sigma'] + ref_stats['sigma'] - s * 2)))