From 66c58929778a02f26ba0bafdd2512294b32c5fc4 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Thu, 25 Jun 2026 01:50:30 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 937837495 --- .../checkpoint/_src/path/storage_backend.py | 64 ++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py index 961994c824..96deceadb3 100644 --- a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py @@ -26,6 +26,7 @@ from absl import logging from etils import epath from orbax.checkpoint._src.path import atomicity_types +from orbax.checkpoint._src.path import step as step_lib @dataclasses.dataclass(frozen=True) @@ -76,7 +77,7 @@ class CheckpointReadOptions: enable_strong_reads: bool = False -class StorageBackend(abc.ABC): +class StorageBackend[MetadataT: step_lib.Metadata](abc.ABC): """An abstract base class for a storage backend. This class defines a common interface for managing checkpoint paths in @@ -91,6 +92,35 @@ def list_checkpoints( """Lists checkpoints for a given base path and version pattern.""" raise NotImplementedError('Subclasses must provide implementation') + @abc.abstractmethod + def list_metadata( + self, + *, + base_path: epath.PathLike, + step_prefix: str, + include_uncommitted: bool, + step: int | None = None, + index: int | None = None, + glob_pattern: str | None = None, + options: CheckpointReadOptions | None = None, + ) -> list[MetadataT]: + """Low-level method to return metadata based on steps, index and glob. + + Args: + base_path: Base path of a checkpoint sequence. + step_prefix: Prefix of the step name. Usually `step`. + include_uncommitted: Some backends distinguish commited from uncommited + checkpoints. If supported, this flag controls whether we skip or include + uncommitted checkpoints. + step: If provided, filters the results by step. + index: If provided, filters the results by index. + glob_pattern: A glob pattern to filter the results by, based on the + convention on file-naming. Glob can be e.g. '*.step_*' and these are not + regex. + options: Options for reading checkpoints. + """ + raise NotImplementedError('Subclasses must provide implementation') + @abc.abstractmethod def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]: """Returns a TemporaryPath class for the storage backend.""" @@ -125,6 +155,22 @@ def list_checkpoints( 'list_checkpoints is not yet implemented for GCSStorageBackend.' ) + def list_metadata( + self, + base_path: epath.PathLike, + step_prefix: str, + *, + include_uncommitted: bool, + step: int | None = None, + index: int | None = None, + glob_pattern: str | None = None, + options: CheckpointReadOptions | None = None, + ) -> list[step_lib.Metadata]: + """Returns a list of IndexPrefixMetadata for a given path.""" + raise NotImplementedError( + 'list_metadata is not yet implemented for LocalStorageBackend.' + ) + def delete_checkpoint( self, checkpoint_path: str | epath.PathLike, @@ -157,6 +203,22 @@ def list_checkpoints( 'list_checkpoints is not yet implemented for LocalStorageBackend.' ) + def list_metadata( + self, + base_path: epath.PathLike, + step_prefix: str, + *, + include_uncommitted: bool, + step: int | None = None, + index: int | None = None, + glob_pattern: str | None = None, + options: CheckpointReadOptions | None = None, + ) -> list[step_lib.Metadata]: + """Returns a list of IndexPrefixMetadata for a given path.""" + raise NotImplementedError( + 'list_metadata is not yet implemented for LocalStorageBackend.' + ) + def delete_checkpoint( self, checkpoint_path: str | epath.PathLike,