Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from absl.testing import absltest
import jax
from jax import config
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost
import portpicker

Expand Down Expand Up @@ -299,6 +300,22 @@ def _main(argv):
assert retval == 0, f"process {i} failed, return value: {retval}"


def _sync_operation_id(client, sync_key: str):
"""Synchronizes the OperationIdGenerator across processes."""
if jax.process_index() == 0:
client.key_value_set(
sync_key,
synchronization.OperationIdGenerator.get_current_operation_id(),
allow_overwrite=True,
)
target = int(client.blocking_key_value_get(sync_key, 10000))
while (
int(synchronization.OperationIdGenerator.get_current_operation_id())
< target
):
synchronization.OperationIdGenerator.next_operation_id()


class MultiProcessTest(absltest.TestCase):
# TODO(b/378138653) Support TPUless MultiProcessTest.

Expand All @@ -318,6 +335,8 @@ def setUp(self):
f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}",
10000,
)
sync_key = f"sync_op_id_{self._testMethodName}"
_sync_operation_id(client, sync_key)

def multiprocess_create_tempdir(self, name: str | None = None) -> str:
"""Creates a temporary directory for the test."""
Expand Down
30 changes: 29 additions & 1 deletion checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from absl import app
from absl import flags
from absl import logging
import jax
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost
import pytest
import yaml

Expand Down Expand Up @@ -80,6 +83,29 @@ def _find_test_path(test_file_yaml):
return None


def _sync_op_id_generator(test_file_yaml: str) -> None:
"""Synchronizes the OperationIdGenerator across processes."""
try:
client = multihost.get_jax_distributed_client()
if client is not None:
normalized_name = test_file_yaml.replace('/', '_').replace(':', '_')
sync_key = f'sync_op_id_file_{normalized_name}'
operation_id_generator = synchronization.OperationIdGenerator
if jax.process_index() == 0:
client.key_value_set(
sync_key,
operation_id_generator.get_current_operation_id(),
allow_overwrite=True,
)
target = int(client.blocking_key_value_get(sync_key, 10000))
while int(operation_id_generator.get_current_operation_id()) < target:
operation_id_generator.next_operation_id()
except Exception as sync_e: # pylint: disable=broad-exception-caught
logging.warning(
'Could not synchronize OperationIdGenerator for file: %s', sync_e
)


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
Expand Down Expand Up @@ -130,7 +156,9 @@ def main(argv: Sequence[str]) -> None:

logging.info('Running test: %s (found from %s)', test_path, test_file_yaml)
try:
exit_code = pytest.main([test_path])
_sync_op_id_generator(test_file_yaml)

exit_code = pytest.main(['--import-mode=importlib', test_path])
if exit_code == 0:
results[test_file_yaml] = 'PASSED'
logging.info('%s: PASSED', test_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,10 @@ processes:1:
- orbax/checkpoint/_src/serialization:serialization_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_process_metadata_checkpoint_handler_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test
- orbax/checkpoint/experimental/v1/_src/emergency:deleter_test
- orbax/checkpoint/experimental/v1/_src/emergency:path_utils_test
- orbax/checkpoint:single_host_test
processes:2:
- orbax/checkpoint/_src/handlers:array_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test
- orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test
- orbax/checkpoint/_src/serialization:local_type_handlers_test
- orbax/checkpoint/_src/serialization:type_handlers_test
- orbax/checkpoint/experimental/emergency/p2p:checkpoint_manager_multiprocess_test
- orbax/checkpoint/experimental/emergency/p2p:local_multiprocess_test
- orbax/checkpoint/experimental/emergency/p2p:persistent_multiprocess_test
processes:4:
- orbax/checkpoint/_src/multihost:multihost_test
- orbax/checkpoint/_src/testing/tree_verity:checkpoint_manager_test
- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:process_metadata_checkpoint_handler_test
- orbax/checkpoint/experimental/emergency:local_checkpoint_data_debugging_test
- orbax/checkpoint/experimental/emergency:local_checkpoint_manager_test
- orbax/checkpoint/experimental/emergency:single_slice_checkpoint_manager_test
- orbax/checkpoint/testing:local_path_test
- orbax/checkpoint:checkpoint_manager_slice_test
- orbax/checkpoint:checkpoint_manager_test
150 changes: 150 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager_slice_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import test_utils
from orbax.checkpoint import utils
from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import type_handler_registry
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.testing import multiprocess_test


FLAGS = flags.FLAGS
PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler
CheckpointManager = checkpoint_manager.CheckpointManager
CheckpointManagerOptions = checkpoint_manager.CheckpointManagerOptions


@test_utils.barrier_compatible_test
class CheckpointManagerSliceTest(
parameterized.TestCase, multiprocess_test.MultiProcessTest
):
"""Structure allows test to run as subclasses, not base class."""

def setUp(self):
super().setUp()

if not multihost.is_runtime_to_distributed_ids_initialized():
multihost.initialize_runtime_to_distributed_ids()

self.assertEqual(jax.device_count(), 8)
self.assertEqual(jax.process_count(), 4)
self.assertEqual(jax.local_device_count(), 2)

self.directory = epath.Path(
self.multiprocess_create_tempdir(name='checkpoint_manager_slice_test')
)
test_utils.set_tensorstore_driver_for_test()

test_utils.sync_global_processes(
'CheckpointManagerSliceTest:setup_complete'
)

def tearDown(self):
test_utils.sync_global_processes(
'CheckpointManagerSliceTest:tests_complete'
)
super().tearDown()

def wait_if_async(self, manager):
manager.wait_until_finished() # no-op if no async checkpointers.

@parameterized.product(
enable_async_checkpointing=(False, True),
array_metadata_store=(None, array_metadata_store_lib.Store()),
)
def test_slice(
self,
enable_async_checkpointing: bool,
array_metadata_store: array_metadata_store_lib.Store | None,
):
"""Test slice."""
self.enter_context(
flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True)
)
global_mesh = test_utils.get_fake_global_mesh_for_slices([{0, 1}, {2, 3}])

