Source code for utils.gcloud_utils
"""File containing helper functions for accessing Google Cloud Storage."""
# built-in libs
# external libs
# https://stackoverflow.com/a/59008580
from google.api_core import page_iterator
from google.cloud import storage
# deps
def _item_to_value(iterator, item):
    """:meta private:"""
    return item
[docs]
def list_directories(bucket_name, prefix):
    """List all directories in the given bucket."""
    if prefix and not prefix.endswith('/'):
        prefix += '/'
    extra_params = {
        "projection": "noAcl",
        "prefix": prefix,
        "delimiter": '/'
    }
    gcs = storage.Client()
    path = "/b/" + bucket_name + "/o"
    iterator = page_iterator.HTTPIterator(
        client=gcs,
        api_request=gcs._connection.api_request,
        path=path,
        items_key='prefixes',
        item_to_value=_item_to_value,
        extra_params=extra_params,
    )
    return [x for x in iterator] 
[docs]
def count_directories(bucket_name, prefix):
    """
    Count the number of directories in the given bucket.
    Used to obtain the numeral prefix for the checkpoint.
    """
    return len(list_directories(bucket_name, prefix)) 
[docs]
def directory_exists(bucket_name, prefix, directory):
    """Check wether the given directory exists under the given bucket."""
    directories = list_directories(bucket_name, prefix)
    directories = [x.split('/')[-2] for x in directories]
    directories = [x[4:] for x in directories]  # Remove indexing
    return directory in directories 
[docs]
def get_directory_index(bucket_name, prefix, directory):
    """Get the index of the given directory under the given bucket."""
    directories = list_directories(bucket_name, prefix)
    directories = [x.split('/')[-2] for x in directories]
    directories_without_indices = [x[4:] for x in directories]  # Remove indexing
    directory_index = directories_without_indices.index(directory)
    full_index = directories[directory_index]
    return int(full_index[:3]) 
[docs]
def list_checkpoints(bucket_name, prefix, workdir):
    """List all checkpoints in the given directory."""
    if not directory_exists(bucket_name, prefix, workdir):
        raise ValueError(f"Directory {workdir} does not exist.")
    
    index = get_directory_index(bucket_name, prefix, workdir)
    work_prefix = prefix + f"/{index:03d}_{workdir}"
    workdir = "gs://" + bucket_name + "/" + work_prefix
    ckpt_iterator = list_directories(bucket_name, work_prefix)
    return [x for x in ckpt_iterator], workdir 
[docs]
def get_checkpoint_steps(bucket_name, prefix, workdir):
    """Get the postfixed steps of all checkpoints in the given directory."""
    ckpts, workdir = list_checkpoints(bucket_name, prefix, workdir)
    ckpts = [x.split('/')[-2] for x in ckpts]
    ckpts = [x.split('_')[-1] for x in ckpts]  # Remove indexing
    ckpts = [int(x) for x in ckpts]
    return sorted(ckpts), workdir