diff --git a/checkpoint/orbax/checkpoint/_src/path/deleter.py b/checkpoint/orbax/checkpoint/_src/path/deleter.py index 5e19800f4..e534759bd 100644 --- a/checkpoint/orbax/checkpoint/_src/path/deleter.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter.py @@ -27,8 +27,9 @@ import jax from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint._src.path import gcs_utils -from orbax.checkpoint._src.path import step as step_lib +from orbax.checkpoint._src.path import storage_backend +from orbax.checkpoint._src.path import step as step_lib # pylint: disable=g-bad-import-order + urlparse = parse.urlparse @@ -95,11 +96,16 @@ def _rmtree(self, path: epath.Path): Args: path: the path to delete. """ - # TODO(b/493110683): Cleanup with refactoring of HNS GCS logic into - # StorageBackend. - if gcs_utils.is_gcs_path(path): - gcs_utils.rmtree(path) - else: + try: + backend = storage_backend.resolve_storage_backend(path) + backend.delete_checkpoint(path) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Failed to delete %s via storage backend: %s. Falling back to' + ' standard rmtree.', + path, + e, + ) path.rmtree() def delete(self, step: int) -> None: @@ -141,9 +147,10 @@ def delete(self, step: int) -> None: ) return + backend = storage_backend.resolve_storage_backend(self._directory) # Attempt to rename using GCS HNS API if configured. if self._todelete_full_path is not None: - if gcs_utils.is_gcs_path(self._directory): + if isinstance(backend, storage_backend.GCSStorageBackend): # This is recommended for GCS buckets with HNS enabled and requires # `_todelete_full_path` to be specified. self._gcs_rename_step(step, delete_target) @@ -151,8 +158,8 @@ def delete(self, step: int) -> None: raise NotImplementedError() # Attempt to rename to local subdirectory using `todelete_subdir` # if configured. - elif self._todelete_subdir is not None and not gcs_utils.is_gcs_path( - self._directory + elif self._todelete_subdir is not None and not isinstance( + backend, storage_backend.GCSStorageBackend ): self._rename_step_to_subdir(step, delete_target) # The final case: fall back to permanent deletion. @@ -228,6 +235,24 @@ def _rename_step_to_subdir(self, step: int, delete_target: epath.Path): def _delete_step_permanently(self, step: int, delete_target: epath.Path): """Permanently deletes a step directory.""" + try: + backend = storage_backend.resolve_storage_backend(delete_target) + backend.delete_checkpoint(delete_target) + logging.info('Deleted step %d via storage backend.', step) + return + except NotImplementedError: + logging.info( + 'Storage backend delete not implemented for %s. Falling back to' + ' standard deletion.', + delete_target, + ) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Failed to delete step %d via storage backend: %s. Falling back to' + ' standard deletion.', + step, + e, + ) self._rmtree(delete_target) logging.info('Deleted step %d.', step) diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py index 961994c82..7bd4c8442 100644 --- a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py @@ -22,6 +22,7 @@ import abc import dataclasses import enum +from typing import Callable from absl import logging from etils import epath @@ -130,9 +131,8 @@ def delete_checkpoint( checkpoint_path: str | epath.PathLike, ) -> None: """Deletes the checkpoint at the given path.""" - raise NotImplementedError( - 'delete_checkpoint is not yet implemented for GCSStorageBackend.' - ) + from orbax.checkpoint._src.path import gcs_utils # pylint: disable=g-import-not-at-top + gcs_utils.rmtree(epath.Path(checkpoint_path)) class LocalStorageBackend(StorageBackend): @@ -167,3 +167,32 @@ def delete_checkpoint( logging.info('Removed old checkpoint (%s)', checkpoint_path) except OSError: logging.exception('Failed to remove checkpoint (%s)', checkpoint_path) + + +ResolverFn = Callable[[str | epath.PathLike], StorageBackend] + +_RESOLVER_FN: ResolverFn | None = None + + +def register_resolver(resolver_fn: ResolverFn) -> None: + """Registers a custom storage backend resolver.""" + global _RESOLVER_FN + _RESOLVER_FN = resolver_fn + + +def resolve_storage_backend( + path: str | epath.PathLike, +) -> StorageBackend: + """Returns a StorageBackend object based on the given path.""" + if _RESOLVER_FN is not None: + try: + return _RESOLVER_FN(path) + except (ValueError, NotImplementedError): + # If the registered resolver doesn't support the path, fall back to + # default resolver. + pass + from orbax.checkpoint._src.path import gcs_utils # pylint: disable=g-import-not-at-top + if gcs_utils.is_gcs_path(epath.Path(path)): + return GCSStorageBackend() + else: + return LocalStorageBackend()