mesh_axes = jax.sharding.PartitionSpec('data')
arrays = [
test_utils.create_sharded_array(arr, global_mesh, mesh_axes)
for arr in [np.arange(8), np.arange(16)]
]
assert len(global_mesh.devices[0]) == 4
assert jax.process_count() == 4
active_processes = {0, 1}
primary_host = 0
if multihost.process_index() in active_processes:
single_slice_arrays = test_utils.select_single_replica(
arrays, global_mesh
)
options = CheckpointManagerOptions(
create=False,
enable_async_checkpointing=enable_async_checkpointing,
multiprocessing_options=checkpoint_manager.MultiprocessingOptions(
primary_host=primary_host,
active_processes=active_processes,
),
)
registry = type_handler_registry.create_type_handler_registry(
(
jax.Array,
type_handlers.ArrayHandler(
primary_host=None,
replica_id=None,
use_replica_parallel=False,
array_metadata_store=array_metadata_store,
),
),
)
handler = PyTreeCheckpointHandler(
multiprocessing_options=options.multiprocessing_options,
type_handler_registry=registry,
)
registry = handler_registration.DefaultCheckpointHandlerRegistry()
registry.add(None, args.PyTreeSave, handler)
registry.add(None, args.PyTreeRestore, handler)
with CheckpointManager(
self.directory,
options=options,
handler_registry=registry,
) as manager:
self.assertTrue(manager.save(0, args=args.PyTreeSave(arrays)))
time.sleep(10)
self.wait_if_async(manager)
abstract_target = jax.tree.map(
utils.to_shape_dtype_struct, single_slice_arrays
)
restore_args = checkpoint_utils.construct_restore_args(abstract_target)
restored = manager.restore(
0, args=args.PyTreeRestore(restore_args=restore_args)
)
test_utils.assert_tree_equal(self, single_slice_arrays, restored)


if __name__ == '__main__':
multiprocess_test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def setUp(self):
)
if not multihost.is_runtime_to_distributed_ids_initialized():
multihost.initialize_runtime_to_distributed_ids()
if not multihost.is_distributed_to_device_ids_initialized():
multihost.initialize_distributed_to_device_ids()

# make sure each process is working on different directories
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def setUp(self):
)
if not multihost.is_runtime_to_distributed_ids_initialized():
multihost.initialize_runtime_to_distributed_ids()
if not multihost.is_distributed_to_device_ids_initialized():
multihost.initialize_distributed_to_device_ids()

self.global_mesh = self.make_global_mesh()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setUp(self):
)
if not multihost.is_runtime_to_distributed_ids_initialized():
multihost.initialize_runtime_to_distributed_ids()
if not multihost.is_distributed_to_device_ids_initialized():
multihost.initialize_distributed_to_device_ids()

self.local_directory = epath.Path(
Expand Down
107 changes: 107 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Deleter that dispatches to Pathways workers with Remote Python."""

from typing import Sequence
import jax
from orbax.checkpoint._src.multihost import dispatchers
from orbax.checkpoint._src.path import deleter as deleter_lib
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types


CheckpointDeleter = deleter_lib.CheckpointDeleter


class _PathwaysDeleter(deleter_lib.CheckpointDeleter):
"""Deleter that dispatches to Pathways workers with Remote Python."""

def __init__(
self,
deleter: deleter_lib.StandardCheckpointDeleter,
global_mesh: jax.sharding.Mesh | None,
):
self._global_mesh = global_mesh or jax.sharding.Mesh(jax.devices(), 'x')
self._deleter = deleter
self._dispatcher = dispatchers.RemotePythonDispatcher()

def delete(self, step: int) -> None:
"""Deletes a step.

Args:
step: The step to delete.
"""

def _delete(
input_arrays: jax.Array,
step: int,
):
del input_arrays
self._deleter.delete(step)

jax.block_until_ready(
self._dispatcher.dispatch(
_delete,
input_arrays=dispatchers.get_dummy_input_array(
self._global_mesh.devices.flatten().tolist(),
),
func_kwargs={'step': step},
)
)

def delete_steps(self, steps: Sequence[int]) -> None:
"""Deletes a sequence of steps.

Args:
steps: The steps to delete.
"""
def _delete(
input_arrays: jax.Array,
steps: Sequence[int],
):
del input_arrays
self._deleter.delete_steps(steps)

jax.block_until_ready(
self._dispatcher.dispatch(
_delete,
input_arrays=dispatchers.get_dummy_input_array(
self._global_mesh.devices.flatten().tolist(),
),
func_kwargs={'steps': steps},
)
)

def close(self) -> None:
"""Performs any cleanup before closing this deleter."""
self._deleter.close()


def create_checkpoint_deleter(
directory: path_types.Path,
*,
global_mesh: jax.sharding.Mesh | None = None,
name_format: step_lib.NameFormat[step_lib.Metadata],
todelete_subdir: str | None = None,
) -> CheckpointDeleter:
return _PathwaysDeleter(
deleter_lib.StandardCheckpointDeleter(
directory,
name_format=name_format,
todelete_subdir=todelete_subdir,
primary_host=None,
),
global_mesh,
)
Loading
Loading