From 8d2e62bf4d0f448222b6c5a8d924012127f2f0ae Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal Date: Sun, 14 Jun 2026 23:49:57 -0700 Subject: [PATCH] Remove emergency directory from Orbax copybara exclusions. PiperOrigin-RevId: 932263086 --- .../_src/testing/oss/multiprocess_test.py | 19 + .../checkpoint/_src/testing/oss/run_tests.py | 30 +- .../testing/oss/tagged_tests_presubmit.yaml | 17 +- .../checkpoint_manager_slice_test.py | 150 ++++++++ .../emergency/checkpoint_manager_test.py | 1 + .../replicator_checkpoint_manager_test.py | 1 + .../single_slice_checkpoint_manager_test.py | 1 + .../experimental/v1/_src/emergency/deleter.py | 107 ++++++ .../v1/_src/emergency/deleter_test.py | 85 +++++ .../v1/_src/emergency/mesh_test_utils.py | 76 ++++ .../v1/_src/emergency/path_utils.py | 164 +++++++++ .../v1/_src/emergency/path_utils_test.py | 276 +++++++++++++++ .../v1/_src/synchronization/multihost_test.py | 325 ++++++++++++++++++ .../checkpoint/testing/local_path_test.py | 82 +++++ 14 files changed, 1318 insertions(+), 16 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/checkpoint_manager_slice_test.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter_test.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/mesh_test_utils.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils_test.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/multihost_test.py create mode 100644 checkpoint/orbax/checkpoint/testing/local_path_test.py diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_test.py index fbbc9235ac..55a6d6b741 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_test.py @@ -32,6 +32,7 @@ from absl.testing import absltest import jax from jax import config +from orbax.checkpoint._src.futures import synchronization from orbax.checkpoint._src.multihost import multihost import portpicker @@ -299,6 +300,22 @@ def _main(argv): assert retval == 0, f"process {i} failed, return value: {retval}" +def _sync_operation_id(client, sync_key: str): + """Synchronizes the OperationIdGenerator across processes.""" + if jax.process_index() == 0: + client.key_value_set( + sync_key, + synchronization.OperationIdGenerator.get_current_operation_id(), + allow_overwrite=True, + ) + target = int(client.blocking_key_value_get(sync_key, 10000)) + while ( + int(synchronization.OperationIdGenerator.get_current_operation_id()) + < target + ): + synchronization.OperationIdGenerator.next_operation_id() + + class MultiProcessTest(absltest.TestCase): # TODO(b/378138653) Support TPUless MultiProcessTest. @@ -318,6 +335,8 @@ def setUp(self): f"multiprocess_test_ensure_all_processes_arrive_at_test_case_{self._testMethodName}", 10000, ) + sync_key = f"sync_op_id_{self._testMethodName}" + _sync_operation_id(client, sync_key) def multiprocess_create_tempdir(self, name: str | None = None) -> str: """Creates a temporary directory for the test.""" diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py index 8e66f33f04..450a1d474e 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py @@ -22,6 +22,9 @@ from absl import app from absl import flags from absl import logging +import jax +from orbax.checkpoint._src.futures import synchronization +from orbax.checkpoint._src.multihost import multihost import pytest import yaml @@ -80,6 +83,29 @@ def _find_test_path(test_file_yaml): return None +def _sync_op_id_generator(test_file_yaml: str) -> None: + """Synchronizes the OperationIdGenerator across processes.""" + try: + client = multihost.get_jax_distributed_client() + if client is not None: + normalized_name = test_file_yaml.replace('/', '_').replace(':', '_') + sync_key = f'sync_op_id_file_{normalized_name}' + operation_id_generator = synchronization.OperationIdGenerator + if jax.process_index() == 0: + client.key_value_set( + sync_key, + operation_id_generator.get_current_operation_id(), + allow_overwrite=True, + ) + target = int(client.blocking_key_value_get(sync_key, 10000)) + while int(operation_id_generator.get_current_operation_id()) < target: + operation_id_generator.next_operation_id() + except Exception as sync_e: # pylint: disable=broad-exception-caught + logging.warning( + 'Could not synchronize OperationIdGenerator for file: %s', sync_e + ) + + def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') @@ -130,7 +156,9 @@ def main(argv: Sequence[str]) -> None: logging.info('Running test: %s (found from %s)', test_path, test_file_yaml) try: - exit_code = pytest.main([test_path]) + _sync_op_id_generator(test_file_yaml) + + exit_code = pytest.main(['--import-mode=importlib', test_path]) if exit_code == 0: results[test_file_yaml] = 'PASSED' logging.info('%s: PASSED', test_path) diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml index 179181c9d6..096bb546fd 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests_presubmit.yaml @@ -11,23 +11,10 @@ processes:1: - orbax/checkpoint/_src/serialization:serialization_test - orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_process_metadata_checkpoint_handler_test - orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:pathways_replicator_checkpoint_manager_test +- orbax/checkpoint/experimental/v1/_src/emergency:deleter_test +- orbax/checkpoint/experimental/v1/_src/emergency:path_utils_test - orbax/checkpoint:single_host_test processes:2: - orbax/checkpoint/_src/handlers:array_checkpoint_handler_test -- orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test -- orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test -- orbax/checkpoint/_src/serialization:local_type_handlers_test -- orbax/checkpoint/_src/serialization:type_handlers_test -- orbax/checkpoint/experimental/emergency/p2p:checkpoint_manager_multiprocess_test -- orbax/checkpoint/experimental/emergency/p2p:local_multiprocess_test -- orbax/checkpoint/experimental/emergency/p2p:persistent_multiprocess_test processes:4: - orbax/checkpoint/_src/multihost:multihost_test -- orbax/checkpoint/_src/testing/tree_verity:checkpoint_manager_test -- orbax/checkpoint/experimental/emergency/multi_tier_checkpointing:process_metadata_checkpoint_handler_test -- orbax/checkpoint/experimental/emergency:local_checkpoint_data_debugging_test -- orbax/checkpoint/experimental/emergency:local_checkpoint_manager_test -- orbax/checkpoint/experimental/emergency:single_slice_checkpoint_manager_test -- orbax/checkpoint/testing:local_path_test -- orbax/checkpoint:checkpoint_manager_slice_test -- orbax/checkpoint:checkpoint_manager_test diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager_slice_test.py b/checkpoint/orbax/checkpoint/checkpoint_manager_slice_test.py new file mode 100644 index 0000000000..d153650fdf --- /dev/null +++ b/checkpoint/orbax/checkpoint/checkpoint_manager_slice_test.py @@ -0,0 +1,150 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from absl import flags +from absl.testing import flagsaver +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import args +from orbax.checkpoint import checkpoint_manager +from orbax.checkpoint import checkpoint_utils +from orbax.checkpoint import test_utils +from orbax.checkpoint import utils +from orbax.checkpoint._src.handlers import handler_registration +from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.serialization import type_handler_registry +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.testing import multiprocess_test + + +FLAGS = flags.FLAGS +PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler +CheckpointManager = checkpoint_manager.CheckpointManager +CheckpointManagerOptions = checkpoint_manager.CheckpointManagerOptions + + +@test_utils.barrier_compatible_test +class CheckpointManagerSliceTest( + parameterized.TestCase, multiprocess_test.MultiProcessTest +): + """Structure allows test to run as subclasses, not base class.""" + + def setUp(self): + super().setUp() + + if not multihost.is_runtime_to_distributed_ids_initialized(): + multihost.initialize_runtime_to_distributed_ids() + + self.assertEqual(jax.device_count(), 8) + self.assertEqual(jax.process_count(), 4) + self.assertEqual(jax.local_device_count(), 2) + + self.directory = epath.Path( + self.multiprocess_create_tempdir(name='checkpoint_manager_slice_test') + ) + test_utils.set_tensorstore_driver_for_test() + + test_utils.sync_global_processes( + 'CheckpointManagerSliceTest:setup_complete' + ) + + def tearDown(self): + test_utils.sync_global_processes( + 'CheckpointManagerSliceTest:tests_complete' + ) + super().tearDown() + + def wait_if_async(self, manager): + manager.wait_until_finished() # no-op if no async checkpointers. + + @parameterized.product( + enable_async_checkpointing=(False, True), + array_metadata_store=(None, array_metadata_store_lib.Store()), + ) + def test_slice( + self, + enable_async_checkpointing: bool, + array_metadata_store: array_metadata_store_lib.Store | None, + ): + """Test slice.""" + self.enter_context( + flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True) + ) + global_mesh = test_utils.get_fake_global_mesh_for_slices([{0, 1}, {2, 3}]) + + mesh_axes = jax.sharding.PartitionSpec('data') + arrays = [ + test_utils.create_sharded_array(arr, global_mesh, mesh_axes) + for arr in [np.arange(8), np.arange(16)] + ] + assert len(global_mesh.devices[0]) == 4 + assert jax.process_count() == 4 + active_processes = {0, 1} + primary_host = 0 + if multihost.process_index() in active_processes: + single_slice_arrays = test_utils.select_single_replica( + arrays, global_mesh + ) + options = CheckpointManagerOptions( + create=False, + enable_async_checkpointing=enable_async_checkpointing, + multiprocessing_options=checkpoint_manager.MultiprocessingOptions( + primary_host=primary_host, + active_processes=active_processes, + ), + ) + registry = type_handler_registry.create_type_handler_registry( + ( + jax.Array, + type_handlers.ArrayHandler( + primary_host=None, + replica_id=None, + use_replica_parallel=False, + array_metadata_store=array_metadata_store, + ), + ), + ) + handler = PyTreeCheckpointHandler( + multiprocessing_options=options.multiprocessing_options, + type_handler_registry=registry, + ) + registry = handler_registration.DefaultCheckpointHandlerRegistry() + registry.add(None, args.PyTreeSave, handler) + registry.add(None, args.PyTreeRestore, handler) + with CheckpointManager( + self.directory, + options=options, + handler_registry=registry, + ) as manager: + self.assertTrue(manager.save(0, args=args.PyTreeSave(arrays))) + time.sleep(10) + self.wait_if_async(manager) + abstract_target = jax.tree.map( + utils.to_shape_dtype_struct, single_slice_arrays + ) + restore_args = checkpoint_utils.construct_restore_args(abstract_target) + restored = manager.restore( + 0, args=args.PyTreeRestore(restore_args=restore_args) + ) + test_utils.assert_tree_equal(self, single_slice_arrays, restored) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py index dfd3360614..7cc25a5148 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py @@ -99,6 +99,7 @@ def setUp(self): ) if not multihost.is_runtime_to_distributed_ids_initialized(): multihost.initialize_runtime_to_distributed_ids() + if not multihost.is_distributed_to_device_ids_initialized(): multihost.initialize_distributed_to_device_ids() # make sure each process is working on different directories diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py index 7110a235a0..2fd0bc3bf6 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager_test.py @@ -209,6 +209,7 @@ def setUp(self): ) if not multihost.is_runtime_to_distributed_ids_initialized(): multihost.initialize_runtime_to_distributed_ids() + if not multihost.is_distributed_to_device_ids_initialized(): multihost.initialize_distributed_to_device_ids() self.global_mesh = self.make_global_mesh() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py index dcd71ab48d..15b32754eb 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py @@ -58,6 +58,7 @@ def setUp(self): ) if not multihost.is_runtime_to_distributed_ids_initialized(): multihost.initialize_runtime_to_distributed_ids() + if not multihost.is_distributed_to_device_ids_initialized(): multihost.initialize_distributed_to_device_ids() self.local_directory = epath.Path( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter.py new file mode 100644 index 0000000000..229e9b472d --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter.py @@ -0,0 +1,107 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deleter that dispatches to Pathways workers with Remote Python.""" + +from typing import Sequence +import jax +from orbax.checkpoint._src.multihost import dispatchers +from orbax.checkpoint._src.path import deleter as deleter_lib +from orbax.checkpoint._src.path import step as step_lib +from orbax.checkpoint.experimental.v1._src.path import types as path_types + + +CheckpointDeleter = deleter_lib.CheckpointDeleter + + +class _PathwaysDeleter(deleter_lib.CheckpointDeleter): + """Deleter that dispatches to Pathways workers with Remote Python.""" + + def __init__( + self, + deleter: deleter_lib.StandardCheckpointDeleter, + global_mesh: jax.sharding.Mesh | None, + ): + self._global_mesh = global_mesh or jax.sharding.Mesh(jax.devices(), 'x') + self._deleter = deleter + self._dispatcher = dispatchers.RemotePythonDispatcher() + + def delete(self, step: int) -> None: + """Deletes a step. + + Args: + step: The step to delete. + """ + + def _delete( + input_arrays: jax.Array, + step: int, + ): + del input_arrays + self._deleter.delete(step) + + jax.block_until_ready( + self._dispatcher.dispatch( + _delete, + input_arrays=dispatchers.get_dummy_input_array( + self._global_mesh.devices.flatten().tolist(), + ), + func_kwargs={'step': step}, + ) + ) + + def delete_steps(self, steps: Sequence[int]) -> None: + """Deletes a sequence of steps. + + Args: + steps: The steps to delete. + """ + def _delete( + input_arrays: jax.Array, + steps: Sequence[int], + ): + del input_arrays + self._deleter.delete_steps(steps) + + jax.block_until_ready( + self._dispatcher.dispatch( + _delete, + input_arrays=dispatchers.get_dummy_input_array( + self._global_mesh.devices.flatten().tolist(), + ), + func_kwargs={'steps': steps}, + ) + ) + + def close(self) -> None: + """Performs any cleanup before closing this deleter.""" + self._deleter.close() + + +def create_checkpoint_deleter( + directory: path_types.Path, + *, + global_mesh: jax.sharding.Mesh | None = None, + name_format: step_lib.NameFormat[step_lib.Metadata], + todelete_subdir: str | None = None, +) -> CheckpointDeleter: + return _PathwaysDeleter( + deleter_lib.StandardCheckpointDeleter( + directory, + name_format=name_format, + todelete_subdir=todelete_subdir, + primary_host=None, + ), + global_mesh, + ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter_test.py new file mode 100644 index 0000000000..df3bae2ed2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/deleter_test.py @@ -0,0 +1,85 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Pathways deleter. + +Note: It is important not to pass `self` to the dispatched function. +""" + +from absl import flags +import jax +from orbax.checkpoint._src.multihost import dispatchers +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.multihost import pathways as multihost_pathways +from orbax.checkpoint._src.path import step as step_lib +from orbax.checkpoint.experimental.v1._src.emergency import deleter as deleter_lib +from orbax.checkpoint.testing import local_path as local_path_test_lib + +from .pyglib.contrib.g3_multiprocessing import g3_multiprocessing +from absl.testing import absltest +from .testing.pybase import parameterized + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +class DeleterTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self._directory = local_path_test_lib.create_local_path_base(self) + self.assertTrue(self._directory.exists()) + self.assertTrue(multihost.is_pathways_backend()) + self._dispatcher = dispatchers.RemotePythonDispatcher() + + def test_dispatch_function(self): + local_path = local_path_test_lib.LocalPath(self._directory) + jax.block_until_ready(self._dispatcher.dispatch(local_path.mkdir)) + + for i in range(multihost_pathways.worker_count(None)): + path = self._directory / f'local_{i}' + self.assertTrue(path.exists(), f'Path {path} does not exist.') + + def test_delete_local_step(self): + name_format = step_lib.standard_name_format() + local_path = local_path_test_lib.LocalPath(self._directory) + deleter = deleter_lib.create_checkpoint_deleter( + local_path, name_format=name_format # pytype: disable=wrong-arg-types + ) + + def _make_step_path(): + path = local_path / name_format.build_name(1) + path.mkdir(parents=True, exist_ok=False) + + jax.block_until_ready(self._dispatcher.dispatch(_make_step_path)) + + worker_count = multihost_pathways.worker_count(None) + self.assertEqual(worker_count, 4) + for worker_id in range(worker_count): + path = self._directory / f'local_{worker_id}' / name_format.build_name(1) + self.assertTrue(path.exists()) + + deleter.delete(1) + for worker_id in range(worker_count): + path = self._directory / f'local_{worker_id}' / name_format.build_name(1) + self.assertFalse( + path.exists(), f'Path {path} still exists on worker {worker_id}' + ) + + +if __name__ == '__main__': + jax.config.parse_flags_with_absl() + g3_multiprocessing.handle_test_main(googletest.main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/mesh_test_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/mesh_test_utils.py new file mode 100644 index 0000000000..fc771fa17c --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/mesh_test_utils.py @@ -0,0 +1,76 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utils for mesh configuration.""" + +import dataclasses +from absl import logging +import jax +import numpy as np + + +@dataclasses.dataclass +class MeshConfig: + """Configuration facilitating mesh generation.""" + + replica_count: int + replica_axis_index: int + use_device_count: int | None = None + + def __post_init__(self): + if self.replica_axis_index not in [0, 1]: + raise ValueError( + f'replica_axis_index must be 0 or 1. Got: {self.replica_axis_index}' + ) + if ( + self.use_device_count is not None + and self.use_device_count < self.replica_count + ): + raise ValueError( + 'use_device_count must be greater than or equal to replica_count.' + f' Got: {self.use_device_count} and {self.replica_count}' + ) + if ( + self.use_device_count is not None + and self.use_device_count == self.replica_count + ): + raise ValueError( + 'use_device_count must be greater than replica_count. Got:' + f' {self.use_device_count} and {self.replica_count}' + ) + + @property + def mesh(self) -> jax.sharding.Mesh: + """Generates a JAX mesh based on the configuration.""" + if jax.device_count() != 8: + raise ValueError('Device count must be 8. Got: {jax.device_count()}') + if jax.device_count() % self.replica_count != 0: + raise ValueError( + 'Device count must be divisible by replica count. Got:' + f' {jax.device_count()} and {self.replica_count}' + ) + use_device_count = self.use_device_count or jax.device_count() + devices_per_replica = use_device_count // self.replica_count + axes = (self.replica_count, devices_per_replica) + axis_names = ('replica', 'data') + device_array = np.asarray(jax.devices()[:use_device_count]).reshape(axes) + if self.replica_axis_index == 1: + axes = axes[::-1] + axis_names = axis_names[::-1] + device_array = np.swapaxes(device_array, 0, 1) + assert ( + device_array.shape == axes + ), f'Devices: {device_array.shape}, axes: {axes}' + logging.info('Devices: %s', device_array) + return jax.sharding.Mesh(device_array, axis_names) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils.py new file mode 100644 index 0000000000..08b7521131 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils.py @@ -0,0 +1,164 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for working with local paths in Pathways. + +TODO(b/448471028): Rework using Dispatcher. +""" + +import jax +from jax import numpy as jnp +import numpy as np +from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.path import step as step_lib +from orbax.checkpoint.experimental.v1._src.path import types as path_types +from .learning.deepmind.jax.ocean.remote_python import rp + + +def _get_max_num_steps( + local_directory: path_types.Path, + *, + step_name_format: step_lib.NameFormat[step_lib.Metadata], + global_mesh: jax.sharding.Mesh, +) -> int: + """Returns the maximum number of steps present on any worker.""" + devices = global_mesh.devices.flatten() + device_count = len(devices) + fully_sharded_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(global_mesh.devices.flatten(), 'x'), + jax.sharding.PartitionSpec( + 'x', + ), + ) + + @rp.stateless_fn + def _max_steps_per_device(_) -> jax.Array: + local_steps = set( + m.step for m in step_name_format.find_all(local_directory) + ) + local_data = [len(local_steps)] * jax.local_device_count() + local_data = jnp.asarray(local_data, dtype=np.int64) + assert local_data.shape == (jax.local_device_count(),) + return jax.make_array_from_process_local_data( + fully_sharded_sharding, local_data, (device_count,) + ) + + dummy_input = rp.make_dummy_array(global_mesh.devices.flatten().tolist()) + _max_steps_per_device.register_shape_fn( + lambda _: jax.ShapeDtypeStruct( + (device_count,), + dtype=np.int64, + sharding=fully_sharded_sharding, + ) + ) + result = _max_steps_per_device(rp.to_remote_python(dummy_input)) + step_count_per_device = rp.from_remote_python(jax.block_until_ready(result)) + return int(max(step_count_per_device)) + + +def _get_steps_per_device( + local_directory: path_types.Path, + max_steps_per_process: int, + *, + global_mesh: jax.sharding.Mesh, + step_name_format: step_lib.NameFormat[step_lib.Metadata], +) -> dict[jax.Device, set[int]]: + """Returns array (device_count, max_num_steps).""" + devices = global_mesh.devices.flatten() + device_count = len(devices) + fully_replicated_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(devices, 'x'), + jax.sharding.PartitionSpec(), + ) + fully_sharded_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(devices, 'x'), + jax.sharding.PartitionSpec('x', None), + ) + global_shape = (device_count, max_steps_per_process) + + @rp.stateless_fn + def _padded_steps_per_device(_) -> jax.Array: + local_steps = list( + m.step for m in step_name_format.find_all(local_directory) + ) + num_local_steps = len(local_steps) + padded_local_steps = list(local_steps) + [-1] * ( + max_steps_per_process - num_local_steps + ) + local_data = jnp.asarray(padded_local_steps, dtype=np.int64) + local_data = jnp.tile(local_data, (jax.local_device_count(), 1)) + assert local_data.shape == (jax.local_device_count(), max_steps_per_process) + return jax.make_array_from_process_local_data( + fully_sharded_sharding, local_data, global_shape + ) + + dummy_input = jax.device_put( + jnp.asarray(0, dtype=np.int64), device=fully_replicated_sharding + ) + _padded_steps_per_device.register_shape_fn( + lambda _: jax.ShapeDtypeStruct( + global_shape, + dtype=np.int64, + sharding=fully_sharded_sharding, + ) + ) + result = _padded_steps_per_device(rp.to_remote_python(dummy_input)) + steps_per_device = rp.from_remote_python(jax.block_until_ready(result)) + device_to_steps = {} + for shard in steps_per_device.addressable_shards: + data = np.asarray(shard.data)[0].tolist() + device_to_steps[shard.device] = set(v for v in data if v != -1) + return device_to_steps + + +def per_replica_local_steps( + local_directory: path_types.Path, + *, + step_name_format: step_lib.NameFormat[step_lib.Metadata], + global_mesh: jax.sharding.Mesh, + replica_axis_index: int, +) -> dict[int, set[int]]: + """Returns a mapping of replica index to local steps present on that replica.""" + max_steps_per_process = _get_max_num_steps( + local_directory, + global_mesh=global_mesh, + step_name_format=step_name_format, + ) + # (device_count, max_steps_per_process) + steps_per_device = _get_steps_per_device( + local_directory, + max_steps_per_process, + global_mesh=global_mesh, + step_name_format=step_name_format, + ) + + num_replicas = multislice.replica_count( + global_mesh, replica_axis_index=replica_axis_index + ) + result: dict[int, set[int]] = {} + for replica_id in range(num_replicas): + replica_devices = multislice.replica_devices( + global_mesh, + replica_id=replica_id, + replica_axis_index=replica_axis_index, + ) + replica_steps = set() + for i, d in enumerate(replica_devices): + if i == 0: + replica_steps = steps_per_device[d] + else: + replica_steps &= steps_per_device[d] + result[replica_id] = replica_steps + + return result diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils_test.py new file mode 100644 index 0000000000..95364be390 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/emergency/path_utils_test.py @@ -0,0 +1,276 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from absl import flags +import jax +import numpy as np +from orbax.checkpoint._src.multihost import dispatchers +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.multihost import pathways as multihost_pathways +from orbax.checkpoint._src.path import step as step_lib +from orbax.checkpoint.experimental.v1._src.emergency import mesh_test_utils +from orbax.checkpoint.experimental.v1._src.emergency import path_utils +from orbax.checkpoint.testing import local_path as local_path_test_utils +from .pyglib.contrib.g3_multiprocessing import g3_multiprocessing +from absl.testing import absltest +from .testing.pybase import parameterized + + +FLAGS = flags.FLAGS +LocalPath = local_path_test_utils.LocalPath + +jax.config.update('jax_enable_x64', True) + +MeshConfig = mesh_test_utils.MeshConfig + + +def _mesh_configs() -> list[MeshConfig]: + result = [] + for replica_count, replica_axis_index, use_device_count in itertools.product( + [8, 4, 2, 1], + [0], # TODO(b/448471028): Fix replica_axis_index=1 behavior. + [8, 4], + ): + try: + cfg = MeshConfig( + replica_count=replica_count, + replica_axis_index=replica_axis_index, + use_device_count=use_device_count, + ) + except ValueError: + continue + result.append(cfg) + return result + + +class PerReplicaLocalLatestStepsTest( + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + self.directory = local_path_test_utils.create_local_path_base(self) + self.assertTrue(self.directory.exists()) + self.assertTrue(multihost.is_pathways_backend()) + self.local_directory = LocalPath(self.directory) + self.assertEqual(jax.device_count(), 8) + self.assertEqual(multihost_pathways.worker_count(None), 4) + + def assertLocalPathsExist( + self, steps: list[int], global_mesh: jax.sharding.Mesh + ): + name_format = step_lib.standard_name_format() + for w in range(multihost_pathways.worker_count(global_mesh)): + for step in steps: + path = self.directory / f'local_{w}' / name_format.build_name(step) + self.assertTrue(path.exists(), f'Path {path} does not exist.') + + def save_local_steps(self, steps: list[int], global_mesh: jax.sharding.Mesh): + name_format = step_lib.standard_name_format() + local_directory = self.local_directory + + def _mkdirs(): + for step in steps: + (local_directory / name_format.build_name(step)).mkdir( + parents=True, exist_ok=False + ) + + jax.block_until_ready( + dispatchers.RemotePythonDispatcher().dispatch( + _mkdirs + ) + ) + self.assertLocalPathsExist(steps, global_mesh) + + @parameterized.product(mesh_cfg=_mesh_configs()) + def test_worker_count(self, mesh_cfg: MeshConfig): + mesh = mesh_cfg.mesh + self.assertEqual( + multihost_pathways.worker_count(mesh), len(mesh.devices.flatten()) // 2 + ) + + @parameterized.product( + steps=([0, 1, 2], [5, 0], [0], []), mesh_cfg=_mesh_configs() + ) + def test_same_steps(self, steps: list[int], mesh_cfg: MeshConfig): + replica_count = mesh_cfg.replica_count + replica_axis_index = mesh_cfg.replica_axis_index + mesh = mesh_cfg.mesh + self.save_local_steps(steps, mesh) + max_num_steps = path_utils._get_max_num_steps( + self.local_directory, + global_mesh=mesh, + step_name_format=step_lib.standard_name_format(), + ) + self.assertLen(steps, max_num_steps) + per_replica_local_steps = path_utils.per_replica_local_steps( + self.local_directory, + step_name_format=step_lib.standard_name_format(), + global_mesh=mesh, + replica_axis_index=replica_axis_index, + ) + self.assertLen(per_replica_local_steps.keys(), replica_count) + self.assertEqual( + per_replica_local_steps, + {replica_id: set(steps) for replica_id in range(replica_count)}, + ) + + @parameterized.product( + steps=([0, 1, 2], [0, 5], [0]), mesh_cfg=_mesh_configs() + ) + def test_different_steps(self, steps: list[int], mesh_cfg: MeshConfig): + replica_count = mesh_cfg.replica_count + replica_axis_index = mesh_cfg.replica_axis_index + mesh = mesh_cfg.mesh + self.save_local_steps(steps, mesh) + name_format = step_lib.standard_name_format() + + # Delete worker 0 step. + (self.directory / 'local_0' / name_format.build_name(0)).rmtree() + + per_replica_local_steps = path_utils.per_replica_local_steps( + self.local_directory, + step_name_format=name_format, + global_mesh=mesh, + replica_axis_index=replica_axis_index, + ) + self.assertLen(per_replica_local_steps.keys(), replica_count) + expected_steps = {} + expected_steps[0] = set(steps) - {0} + for replica_id in range(1, replica_count): + expected_steps[replica_id] = set(steps) + self.assertEqual( + per_replica_local_steps, + expected_steps, + ) + + @parameterized.product(steps=([0, 1, 2], [0, 5], [0], [])) + def test_processes_split_between_replicas(self, steps: list[int]): + devices = jax.devices() + mesh = jax.sharding.Mesh( + np.asarray([ + [devices[0], devices[1], devices[2]], + [devices[3], devices[4], devices[5]], + ]), + ('replica', 'data'), + ) + replica_count = 2 + self.assertLen(mesh.devices, replica_count) + self.assertEqual(multihost_pathways.worker_count(mesh), 3) + self.save_local_steps(steps, mesh) + name_format = step_lib.standard_name_format() + + per_replica_local_steps = path_utils.per_replica_local_steps( + self.local_directory, + step_name_format=name_format, + global_mesh=mesh, + replica_axis_index=0, + ) + self.assertLen(per_replica_local_steps.keys(), replica_count) + self.assertEqual( + per_replica_local_steps, + {replica_id: set(steps) for replica_id in range(replica_count)}, + ) + + if not steps: + return + + # Delete worker 2 step. + (self.directory / 'local_2' / name_format.build_name(0)).rmtree() + + per_replica_local_steps = path_utils.per_replica_local_steps( + self.local_directory, + step_name_format=name_format, + global_mesh=mesh, + replica_axis_index=0, + ) + self.assertLen(per_replica_local_steps.keys(), replica_count) + expected_steps = { + 0: set(steps), + 1: set(steps) - {0}, + } + self.assertEqual( + per_replica_local_steps, + expected_steps, + ) + + @parameterized.parameters( + # 2 replicas + ([[], [], [], []], 2, [{}, {}]), + ([[1], [1], [0], [0]], 2, [{1}, {0}]), + ([[], [], [0], [0]], 2, [{}, {0}]), + ([[], [], [0], []], 2, [{}, {}]), + ([[], [], [0], [1]], 2, [{}, {}]), + ([[], [0], [], [0]], 2, [{}, {}]), + ([[0], [1], [0], [1]], 2, [{}, {}]), + ([[1, 2], [1, 2], [4], [4, 5]], 2, [{1, 2}, {4}]), + ( + [[-1, 0], [-1, 0], [1, 2], [1, 2]], + 2, + [{0}, {1, 2}], + ), + ( + [[-1, 0, 1], [-1, 0, 1], [0, 1, -1], [1, 0, -1]], + 2, + [{0, 1}, {0, 1}], + ), + # 4 replicas + ([[], [], [], []], 4, [{}, {}, {}, {}]), + ([[1], [1], [0], [0]], 4, [{1}, {1}, {0}, {0}]), + ([[], [], [0], [0]], 4, [{}, {}, {0}, {0}]), + ([[], [], [0], []], 4, [{}, {}, {0}, {}]), + ([[], [], [0], [1]], 4, [{}, {}, {0}, {1}]), + ([[], [0], [], [0]], 4, [{}, {0}, {}, {0}]), + ([[0], [1], [0], [1]], 4, [{0}, {1}, {0}, {1}]), + ([[1, 2], [1, 2], [4], [4, 5]], 4, [{1, 2}, {1, 2}, {4}, {4, 5}]), + ( + [[-1, 0], [-1, 0], [1, 2], [1, 2]], + 4, + [{0}, {0}, {1, 2}, {1, 2}], + ), + ( + [[-1, 0, 1], [-1, 0, 1], [0, 1, -1], [1, 0, -1]], + 4, + [{0, 1}, {0, 1}, {0, 1}, {0, 1}], + ), + ) + def test_per_replica_local_steps( + self, worker_steps, num_replicas, expectation + ): + expected_dict = {i: set(steps) for i, steps in enumerate(expectation)} + for replica_axis_index in [0]: + with self.subTest(f'replica_axis_index={replica_axis_index}'): + mesh_cfg = MeshConfig( + replica_count=num_replicas, + replica_axis_index=replica_axis_index, + ) + for w, steps in enumerate(worker_steps): + for step in steps: + (self.directory / f'local_{w}' / str(step)).mkdir( + parents=True, exist_ok=False + ) + + per_replica_local_steps = path_utils.per_replica_local_steps( + self.local_directory, + step_name_format=step_lib.standard_name_format(), + global_mesh=mesh_cfg.mesh, + replica_axis_index=mesh_cfg.replica_axis_index, + ) + self.assertDictEqual(per_replica_local_steps, expected_dict) + + +if __name__ == '__main__': + jax.config.parse_flags_with_absl() + g3_multiprocessing.handle_test_main(googletest.main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/multihost_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/multihost_test.py new file mode 100644 index 0000000000..2e868160ee --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/multihost_test.py @@ -0,0 +1,325 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +import unittest + +from absl.testing import parameterized +from etils import epath +import jax +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint.experimental.v1._src.synchronization import multihost + + +async def primary_host_sleep_and_mkdir(path: epath.Path, seconds: int = 2): + if multihost.process_index() == 0: + await asyncio.sleep(seconds) + await async_path.mkdir(path, parents=False, exist_ok=False) + + +class MultihostTest( + parameterized.TestCase, + multiprocess_test.MultiProcessTest, + unittest.IsolatedAsyncioTestCase, +): + + def setUp(self): + super().setUp() + self.assertEqual(jax.device_count(), 8) + self.assertEqual(jax.process_count(), 4) + self.assertEqual(jax.local_device_count(), 2) + + self.tmpdir = epath.Path(self.multiprocess_create_tempdir('multihost_test')) + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + async def test_process_errors(self): + if multihost.process_index() == 1: + with self.assertRaises(ValueError): + await multihost.sync_global_processes( + 'test_process_errors_1', operation_id='op', processes={0} + ) + + async def test_sync_global_processes(self): + path = self.tmpdir / 'dummy' + if multihost.process_index() == 0: + await asyncio.sleep(2) + await async_path.mkdir(path, parents=False, exist_ok=False) + else: + self.assertFalse(await async_path.exists(path)) + await multihost.sync_global_processes( + 'test_sync_global_processes', operation_id='op' + ) + self.assertTrue(await async_path.exists(path)) + + async def test_sync_global_processes_partially_async(self): + path = self.tmpdir / 'dummy' + if multihost.process_index() == 0: + time.sleep(2) + path.mkdir(parents=False, exist_ok=False) + else: + self.assertFalse(path.exists()) + await multihost.sync_global_processes( + 'test_sync_global_processes', operation_id='op' + ) + self.assertTrue(path.exists()) + + async def test_reused_barrier_key(self): + await multihost.sync_global_processes( + 'test_reused_barrier_key', operation_id='op' + ) + await multihost.sync_global_processes( + 'test_reused_barrier_key', operation_id='op' + ) + + async def test_interlocking_sequential(self): + async def foo(): + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + await asyncio.sleep(2) + + async def bar(): + await asyncio.sleep(2) + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + + start = time.time() + if multihost.process_index() == 0: + await foo() + else: + await bar() + await multihost.sync_global_processes( + 'test_interlocking_final', operation_id='op' + ) + end = time.time() + self.assertGreaterEqual(end - start, 4) + + async def test_interlocking_different_barrier_names(self): + async def foo(): + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + await asyncio.sleep(2) + + async def bar(): + await asyncio.sleep(2) + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + + start = time.time() + if multihost.process_index() == 0: + await foo() + # Need to unlock the other processes, otherwise they will get stuck. + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + else: + # Unlock the other process before proceeding. + await multihost.sync_global_processes( + 'test_interlocking', operation_id='op' + ) + await bar() + await multihost.sync_global_processes( + 'test_interlocking_final', operation_id='op' + ) + end = time.time() + self.assertLess(end - start, 3) + + async def test_not_all_processes_arrived_at_barrier(self): + if multihost.process_index() == 0: + with self.assertRaises(TimeoutError): + await multihost.sync_global_processes( + 'test_timeout', timeout=2, operation_id='op' + ) + + @parameterized.parameters( + (1,), + (5,), + ) + async def test_sync_global_processes_background_tasks(self, num_tasks): + paths = [self.tmpdir / f'dummy_{t}' for t in range(num_tasks)] + + async def background_fn(t): + path = paths[t] + await primary_host_sleep_and_mkdir(path) + await multihost.sync_global_processes( + f'test_sync_global_processes_{t}', operation_id='op' + ) + + async def background_fns(): + return await asyncio.gather(*[background_fn(t) for t in range(num_tasks)]) + + task = asyncio.create_task(background_fns()) + exists = await asyncio.gather(*[async_path.exists(path) for path in paths]) + self.assertFalse(all(exists)) + await task + exists = await asyncio.gather(*[async_path.exists(path) for path in paths]) + self.assertTrue(all(exists)) + + async def test_sync_global_processes_background_tasks_sequential(self): + async def fn1(): + await primary_host_sleep_and_mkdir(self.tmpdir / 'dummy1a') + await multihost.sync_global_processes( + 'test_sync_global_processes_1', operation_id='op' + ) + await primary_host_sleep_and_mkdir(self.tmpdir / 'dummy1b') + await multihost.sync_global_processes( + 'test_sync_global_processes_1', operation_id='op' + ) + + async def fn2(): + path = self.tmpdir / 'dummy2' + await primary_host_sleep_and_mkdir(path) + await multihost.sync_global_processes( + 'test_sync_global_processes_2', operation_id='op' + ) + + async def background_fns(): + return await asyncio.gather(*[fn1(), fn2()]) + + task = asyncio.create_task(background_fns()) + self.assertFalse(await async_path.exists(self.tmpdir / 'dummy1a')) + self.assertFalse(await async_path.exists(self.tmpdir / 'dummy1b')) + self.assertFalse(await async_path.exists(self.tmpdir / 'dummy2')) + await asyncio.sleep(2.5) + self.assertTrue(await async_path.exists(self.tmpdir / 'dummy1a')) + self.assertFalse(await async_path.exists(self.tmpdir / 'dummy1b')) + self.assertTrue(await async_path.exists(self.tmpdir / 'dummy2')) + await asyncio.sleep(2.5) + self.assertTrue(await async_path.exists(self.tmpdir / 'dummy1a')) + self.assertTrue(await async_path.exists(self.tmpdir / 'dummy1b')) + self.assertTrue(await async_path.exists(self.tmpdir / 'dummy2')) + await task + + async def test_sequential_execution(self): + + async def sleep_and_sync(i): + if multihost.process_index() == 0: + await asyncio.sleep(2) + await multihost.sync_global_processes( + f'test_sleep_and_sync_{i}', operation_id='op' + ) + + start = time.time() + await sleep_and_sync(0) + await sleep_and_sync(1) + end = time.time() + self.assertGreaterEqual(end - start, 4) + + async def test_parallel_execution(self): + + async def sleep_and_sync(i): + if multihost.process_index() == 0: + await asyncio.sleep(2) + await multihost.sync_global_processes( + f'test_sleep_and_sync_{i}', operation_id='op' + ) + + start = time.time() + await asyncio.gather(*[sleep_and_sync(0), sleep_and_sync(1)]) + end = time.time() + self.assertLess(end - start, 3) + + async def test_sync_global_processes_partial(self): + participating_processes = {0, 2} + primary_process = 0 + non_primary_process = 1 + + directory = self.tmpdir / 'testdir' + if multihost.process_index() == primary_process: + directory.mkdir(parents=False, exist_ok=False) + test_utils.sync_global_processes('test_sync_global_processes_partial_setup') + + if multihost.process_index() == primary_process: + time.sleep(2) + (directory / 'dummy').mkdir(parents=False, exist_ok=False) + if multihost.process_index() in participating_processes: + await multihost.sync_global_processes( + 'test_sync_global_processes_partial_one', + processes=participating_processes, + operation_id='op', + ) + if multihost.process_index() in participating_processes: + self.assertTrue((directory / 'dummy').exists()) + else: + self.assertFalse((directory / 'dummy').exists()) + + if multihost.process_index() == primary_process: + time.sleep(2) + (directory / 'foo').mkdir(parents=False, exist_ok=False) + if multihost.process_index() in participating_processes: + await multihost.sync_global_processes( + 'test_sync_global_processes_partial_two', + processes=participating_processes, + operation_id='op', + ) + if multihost.process_index() in participating_processes: + self.assertTrue((directory / 'foo').exists()) + else: + self.assertFalse((directory / 'foo').exists()) + + await multihost.sync_global_processes( + 'test_sync_global_processes_partial_all', operation_id='op' + ) + # If non-primary processes get past the above barrier without waiting for + # all, then an error would happen for the primary process when trying to + # create subdirectories. + if multihost.process_index() == non_primary_process: + directory.rmtree() + + async def test_different_barriers(self): + slice1 = {0, 2} + slice2 = {1, 3} + primary_processes = [0, 1] + + if multihost.process_index() in primary_processes: + # Don't sleep for slice1, but do sleep for slice2, so that when slice1 + # finishes waiting at the barrier, one file exists but the other does + # not. + time.sleep(3 * multihost.process_index()) + (self.tmpdir / f'dummy_{multihost.process_index()}').mkdir( + parents=False, exist_ok=False + ) + + if multihost.process_index() in slice1: + await multihost.sync_global_processes( + 'test_different_barriers_slice1', + operation_id='op', + processes=slice1, + ) + else: + await multihost.sync_global_processes( + 'test_different_barriers_slice2', + operation_id='op', + processes=slice2, + ) + if multihost.process_index() in slice1: + self.assertTrue((self.tmpdir / 'dummy_0').exists()) + self.assertFalse((self.tmpdir / 'dummy_1').exists()) + else: + self.assertTrue((self.tmpdir / 'dummy_0').exists()) + self.assertTrue((self.tmpdir / 'dummy_1').exists()) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/testing/local_path_test.py b/checkpoint/orbax/checkpoint/testing/local_path_test.py new file mode 100644 index 0000000000..59d386a8f7 --- /dev/null +++ b/checkpoint/orbax/checkpoint/testing/local_path_test.py @@ -0,0 +1,82 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +from etils import epath +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint.testing import local_path + +LocalPath = local_path.LocalPath + + +class LocalPathTest(parameterized.TestCase, multiprocess_test.MultiProcessTest): + + def setUp(self): + super().setUp() + self.invalid_path = self.multiprocess_create_tempdir(name="foo") + self.base_path = local_path.create_local_path_base(self) + self.assertGreater(multihost.process_count(), 1) + test_utils.sync_global_processes("LocalPathTest:setup_complete") + + def tearDown(self): + test_utils.sync_global_processes("LocalPathTest:tests_complete") + super().tearDown() + + def assertPathEqual(self, p1, p2): + self.assertEqual(p1, p2, f"{p1} != {p2}") + + def assertPathExists(self, p): + self.assertTrue(p.exists(), f"{p} does not exist.") + + @parameterized.product(input_cls=[epath.Path, str]) + def test_construction(self, input_cls): + p = LocalPath(input_cls(self.base_path)) + self.assertPathEqual( + p.path, + epath.Path(self.base_path) / f"local_{multihost.process_index()}", + ) + p = epath.Path(p) + self.assertPathEqual( + p, epath.Path(self.base_path) / f"local_{multihost.process_index()}" + ) + + def test_mkdir(self): + base_path = epath.Path(self.base_path) + p = LocalPath(base_path) + p.mkdir(parents=False, exist_ok=False) + self.assertPathExists(base_path / f"local_{multihost.process_index()}") + + def test_join(self): + base_path = epath.Path(self.base_path) + p = LocalPath(base_path) + p.mkdir(parents=False, exist_ok=False) + self.assertPathExists(base_path / f"local_{multihost.process_index()}") + p /= "foobar" + p.mkdir(parents=False, exist_ok=False) + self.assertPathExists( + base_path / f"local_{multihost.process_index()}" / "foobar" + ) + + def test_invalid_path(self): + p = LocalPath(self.invalid_path) + with self.assertRaisesRegex( + ValueError, f"must contain {local_path._LOCAL_PATH_BASE_NAME}" + ): + p.exists() + + +if __name__ == "__main__": + multiprocess_test.main()