diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 98aea2be5f..09f665d482 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -640,9 +640,6 @@ async def async_save( the data from its source will be awaited in this function. """ start_time = time.time() - initial_ts_metrics = ts.experimental_collect_matching_metrics( - '/tensorstore/' - ) item = args.item # Reject only zero-leaf items (empty containers, None). A single falsy leaf # (0, '', False, zero-array) is a valid one-leaf tree and must be allowed. @@ -700,9 +697,6 @@ async def async_save( else: requests_to_save = batch_requests - tree_memory_size = async_io_engine.compute_save_memory_size( - requests_to_save - ) commit_futures = await self._async_io_engine.execute_save(requests_to_save) # Flatten to List[future.Future]. commit_futures, _ = jax.tree.flatten(commit_futures) @@ -730,11 +724,14 @@ async def async_save( else: save_futures += commit_futures - + tree_memory_size = async_io_engine.compute_memory_size( + [req.values for req in requests_to_save] + ) async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/orbax/write/blocking_gbytes_per_sec', + primary_host=self._primary_host, ) chained_futures = [ future.ChainedFuture( @@ -745,7 +742,7 @@ async def async_save( start_time, '/jax/orbax/write/gbytes_per_sec', '/jax/orbax/write/gbytes', - initial_ts_metrics=initial_ts_metrics, + primary_host=self._primary_host, ), ) ] @@ -792,7 +789,7 @@ async def _maybe_deserialize( metadata: PyTree, param_infos: PyTree, restore_args: PyTree, - ) -> Tuple[int, PyTree]: + ) -> PyTree: """Deserializes values or skips.""" flat_metadata = tree_utils.to_flat_dict(metadata) byte_limiter = limits.get_byte_limiter(self._restore_concurrent_bytes) @@ -810,9 +807,6 @@ async def _maybe_deserialize( deserialized_batches = await self._async_io_engine.execute_restore( batch_requests ) - tree_memory_size = async_io_engine.compute_restore_memory_size( - batch_requests, deserialized_batches - ) flat_restored = {} for request, deserialized in zip(batch_requests, deserialized_batches): @@ -832,9 +826,7 @@ async def _maybe_deserialize( # Restore using `item` as the target structure. If there are any custom # nodes (e.g. optax.EmptyState), these will replace None values in # flat_restored. - return tree_memory_size, tree_utils.from_flat_dict( - flat_restored, target=item - ) + return tree_utils.from_flat_dict(flat_restored, target=item) def _partial_restore_with_omission( self, @@ -1087,7 +1079,7 @@ class TrainState: raise_array_data_missing_error=raise_array_data_missing_error, ) # Begin restore. - tree_memory_size, restored_item = asyncio_utils.run_sync( + restored_item = asyncio_utils.run_sync( self._maybe_deserialize( item, value_metadata_tree, param_infos, restore_args ) @@ -1111,12 +1103,13 @@ class TrainState: json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')), ) - + tree_memory_size = async_io_engine.compute_memory_size(restored_item) async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/checkpoint/read/gbytes_per_sec', '/jax/checkpoint/read/gbytes', # device memory usage + primary_host=self._primary_host, ) return restored_item diff --git a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py index 22a0ea94f7..aa5b0f0122 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py @@ -39,6 +39,7 @@ import asyncio import contextlib import dataclasses +import math import sys import threading import time @@ -50,11 +51,11 @@ from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.serialization import memory_regulator -from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.serialization import types -import tensorstore as ts +from orbax.checkpoint._src.tree import types as tree_types +PyTree = tree_types.PyTree TypeHandler = types.TypeHandler ParamInfo = types.ParamInfo SaveArgs = type_handlers.SaveArgs @@ -66,27 +67,43 @@ MemorySizes = Tuple[int, int] -def _default_sizeof_values(values: BatchOfInts) -> BatchOfInts: - return [sys.getsizeof(v) for v in values] +def _get_memory_size(value: Any) -> int: + """Gets memory size for a leaf value. + The value is expected to be symmetric for save and load and represents the + total memory allocated across all devices. -def get_batch_memory_size( - handler: TypeHandler, values: BatchOfLeaves -) -> MemorySizes: - """Gets memory size for a batch of leaf values.""" - try: - write_sizes, read_sizes = zip(*handler.memory_size(values)) - except NotImplementedError: - logging.warning( - '`memory_size` is not implemented for `TypeHandler` of type: %s. Using' - ' the a default implementation to measure value memory consumption that' - ' may result in inaccurate estimation.', - type(handler), - ) - write_sizes = read_sizes = _default_sizeof_values(values) - assert len(write_sizes) == len(values) - assert len(read_sizes) == len(values) - return sum(write_sizes), sum(read_sizes) + Args: + value: The leaf object to inspect. + + Returns: + The estimated memory footprint in bytes. + """ + if hasattr(value, 'nbytes'): + return int(value.nbytes) + if hasattr(value, 'shape') and hasattr(value, 'dtype'): + itemsize = getattr(value.dtype, 'itemsize', 1) + return int(math.prod(value.shape) * itemsize) + if isinstance(value, (int, float, complex)): + return sys.getsizeof(value) + if isinstance(value, bytes): + return len(value) + if isinstance(value, str): + return len(value.encode('utf-8')) + return sys.getsizeof(value) + + +def compute_memory_size(values: PyTree) -> int: + """Computes the total memory size for a sequence of batch requests. + + Args: + values: Pytree of leaves or values to compute size for. + + Returns: + Total memory size in bytes. + """ + leaves = jax.tree.leaves(values) + return sum(_get_memory_size(v) for v in leaves) def log_io_metrics( @@ -94,51 +111,32 @@ def log_io_metrics( start_time: float, gbytes_per_sec_metric: str, gbytes_metric: str | None = None, - initial_ts_metrics: Sequence[dict[str, Any]] | None = None, + *, + primary_host: int | None, ): """Logs the bytes per second metric.""" time_elapsed = time.time() - start_time bytes_per_sec = ( float('nan') if time_elapsed == 0 else float(size) / time_elapsed ) - note = 'per-host' logging.info( - '[process=%d] %s: %s/s (total size: %s) (time elapsed: %s s) (%s)', + '[process=%d] %s: %s/s (total size: %s) (time elapsed: %s s) (global)', multihost.process_index(), gbytes_per_sec_metric, humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'), humanize.naturalsize(size, binary=True), time_elapsed, - note, ) - jax.monitoring.record_scalar( - gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) - ) - if gbytes_metric is not None: - jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) - if initial_ts_metrics is not None: - final_ts_metrics = ts.experimental_collect_matching_metrics('/tensorstore/') - initial_bytes = ts_utils.get_total_bytes_from_tensorstore( - initial_ts_metrics, types.IoDirection.WRITE + if primary_host is None: + logging.warning( + 'Global object size logging disabled for `primary_host=None`.' ) - final_bytes = ts_utils.get_total_bytes_from_tensorstore( - final_ts_metrics, types.IoDirection.WRITE + elif multihost.is_primary_host(primary_host): + jax.monitoring.record_scalar( + gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) ) - compressed_bytes = final_bytes - initial_bytes - - if compressed_bytes > 0 and size > 0: - ratio = float(compressed_bytes) / size - logging.info( - '[process=%d] Compression ratio: %.3f (%s / %s)', - multihost.process_index(), - ratio, - humanize.naturalsize(compressed_bytes, binary=True), - humanize.naturalsize(size, binary=True), - ) - jax.monitoring.record_scalar('/jax/orbax/write/compression_ratio', ratio) - jax.monitoring.record_scalar( - '/jax/orbax/write/compressed_gbytes', compressed_bytes / (1024**3) - ) + if gbytes_metric is not None: + jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) async def logging_serialize( @@ -202,27 +200,6 @@ def memory_profiler_context(): memory_regulator.profiler_end() -def compute_save_memory_size(batch_requests: BatchRequests) -> int: - """Computes the total write memory size for a sequence of batch requests.""" - tree_memory_size = 0 - for request in batch_requests: - write_size, _ = get_batch_memory_size(request.handler, request.values) - tree_memory_size += write_size - return tree_memory_size - - -def compute_restore_memory_size( - batch_requests: BatchRequests, - deserialized_batches: Batches, -) -> int: - """Computes the total read memory size for deserialized batches.""" - tree_memory_size = 0 - for request, deserialized in zip(batch_requests, deserialized_batches): - _, read_size = get_batch_memory_size(request.handler, deserialized) - tree_memory_size += read_size - return tree_memory_size - - class AsyncIoEngine: """Encapsulates concurrency, thread-pooling, and I/O telemetry logic.""" diff --git a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py deleted file mode 100644 index 37bf88d370..0000000000 --- a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2026 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import unittest -from unittest import mock - -from absl.testing import absltest -import numpy as np -from orbax.checkpoint._src.serialization import async_io_engine -from orbax.checkpoint._src.serialization import types - -AsyncIoEngine = async_io_engine.AsyncIoEngine -BatchRequest = async_io_engine.BatchRequest - - -class AsyncIoEngineTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): - - def test_get_batch_memory_size_success(self): - handler = mock.create_autospec(types.TypeHandler, instance=True) - handler.memory_size.return_value = [(10, 20), (30, 40)] - - write_size, read_size = async_io_engine.get_batch_memory_size( - handler, ['a', 'b'] - ) - self.assertEqual(write_size, 40) - self.assertEqual(read_size, 60) - - def test_get_batch_memory_size_not_implemented(self): - handler = mock.create_autospec(types.TypeHandler, instance=True) - handler.memory_size.side_effect = NotImplementedError() - - values = ['dummy1', 'dummy2'] - expected_size = sum(sys.getsizeof(v) for v in values) - - write_size, read_size = async_io_engine.get_batch_memory_size( - handler, values - ) - self.assertEqual(write_size, expected_size) - self.assertEqual(read_size, expected_size) - - def test_batch_request_validation_success(self): - handler = mock.create_autospec(types.TypeHandler, instance=True) - req = BatchRequest( - handler=handler, - keys=['k1', 'k2'], - values=['v1', 'v2'], - infos=[mock.Mock(), mock.Mock()], - args=[mock.Mock(), mock.Mock()], - ) - self.assertLen(req.values, 2) - - def test_batch_request_validation_mismatch(self): - handler = mock.create_autospec(types.TypeHandler, instance=True) - with self.assertRaises(AssertionError): - BatchRequest( - handler=handler, - keys=['k1'], - values=['v1', 'v2'], - infos=[mock.Mock(), mock.Mock()], - args=[mock.Mock(), mock.Mock()], - ) - - def test_compute_save_memory_size(self): - handler1 = mock.create_autospec(types.TypeHandler, instance=True) - handler2 = mock.create_autospec(types.TypeHandler, instance=True) - - # memory_size returns a list of (write_size, read_size) tuples - handler1.memory_size.return_value = [(100, 0)] - handler2.memory_size.return_value = [(200, 0)] - - req1 = BatchRequest( - handler=handler1, - keys=['k1'], - values=['v1'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - req2 = BatchRequest( - handler=handler2, - keys=['k2'], - values=['v2'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - - tree_memory_size = async_io_engine.compute_save_memory_size([req1, req2]) - self.assertEqual(tree_memory_size, 300) - - def test_compute_restore_memory_size(self): - handler1 = mock.create_autospec(types.TypeHandler, instance=True) - handler2 = mock.create_autospec(types.TypeHandler, instance=True) - - # memory_size returns a list of (write_size, read_size) tuples - handler1.memory_size.return_value = [(0, 50)] - handler2.memory_size.return_value = [(0, 150)] - - req1 = BatchRequest( - handler=handler1, - keys=['k1'], - values=['v1'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - req2 = BatchRequest( - handler=handler2, - keys=['k2'], - values=['v2'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - - deserialized_batches = [['restored1'], ['restored2']] - - tree_memory_size = async_io_engine.compute_restore_memory_size( - [req1, req2], deserialized_batches - ) - self.assertEqual(tree_memory_size, 200) - - async def test_execute_save(self): - engine = AsyncIoEngine() - - handler1 = mock.create_autospec(types.TypeHandler, instance=True) - handler2 = mock.create_autospec(types.TypeHandler, instance=True) - - async def dummy_serialize1(*args, **kwargs): - del args, kwargs - return ['fut1', 'fut2'] - - async def dummy_serialize2(*args, **kwargs): - del args, kwargs - return ['fut3'] - - handler1.serialize.side_effect = dummy_serialize1 - handler2.serialize.side_effect = dummy_serialize2 - - req1 = BatchRequest( - handler=handler1, - keys=['k1'], - values=['v1'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - req2 = BatchRequest( - handler=handler2, - keys=['k2'], - values=['v2'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - - commit_futures = await engine.execute_save([req1, req2]) - self.assertEqual(commit_futures, [['fut1', 'fut2'], ['fut3']]) - - # Test the standalone memory size function - handler1.memory_size.return_value = [(100, 0)] - handler2.memory_size.return_value = [(200, 0)] - tree_memory_size = async_io_engine.compute_save_memory_size([req1, req2]) - self.assertEqual(tree_memory_size, 300) - - async def test_execute_restore(self): - engine = AsyncIoEngine() - - handler1 = mock.create_autospec(types.TypeHandler, instance=True) - handler2 = mock.create_autospec(types.TypeHandler, instance=True) - - async def dummy_deserialize1(*args, **kwargs): - del args, kwargs - return ['restored1'] - - async def dummy_deserialize2(*args, **kwargs): - del args, kwargs - return ['restored2'] - - handler1.deserialize.side_effect = dummy_deserialize1 - handler2.deserialize.side_effect = dummy_deserialize2 - - req1 = BatchRequest( - handler=handler1, - keys=['k1'], - values=['v1'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - req2 = BatchRequest( - handler=handler2, - keys=['k2'], - values=['v2'], - infos=[mock.Mock()], - args=[mock.Mock()], - ) - - deserialized_batches = await engine.execute_restore([req1, req2]) - self.assertEqual(deserialized_batches, [['restored1'], ['restored2']]) - - # Test the standalone memory size function - handler1.memory_size.return_value = [(0, 50)] - handler2.memory_size.return_value = [(0, 150)] - tree_memory_size = async_io_engine.compute_restore_memory_size( - [req1, req2], deserialized_batches - ) - self.assertEqual(tree_memory_size, 200) - - @mock.patch.object(async_io_engine.jax.monitoring, 'record_scalar') - def test_log_io_metrics_compression_ratio(self, mock_record_scalar): - initial_ts_metrics = ( - async_io_engine.ts.experimental_collect_matching_metrics( - '/tensorstore/' - ) - ) - - # Perform actual TensorStore write to increment bytes_written. - ts_spec = async_io_engine.ts.Spec({ - 'driver': 'zarr', - 'kvstore': { - 'driver': 'file', - 'path': self.create_tempdir().full_path, - }, - 'metadata': { - 'compressor': {'id': 'zstd'}, - }, - }) - ts_store = async_io_engine.ts.open( - ts_spec, - create=True, - delete_existing=True, - dtype=np.int32, - shape=(10000,), - ).result() - ts_store.write(np.ones((10000,), dtype=np.int32)).result() - - async_io_engine.log_io_metrics( - size=40000, - start_time=12345.0, - gbytes_per_sec_metric='/jax/orbax/write/gbytes_per_sec', - initial_ts_metrics=initial_ts_metrics, - ) - - ratio_calls = [ - call - for call in mock_record_scalar.call_args_list - if call[0][0] == '/jax/orbax/write/compression_ratio' - ] - self.assertNotEmpty(ratio_calls) - ratio = ratio_calls[0][0][1] - # Verifies that compression actually reduced size - self.assertGreater(ratio, 0.0) - self.assertLess(ratio, 1.0) - - compressed_calls = [ - call - for call in mock_record_scalar.call_args_list - if call[0][0] == '/jax/orbax/write/compressed_gbytes' - ] - self.assertNotEmpty(compressed_calls) - self.assertGreater(compressed_calls[0][0][1], 0.0) - - @mock.patch.object(async_io_engine.jax.monitoring, 'record_scalar') - def test_log_io_metrics_compression_ratio_no_compression( - self, mock_record_scalar - ): - # Capture initial_ts_metrics, but perform no TensorStore writes, - # so compressed_bytes will be 0. - initial_ts_metrics = ( - async_io_engine.ts.experimental_collect_matching_metrics( - '/tensorstore/' - ) - ) - - async_io_engine.log_io_metrics( - size=4000, - start_time=12345.0, - gbytes_per_sec_metric='/jax/orbax/write/gbytes_per_sec', - initial_ts_metrics=initial_ts_metrics, - ) - - # Ensure compression_ratio was NOT recorded - for call in mock_record_scalar.call_args_list: - self.assertNotEqual(call[0][0], '/jax/orbax/write/compression_ratio') - - -if __name__ == '__main__': - absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 1f0a7ed095..a4ef38475a 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -21,7 +21,7 @@ import functools import os import time -from typing import Any, Callable, Dict, Sequence, Set, Tuple, TypeAlias, Union, cast +from typing import Any, Callable, Dict, Sequence, Set, TypeAlias, Union, cast import warnings from absl import logging @@ -340,6 +340,12 @@ def _record_raw_metrics( ratio, storage_type=storage_type, ) + if direction == types.IoDirection.WRITE: + jax.monitoring.record_scalar( + '/jax/orbax/write/worker/io/compressed_gbytes', + raw_bytes / (1024**3), + storage_type=storage_type, + ) def _log_io_metrics( @@ -1574,27 +1580,6 @@ async def deserialize( return ret # pytype: disable=bad-return-type - def memory_size( - self, values: Sequence[jax.Array] - ) -> Sequence[Tuple[int, int]]: - write_sizes = [] - read_sizes = [] - shard_size = lambda shard: shard.data.size * shard.data.dtype.itemsize - for v in values: - write_sizes.append( - replica_slices.get_replica_slices( - v, - replica_id=self._replica_id, - use_replica_parallel=self._use_replica_parallel, - min_slice_bytes_for_replica_parallel=self._min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel=self._max_replicas_for_replica_parallel, - ).nbytes - ) - read_sizes.append( - sum(shard_size(shard) for shard in v.addressable_shards) - ) - return list(zip(write_sizes, read_sizes)) - def _is_host_for_primary_replica(primary_replica_ids: set[int]) -> bool: return multihost.process_index() in primary_replica_ids @@ -1925,9 +1910,3 @@ async def deserialize( ret = _wrap_random_key_data(array_metadatas, infos, list(ret)) return ret - - # TODO(b/370396118): Calculation overestimates bytes read. - def memory_size( # pylint: disable=useless-parent-delegation - self, values: Sequence[jax.Array] - ) -> Sequence[Tuple[int, int]]: - return super().memory_size(values) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index cc5727f2a3..2a2c8b44cf 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -18,8 +18,7 @@ import asyncio import copy -import sys -from typing import Any, Dict, Optional, Sequence, Tuple, TypeAlias, Union +from typing import Any, Dict, Optional, Sequence, TypeAlias, Union from absl import logging import jax @@ -213,17 +212,6 @@ async def deserialize( return ret - def memory_size( - self, values: Sequence[np.ndarray] - ) -> Sequence[Tuple[int, int]]: - actual_sizes = [v.size * v.dtype.itemsize for v in values] - if multihost.process_index() == 0: - write_sizes = actual_sizes - else: - write_sizes = [0 for _ in values] - read_sizes = actual_sizes - return list(zip(write_sizes, read_sizes)) - class ScalarHandler(NumpyHandler): """A wrapper around NumpyHandler to deal with scalar types (int, float, etc.).""" @@ -269,15 +257,6 @@ async def deserialize( ] return results - def memory_size(self, values: Sequence[Scalar]) -> Sequence[Tuple[int, int]]: # pytype: disable=signature-mismatch - actual_sizes = [sys.getsizeof(v) for v in values] - if multihost.process_index() == 0: - write_sizes = actual_sizes - else: - write_sizes = [0 for _ in values] - read_sizes = actual_sizes - return list(zip(write_sizes, read_sizes)) - class StringHandler(types.TypeHandler): """TypeHandler for strings.""" @@ -384,15 +363,6 @@ async def deserialize( read_ops = [self._convert_to_string(t) for t in tensorstores] return await asyncio.gather(*read_ops) - def memory_size(self, values: Sequence[str]) -> Sequence[Tuple[int, int]]: - actual_sizes = [len(v.encode('utf-8')) for v in values] - if multihost.process_index() == 0: - write_sizes = actual_sizes - else: - write_sizes = [0 for _ in values] - read_sizes = actual_sizes - return list(zip(write_sizes, read_sizes)) - def is_placeholder(value: Any) -> bool: return value is PLACEHOLDER diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py index 9480995078..b19544b0ff 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers_test.py @@ -14,7 +14,6 @@ import asyncio import dataclasses -import sys import threading from typing import Any, Optional import unittest @@ -118,58 +117,6 @@ def get_replica_pids(rep_id: int, mesh: jax.sharding.Mesh): return ids, pids -def per_host_write_size(value: Any) -> int: - if not isinstance(value, jax.Array) and multihost.process_index() != 0: - return 0 - - if isinstance(value, np.ndarray): - return value.size * value.dtype.itemsize - elif isinstance(value, jax.Array): - return sum( - shard.data.size * shard.data.dtype.itemsize - for shard in value.addressable_shards - if shard.replica_id == 0 - ) - elif isinstance(value, str): - return len(value) - else: - return sys.getsizeof(value) - - -def per_host_read_size(value: Any) -> int: - if isinstance(value, np.ndarray): - return value.size * value.dtype.itemsize - elif isinstance(value, jax.Array): - return sum( - shard.data.size * shard.data.dtype.itemsize - for shard in value.addressable_shards - ) - elif isinstance(value, str): - return len(value) - else: - return sys.getsizeof(value) - - -def per_host_size(value: Any) -> int: - if isinstance(value, np.ndarray): - return ( - value.size * value.dtype.itemsize - if multihost.process_index() == 0 - else 0 - ) - elif isinstance(value, jax.Array): - shards = value.addressable_shards - total = 0 - for shard in shards: - if shard.replica_id == 0: - total += shard.data.size * shard.data.dtype.itemsize - return total - elif isinstance(value, str): - return len(value) if multihost.process_index() == 0 else 0 - else: - return sys.getsizeof(value) if multihost.process_index() == 0 else 0 - - class SerializationTest( unittest.IsolatedAsyncioTestCase, multiprocess_test.MultiProcessTest, @@ -694,20 +641,6 @@ class NumpyHandlerTest( ): """Test class.""" - def test_memory_size(self): - handler = type_handlers.NumpyHandler() - if multihost.process_index() == 0: - values = [np.arange(8, dtype=np.int32)] - else: - values = [np.arange(16, dtype=np.int32)] - write_sizes, read_sizes = zip(*handler.memory_size(values)) - self.assertSequenceEqual( - write_sizes, [per_host_write_size(v) for v in values] - ) - self.assertSequenceEqual( - read_sizes, [per_host_read_size(v) for v in values] - ) - async def test_metadata(self): if multihost.process_index() != 0: self.skipTest('Only run on host 0') @@ -753,32 +686,10 @@ async def test_metadata(self): class ScalarHandlerTest(parameterized.TestCase): """Test class.""" - def test_memory_size(self): - handler = type_handlers.ScalarHandler() - values = [3] - write_sizes, read_sizes = zip(*handler.memory_size(values)) - self.assertSequenceEqual( - write_sizes, [per_host_write_size(v) for v in values] - ) - self.assertSequenceEqual( - read_sizes, [per_host_read_size(v) for v in values] - ) - class StringHandlerTest(parameterized.TestCase): """Test class.""" - def test_memory_size(self): - handler = type_handlers.StringHandler() - values = ['a', 'foobar'] - write_sizes, read_sizes = zip(*handler.memory_size(values)) - self.assertSequenceEqual( - write_sizes, [per_host_write_size(v) for v in values] - ) - self.assertSequenceEqual( - read_sizes, [per_host_read_size(v) for v in values] - ) - class ArrayHandlerTest(parameterized.TestCase): """Test class.""" @@ -787,13 +698,6 @@ def setUp(self): super().setUp() self.pytree, _, _ = test_utils.setup_sharded_pytree() - def test_memory_size(self): - handler = type_handlers.ArrayHandler(use_replica_parallel=True) - values = jax.tree.leaves(self.pytree) - write_sizes, read_sizes = zip(*handler.memory_size(values)) - self.assertSequenceEqual(write_sizes, [32, 64, 32, 64]) - self.assertSequenceEqual(read_sizes, [256, 64, 32, 64]) - class ArrayHandlerCallbackTest( parameterized.TestCase, unittest.IsolatedAsyncioTestCase diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index 6994b47ddd..8f237510a9 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -20,7 +20,7 @@ import copy import dataclasses import enum -from typing import Any, Callable, Optional, Protocol, Sequence, Tuple +from typing import Any, Callable, Optional, Protocol, Sequence from absl import logging from etils import epath @@ -472,27 +472,6 @@ def finalize(self, directory: epath.Path): """ pass - def memory_size(self, values: Sequence[Any]) -> Sequence[Tuple[int, int]]: - """For a batch of values, returns the size of each value in bytes. - - Note that the default implementation uses `sys.getsizeof`, which is not - likely to be accurate for many types. - - The value returned is intended to be per-host. - - Args: - values: A batch of values. - - Returns: - A sequence of elements corresponding to `values`. Each element is a tuple - of [write_size, read_size]. In many cases these values may be the same. - - Raises: - NotImplementedError: Raises error by default since we will rely on a - backup implementation. - """ - raise NotImplementedError() - class TypeHandlerRegistry(Protocol): """A registry for TypeHandlers. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py index 8670850c49..e03f8a7519 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py @@ -15,7 +15,7 @@ """Compatibility wrapper to help leaf handlers to work as V0 type_handlers.""" import dataclasses -from typing import Any, Generic, Sequence, Tuple, Type, cast, get_args +from typing import Any, Generic, Sequence, Type, cast, get_args from absl import logging import jax @@ -393,21 +393,6 @@ async def metadata( ) return ret - def memory_size( - self, values: Sequence[types.Leaf] - ) -> Sequence[Tuple[int, int]]: - # this only works for leaf handler that based on V0 TypeHandlers and stored - # it in self._leaf_handler._handler_impl. - if hasattr(self._leaf_handler, '_handler_impl'): - v0_handler = self._leaf_handler._handler_impl # pylint: disable=protected-access - - return v0_handler.memory_size(values) - - raise NotImplementedError( - 'Cannot resolve memory_size for this v1 leaf handler, ' - f' {self._leaf_handler!r}.' - ) - @property def _array_metadata_store(self): # as the array_metadata_store.resolve_array_metadata_store read the metadata