"""File containing the util functions for evaluation."""
# built-in libs
import math
import os
import requests
import tempfile
# external libs
from absl import logging
import flax
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import scipy
import torch
from tqdm import tqdm
# deps
from eval import inception
from samplers import samplers
[docs]
def get(dictionary, key):
"""Get a value from a dictionary. If value not present, default to None.
Returns:
- Any: value, the value from the dictionary or None if not present.
"""
if dictionary is None or key not in dictionary:
return None
return dictionary[key]
[docs]
def download(url, ckpt_dir=None):
"""Download a file from a URL to ckpt_dir.
Returns:
- str: ckpt_file, path to the downloaded checkpoint file.
"""
name = url[url.rfind('/') + 1 : url.rfind('?')]
if ckpt_dir is None:
ckpt_dir = tempfile.gettempdir()
ckpt_dir = os.path.join(ckpt_dir, 'jax_fid')
ckpt_file = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_file):
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
# first create temp file, in case the download fails
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
with open(ckpt_file_temp, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print('An error occured while downloading, please try again.')
if os.path.exists(ckpt_file_temp):
os.remove(ckpt_file_temp)
else:
# if download was successful, rename the temp file
os.rename(ckpt_file_temp, ckpt_file)
return ckpt_file
[docs]
def all_gather(x: jnp.ndarray) -> jnp.ndarray:
"""convenient wrapper for jax.lax.all_gather
Returns:
- jnp.ndarray: all_gathered, the gathered array from all devices.
"""
assert x.shape[0] == jax.local_device_count(), f"Expected first dimension to be the number of local devices, got {x.shape[0]} != {jax.local_device_count()}"
all_gather_fn = lambda x: jax.lax.all_gather(x, axis_name='data', tiled=True)
all_gathered = jax.pmap(all_gather_fn, axis_name='data')(x)[0]
return all_gathered
[docs]
def lock():
"""Hold the lock until all processes sync up."""
all_gather(jnp.ones((jax.local_device_count(),1))).block_until_ready()
[docs]
def build_keep_indices(
item_subset: list[int],
batch_size: int,
len_dataset: int
):
"""
This function simulates the behavior of a DataLoader with the item_subset sampler.
The intent is to find and remove the indices of images that are processed twice to avoid
biasing FID.
Returns:
- list: final_indices, list of keep indices for each batch.
"""
keep_indices = jnp.array(item_subset) < len_dataset
final_indices = []
for batch_start in range(0, keep_indices.shape[0], batch_size):
indices_in = keep_indices[batch_start:batch_start+batch_size].reshape((jax.local_device_count(), -1))
batch_keep_indices = all_gather(indices_in).reshape(-1)
final_indices.append(batch_keep_indices)
return final_indices
[docs]
def build_eval_loader(
dataset: torch.utils.data.Dataset,
batch_size: int,
num_workers: int = 8,
) -> torch.utils.data.DataLoader:
"""Build the dataloader for evaluation.
Returns:
- tuple[torch.utils.data.DataLoader, list]: (loader, keep_indices), the dataloader and keep indices.
"""
dataset_len = len(dataset)
n = jax.process_count()
pad_factor = batch_size
# pad the dataset to be divisible by the batch size and local_device_count:
if (pad_factor // n) % jax.local_device_count() != 0:
pad_factor *= jax.local_device_count()
dataset_len = int(math.ceil(dataset_len / pad_factor)) * pad_factor
item_subset = [(i * n + jax.process_index()) for i in range((dataset_len - 1) // n + 1)]
keep_indices = build_keep_indices(item_subset, batch_size, len(dataset))
item_subset = [i % len(dataset) for i in item_subset]
loader = torch.utils.data.DataLoader(
dataset,
sampler=item_subset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
shuffle=False, # important
worker_init_fn=None,
persistent_workers=True,
timeout=60.0
)
return loader, keep_indices
[docs]
def get_detector(config: ml_collections.ConfigDict):
"""Get the sampler for fid evaluation.
Returns:
- tuple[dict, Callable]: (params, forward), detector parameters and forward function.
"""
if config.eval.detector == 'inception':
logging.info('Loading InceptionV3 model for FID calculation...')
detector = inception.InceptionV3(pretrained=True)
def inception_forward(
renormalize_data: bool = False,
run_all_gather: bool = True
):
"""Forward pass of the inception model to extract features."""
params = detector.init(jax.random.PRNGKey(0), jnp.ones((1, 299, 299, 3)))
params = flax.jax_utils.replicate(params)
def forward(params, x):
if renormalize_data:
x = x.astype(jnp.float32) / 127.5 - 1
# TODO: ablate following resize choices
x = jax.image.resize(x, (x.shape[0], 299, 299, x.shape[-1]), method='bilinear')
features = detector.apply(params, x, train=False).squeeze(axis=(1, 2))
if run_all_gather:
features = jax.lax.all_gather(features, axis_name='data', tiled=True)
return features
return params, jax.pmap(forward, axis_name='data')
params, forward = inception_forward(renormalize_data=True, run_all_gather=True)
logging.info('InceptionV3 model loaded.')
return params, forward
else:
# TODO: add DINOv2
raise NotImplementedError
[docs]
def calculate_fid(
stats: dict[str, np.ndarray],
ref_stats: dict[str, np.ndarray]
) -> float:
"""Calculate the FID score between stats and ref_stats.
Returns:
- float: fid_score, the calculated FID score.
"""
m = np.square(stats['mu'] - ref_stats['mu']).sum()
s, _ = scipy.linalg.sqrtm(np.dot(stats['sigma'], ref_stats['sigma']), disp=False)
return float(np.real(m + np.trace(stats['sigma'] + ref_stats['sigma'] - s * 2)))