Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion checkpoint/orbax/checkpoint/_src/path/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from absl import logging
from etils import epath
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import step as step_lib


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -76,7 +77,7 @@ class CheckpointReadOptions:
enable_strong_reads: bool = False


class StorageBackend(abc.ABC):
class StorageBackend[MetadataT: step_lib.Metadata](abc.ABC):
"""An abstract base class for a storage backend.

This class defines a common interface for managing checkpoint paths in
Expand All @@ -91,6 +92,35 @@ def list_checkpoints(
"""Lists checkpoints for a given base path and version pattern."""
raise NotImplementedError('Subclasses must provide implementation')

@abc.abstractmethod
def list_metadata(
self,
*,
base_path: epath.PathLike,
step_prefix: str,
include_uncommitted: bool,
step: int | None = None,
index: int | None = None,
glob_pattern: str | None = None,
options: CheckpointReadOptions | None = None,
) -> list[MetadataT]:
"""Low-level method to return metadata based on steps, index and glob.

Args:
base_path: Base path of a checkpoint sequence.
step_prefix: Prefix of the step name. Usually `step`.
include_uncommitted: Some backends distinguish commited from uncommited
checkpoints. If supported, this flag controls whether we skip or include
uncommitted checkpoints.
step: If provided, filters the results by step.
index: If provided, filters the results by index.
glob_pattern: A glob pattern to filter the results by, based on the
convention on file-naming. Glob can be e.g. '*.step_*' and these are not
regex.
options: Options for reading checkpoints.
"""
raise NotImplementedError('Subclasses must provide implementation')

@abc.abstractmethod
def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]:
"""Returns a TemporaryPath class for the storage backend."""
Expand Down Expand Up @@ -125,6 +155,22 @@ def list_checkpoints(
'list_checkpoints is not yet implemented for GCSStorageBackend.'
)

def list_metadata(
self,
base_path: epath.PathLike,
step_prefix: str,
*,
include_uncommitted: bool,
step: int | None = None,
index: int | None = None,
glob_pattern: str | None = None,
options: CheckpointReadOptions | None = None,
) -> list[step_lib.Metadata]:
"""Returns a list of IndexPrefixMetadata for a given path."""
raise NotImplementedError(
'list_metadata is not yet implemented for LocalStorageBackend.'
)

def delete_checkpoint(
self,
checkpoint_path: str | epath.PathLike,
Expand Down Expand Up @@ -157,6 +203,22 @@ def list_checkpoints(
'list_checkpoints is not yet implemented for LocalStorageBackend.'
)

def list_metadata(
self,
base_path: epath.PathLike,
step_prefix: str,
*,
include_uncommitted: bool,
step: int | None = None,
index: int | None = None,
glob_pattern: str | None = None,
options: CheckpointReadOptions | None = None,
) -> list[step_lib.Metadata]:
"""Returns a list of IndexPrefixMetadata for a given path."""
raise NotImplementedError(
'list_metadata is not yet implemented for LocalStorageBackend.'
)

def delete_checkpoint(
self,
checkpoint_path: str | epath.PathLike,
Expand Down
Loading