Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,6 @@ async def async_save(
the data from its source will be awaited in this function.
"""
start_time = time.time()
initial_ts_metrics = ts.experimental_collect_matching_metrics(
'/tensorstore/'
)
item = args.item
# Reject only zero-leaf items (empty containers, None). A single falsy leaf
# (0, '', False, zero-array) is a valid one-leaf tree and must be allowed.
Expand Down Expand Up @@ -700,9 +697,6 @@ async def async_save(
else:
requests_to_save = batch_requests

tree_memory_size = async_io_engine.compute_save_memory_size(
requests_to_save
)
commit_futures = await self._async_io_engine.execute_save(requests_to_save)
# Flatten to List[future.Future].
commit_futures, _ = jax.tree.flatten(commit_futures)
Expand Down Expand Up @@ -730,11 +724,14 @@ async def async_save(
else:
save_futures += commit_futures


tree_memory_size = async_io_engine.compute_memory_size(
[req.values for req in requests_to_save]
)
async_io_engine.log_io_metrics(
tree_memory_size,
start_time,
'/jax/orbax/write/blocking_gbytes_per_sec',
primary_host=self._primary_host,
)
chained_futures = [
future.ChainedFuture(
Expand All @@ -745,7 +742,7 @@ async def async_save(
start_time,
'/jax/orbax/write/gbytes_per_sec',
'/jax/orbax/write/gbytes',
initial_ts_metrics=initial_ts_metrics,
primary_host=self._primary_host,
),
)
]
Expand Down Expand Up @@ -792,7 +789,7 @@ async def _maybe_deserialize(
metadata: PyTree,
param_infos: PyTree,
restore_args: PyTree,
) -> Tuple[int, PyTree]:
) -> PyTree:
"""Deserializes values or skips."""
flat_metadata = tree_utils.to_flat_dict(metadata)
byte_limiter = limits.get_byte_limiter(self._restore_concurrent_bytes)
Expand All @@ -810,9 +807,6 @@ async def _maybe_deserialize(
deserialized_batches = await self._async_io_engine.execute_restore(
batch_requests
)
tree_memory_size = async_io_engine.compute_restore_memory_size(
batch_requests, deserialized_batches
)

flat_restored = {}
for request, deserialized in zip(batch_requests, deserialized_batches):
Expand All @@ -832,9 +826,7 @@ async def _maybe_deserialize(
# Restore using `item` as the target structure. If there are any custom
# nodes (e.g. optax.EmptyState), these will replace None values in
# flat_restored.
return tree_memory_size, tree_utils.from_flat_dict(
flat_restored, target=item
)
return tree_utils.from_flat_dict(flat_restored, target=item)

def _partial_restore_with_omission(
self,
Expand Down Expand Up @@ -1087,7 +1079,7 @@ class TrainState:
raise_array_data_missing_error=raise_array_data_missing_error,
)
# Begin restore.
tree_memory_size, restored_item = asyncio_utils.run_sync(
restored_item = asyncio_utils.run_sync(
self._maybe_deserialize(
item, value_metadata_tree, param_infos, restore_args
)
Expand All @@ -1111,12 +1103,13 @@ class TrainState:
json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')),
)


tree_memory_size = async_io_engine.compute_memory_size(restored_item)
async_io_engine.log_io_metrics(
tree_memory_size,
start_time,
'/jax/checkpoint/read/gbytes_per_sec',
'/jax/checkpoint/read/gbytes', # device memory usage
primary_host=self._primary_host,
)
return restored_item

Expand Down
121 changes: 49 additions & 72 deletions checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import asyncio
import contextlib
import dataclasses
import math
import sys
import threading
import time
Expand All @@ -50,11 +51,11 @@
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
import tensorstore as ts
from orbax.checkpoint._src.tree import types as tree_types

PyTree = tree_types.PyTree
TypeHandler = types.TypeHandler
ParamInfo = types.ParamInfo
SaveArgs = type_handlers.SaveArgs
Expand All @@ -66,79 +67,76 @@
MemorySizes = Tuple[int, int]


def _default_sizeof_values(values: BatchOfInts) -> BatchOfInts:
return [sys.getsizeof(v) for v in values]
def _get_memory_size(value: Any) -> int:
"""Gets memory size for a leaf value.

The value is expected to be symmetric for save and load and represents the
total memory allocated across all devices.

def get_batch_memory_size(
handler: TypeHandler, values: BatchOfLeaves
) -> MemorySizes:
"""Gets memory size for a batch of leaf values."""
try:
write_sizes, read_sizes = zip(*handler.memory_size(values))
except NotImplementedError:
logging.warning(
'`memory_size` is not implemented for `TypeHandler` of type: %s. Using'
' the a default implementation to measure value memory consumption that'
' may result in inaccurate estimation.',
type(handler),
)
write_sizes = read_sizes = _default_sizeof_values(values)
assert len(write_sizes) == len(values)
assert len(read_sizes) == len(values)
return sum(write_sizes), sum(read_sizes)
Args:
value: The leaf object to inspect.

Returns:
The estimated memory footprint in bytes.
"""
if hasattr(value, 'nbytes'):
return int(value.nbytes)
if hasattr(value, 'shape') and hasattr(value, 'dtype'):
itemsize = getattr(value.dtype, 'itemsize', 1)
return int(math.prod(value.shape) * itemsize)
if isinstance(value, (int, float, complex)):
return sys.getsizeof(value)
if isinstance(value, bytes):
return len(value)
if isinstance(value, str):
return len(value.encode('utf-8'))
return sys.getsizeof(value)


def compute_memory_size(values: PyTree) -> int:
"""Computes the total memory size for a sequence of batch requests.

Args:
values: Pytree of leaves or values to compute size for.

Returns:
Total memory size in bytes.
"""
leaves = jax.tree.leaves(values)
return sum(_get_memory_size(v) for v in leaves)


def log_io_metrics(
size: int,
start_time: float,
gbytes_per_sec_metric: str,
gbytes_metric: str | None = None,
initial_ts_metrics: Sequence[dict[str, Any]] | None = None,
*,
primary_host: int | None,
):
"""Logs the bytes per second metric."""
time_elapsed = time.time() - start_time
bytes_per_sec = (
float('nan') if time_elapsed == 0 else float(size) / time_elapsed
)
note = 'per-host'
logging.info(
'[process=%d] %s: %s/s (total size: %s) (time elapsed: %s s) (%s)',
'[process=%d] %s: %s/s (total size: %s) (time elapsed: %s s) (global)',
multihost.process_index(),
gbytes_per_sec_metric,
humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'),
humanize.naturalsize(size, binary=True),
time_elapsed,
note,
)
jax.monitoring.record_scalar(
gbytes_per_sec_metric, value=bytes_per_sec / (1024**3)
)
if gbytes_metric is not None:
jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3))
if initial_ts_metrics is not None:
final_ts_metrics = ts.experimental_collect_matching_metrics('/tensorstore/')
initial_bytes = ts_utils.get_total_bytes_from_tensorstore(
initial_ts_metrics, types.IoDirection.WRITE
if primary_host is None:
logging.warning(
'Global object size logging disabled for `primary_host=None`.'
)
final_bytes = ts_utils.get_total_bytes_from_tensorstore(
final_ts_metrics, types.IoDirection.WRITE
elif multihost.is_primary_host(primary_host):
jax.monitoring.record_scalar(
gbytes_per_sec_metric, value=bytes_per_sec / (1024**3)
)
compressed_bytes = final_bytes - initial_bytes

if compressed_bytes > 0 and size > 0:
ratio = float(compressed_bytes) / size
logging.info(
'[process=%d] Compression ratio: %.3f (%s / %s)',
multihost.process_index(),
ratio,
humanize.naturalsize(compressed_bytes, binary=True),
humanize.naturalsize(size, binary=True),
)
jax.monitoring.record_scalar('/jax/orbax/write/compression_ratio', ratio)
jax.monitoring.record_scalar(
'/jax/orbax/write/compressed_gbytes', compressed_bytes / (1024**3)
)
if gbytes_metric is not None:
jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3))


async def logging_serialize(
Expand Down Expand Up @@ -202,27 +200,6 @@ def memory_profiler_context():
memory_regulator.profiler_end()


def compute_save_memory_size(batch_requests: BatchRequests) -> int:
"""Computes the total write memory size for a sequence of batch requests."""
tree_memory_size = 0
for request in batch_requests:
write_size, _ = get_batch_memory_size(request.handler, request.values)
tree_memory_size += write_size
return tree_memory_size


def compute_restore_memory_size(
batch_requests: BatchRequests,
deserialized_batches: Batches,
) -> int:
"""Computes the total read memory size for deserialized batches."""
tree_memory_size = 0
for request, deserialized in zip(batch_requests, deserialized_batches):
_, read_size = get_batch_memory_size(request.handler, deserialized)
tree_memory_size += read_size
return tree_memory_size


class AsyncIoEngine:
"""Encapsulates concurrency, thread-pooling, and I/O telemetry logic."""

Expand Down
Loading
Loading