Source code for utils.gcloud_utils
"""File containing helper functions for accessing Google Cloud Storage."""
# built-in libs
# external libs
# https://stackoverflow.com/a/59008580
from google.api_core import page_iterator
from google.cloud import storage
# deps
def _item_to_value(iterator, item):
""":meta private:"""
return item
[docs]
def list_directories(bucket_name, prefix):
"""List all directories in the given bucket."""
if prefix and not prefix.endswith('/'):
prefix += '/'
extra_params = {
"projection": "noAcl",
"prefix": prefix,
"delimiter": '/'
}
gcs = storage.Client()
path = "/b/" + bucket_name + "/o"
iterator = page_iterator.HTTPIterator(
client=gcs,
api_request=gcs._connection.api_request,
path=path,
items_key='prefixes',
item_to_value=_item_to_value,
extra_params=extra_params,
)
return [x for x in iterator]
[docs]
def count_directories(bucket_name, prefix):
"""
Count the number of directories in the given bucket.
Used to obtain the numeral prefix for the checkpoint.
"""
return len(list_directories(bucket_name, prefix))
[docs]
def directory_exists(bucket_name, prefix, directory):
"""Check wether the given directory exists under the given bucket."""
directories = list_directories(bucket_name, prefix)
directories = [x.split('/')[-2] for x in directories]
directories = [x[4:] for x in directories] # Remove indexing
return directory in directories
[docs]
def get_directory_index(bucket_name, prefix, directory):
"""Get the index of the given directory under the given bucket."""
directories = list_directories(bucket_name, prefix)
directories = [x.split('/')[-2] for x in directories]
directories_without_indices = [x[4:] for x in directories] # Remove indexing
directory_index = directories_without_indices.index(directory)
full_index = directories[directory_index]
return int(full_index[:3])
[docs]
def list_checkpoints(bucket_name, prefix, workdir):
"""List all checkpoints in the given directory."""
if not directory_exists(bucket_name, prefix, workdir):
raise ValueError(f"Directory {workdir} does not exist.")
index = get_directory_index(bucket_name, prefix, workdir)
work_prefix = prefix + f"/{index:03d}_{workdir}"
workdir = "gs://" + bucket_name + "/" + work_prefix
ckpt_iterator = list_directories(bucket_name, work_prefix)
return [x for x in ckpt_iterator], workdir
[docs]
def get_checkpoint_steps(bucket_name, prefix, workdir):
"""Get the postfixed steps of all checkpoints in the given directory."""
ckpts, workdir = list_checkpoints(bucket_name, prefix, workdir)
ckpts = [x.split('/')[-2] for x in ckpts]
ckpts = [x.split('_')[-1] for x in ckpts] # Remove indexing
ckpts = [int(x) for x in ckpts]
return sorted(ckpts), workdir