Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions checkpoint/orbax/checkpoint/_src/path/deleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
import jax
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint._src.path import storage_backend
from orbax.checkpoint._src.path import step as step_lib # pylint: disable=g-bad-import-order



urlparse = parse.urlparse
Expand Down Expand Up @@ -95,11 +96,16 @@ def _rmtree(self, path: epath.Path):
Args:
path: the path to delete.
"""
# TODO(b/493110683): Cleanup with refactoring of HNS GCS logic into
# StorageBackend.
if gcs_utils.is_gcs_path(path):
gcs_utils.rmtree(path)
else:
try:
backend = storage_backend.resolve_storage_backend(path)
backend.delete_checkpoint(path)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Failed to delete %s via storage backend: %s. Falling back to'
' standard rmtree.',
path,
e,
)
path.rmtree()

def delete(self, step: int) -> None:
Expand Down Expand Up @@ -141,18 +147,19 @@ def delete(self, step: int) -> None:
)
return

backend = storage_backend.resolve_storage_backend(self._directory)
# Attempt to rename using GCS HNS API if configured.
if self._todelete_full_path is not None:
if gcs_utils.is_gcs_path(self._directory):
if isinstance(backend, storage_backend.GCSStorageBackend):
# This is recommended for GCS buckets with HNS enabled and requires
# `_todelete_full_path` to be specified.
self._gcs_rename_step(step, delete_target)
else:
raise NotImplementedError()
# Attempt to rename to local subdirectory using `todelete_subdir`
# if configured.
elif self._todelete_subdir is not None and not gcs_utils.is_gcs_path(
self._directory
elif self._todelete_subdir is not None and not isinstance(
backend, storage_backend.GCSStorageBackend
):
self._rename_step_to_subdir(step, delete_target)
# The final case: fall back to permanent deletion.
Expand Down Expand Up @@ -228,6 +235,24 @@ def _rename_step_to_subdir(self, step: int, delete_target: epath.Path):

def _delete_step_permanently(self, step: int, delete_target: epath.Path):
"""Permanently deletes a step directory."""
try:
backend = storage_backend.resolve_storage_backend(delete_target)
backend.delete_checkpoint(delete_target)
logging.info('Deleted step %d via storage backend.', step)
return
except NotImplementedError:
logging.info(
'Storage backend delete not implemented for %s. Falling back to'
' standard deletion.',
delete_target,
)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Failed to delete step %d via storage backend: %s. Falling back to'
' standard deletion.',
step,
e,
)
self._rmtree(delete_target)
logging.info('Deleted step %d.', step)

Expand Down
35 changes: 32 additions & 3 deletions checkpoint/orbax/checkpoint/_src/path/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import abc
import dataclasses
import enum
from typing import Callable

from absl import logging
from etils import epath
Expand Down Expand Up @@ -130,9 +131,8 @@ def delete_checkpoint(
checkpoint_path: str | epath.PathLike,
) -> None:
"""Deletes the checkpoint at the given path."""
raise NotImplementedError(
'delete_checkpoint is not yet implemented for GCSStorageBackend.'
)
from orbax.checkpoint._src.path import gcs_utils # pylint: disable=g-import-not-at-top
gcs_utils.rmtree(epath.Path(checkpoint_path))


class LocalStorageBackend(StorageBackend):
Expand Down Expand Up @@ -167,3 +167,32 @@ def delete_checkpoint(
logging.info('Removed old checkpoint (%s)', checkpoint_path)
except OSError:
logging.exception('Failed to remove checkpoint (%s)', checkpoint_path)


ResolverFn = Callable[[str | epath.PathLike], StorageBackend]

_RESOLVER_FN: ResolverFn | None = None


def register_resolver(resolver_fn: ResolverFn) -> None:
"""Registers a custom storage backend resolver."""
global _RESOLVER_FN
_RESOLVER_FN = resolver_fn


def resolve_storage_backend(
path: str | epath.PathLike,
) -> StorageBackend:
"""Returns a StorageBackend object based on the given path."""
if _RESOLVER_FN is not None:
try:
return _RESOLVER_FN(path)
except (ValueError, NotImplementedError):
# If the registered resolver doesn't support the path, fall back to
# default resolver.
pass
from orbax.checkpoint._src.path import gcs_utils # pylint: disable=g-import-not-at-top
if gcs_utils.is_gcs_path(epath.Path(path)):
return GCSStorageBackend()
else:
return LocalStorageBackend()
Loading