Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
import tensorstore as ts
Expand Down Expand Up @@ -88,19 +89,6 @@ def get_batch_memory_size(
return sum(write_sizes), sum(read_sizes)


def _get_total_bytes_written_from_tensorstore(
metrics: Sequence[dict[str, Any]],
) -> int:
total = 0
for m in metrics:
if m['name'].startswith('/tensorstore/kvstore/') and m['name'].endswith(
'/bytes_written'
):
for val in m['values']:
total += val['value']
return total


def log_io_metrics(
size: int,
start_time: float,
Expand Down Expand Up @@ -130,10 +118,12 @@ def log_io_metrics(
jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3))
if initial_ts_metrics is not None:
final_ts_metrics = ts.experimental_collect_matching_metrics('/tensorstore/')
initial_bytes = _get_total_bytes_written_from_tensorstore(
initial_ts_metrics
initial_bytes = ts_utils.get_total_bytes_from_tensorstore(
initial_ts_metrics, types.IoDirection.WRITE
)
final_bytes = ts_utils.get_total_bytes_from_tensorstore(
final_ts_metrics, types.IoDirection.WRITE
)
final_bytes = _get_total_bytes_written_from_tensorstore(final_ts_metrics)
compressed_bytes = final_bytes - initial_bytes

if compressed_bytes > 0 and size > 0:
Expand Down
219 changes: 184 additions & 35 deletions checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings

from absl import logging
from etils import epath
import humanize
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -236,6 +237,137 @@ def _get_replica_slices(
]


def _record_logical_metrics(
direction: types.IoDirection,
logical_bytes: int,
duration: float,
storage_type: str,
):
"""Records logical bytes, throughput, and duration to JAX monitoring."""
logical_throughput = logical_bytes / duration if duration > 0 else 0

logging.info(
'[process=%d] %s throughput: %s/s (total gbytes: %s) (time elapsed: %s s)'
' (per-host)',
multihost.process_index(),
f'/jax/orbax/{direction.value}/worker/io/requested',
humanize.naturalsize(logical_throughput, binary=True, format='%.3f'),
humanize.naturalsize(logical_bytes, binary=True),
duration,
)

jax.monitoring.record_event_duration_secs(
f'/jax/orbax/{direction.value}/worker/total_duration_secs',
duration,
storage_type=storage_type,
)

jax.monitoring.record_scalar(
f'/jax/orbax/{direction.value}/worker/io/requested/gbytes',
logical_bytes / (1024**3),
storage_type=storage_type,
)
jax.monitoring.record_scalar(
f'/jax/orbax/{direction.value}/worker/io/requested/throughput/gbytes_per_sec',
logical_throughput / (1024**3),
storage_type=storage_type,
)


def _record_raw_metrics(
direction: types.IoDirection,
logical_bytes: int,
duration: float,
storage_type: str,
initial_ts_metrics: Sequence[dict[str, Any]] | None = None,
):
"""Records raw metrics collected from TensorStore."""
if initial_ts_metrics is None:
return

try:
final_ts_metrics = ts.experimental_collect_matching_metrics('/tensorstore/')
except Exception: # pylint: disable=broad-except
final_ts_metrics = None

if final_ts_metrics is None:
return

initial_bytes = ts_utils.get_total_bytes_from_tensorstore(
initial_ts_metrics, direction
)
final_bytes = ts_utils.get_total_bytes_from_tensorstore(
final_ts_metrics, direction
)
raw_bytes = final_bytes - initial_bytes

if raw_bytes <= 0:
return

raw_throughput = raw_bytes / duration if duration > 0 else 0
logging.info(
'[process=%d] Raw %s throughput: %s/s (total gbytes: %s) (time elapsed:'
' %s s) (per-host)',
multihost.process_index(),
f'/jax/orbax/{direction.value}/worker/io/raw',
humanize.naturalsize(raw_throughput, binary=True, format='%.3f'),
humanize.naturalsize(raw_bytes, binary=True),
duration,
)
jax.monitoring.record_scalar(
f'/jax/orbax/{direction.value}/worker/io/raw/gbytes',
raw_bytes / (1024**3),
storage_type=storage_type,
)
jax.monitoring.record_scalar(
f'/jax/orbax/{direction.value}/worker/io/raw/throughput/gbytes_per_sec',
raw_throughput / (1024**3),
storage_type=storage_type,
)

if logical_bytes > 0:
ratio = float(raw_bytes) / logical_bytes
logging.info(
'[process=%d] %s ratio (raw/logical): %.3f (%s / %s)',
multihost.process_index(),
direction.value.capitalize(),
ratio,
humanize.naturalsize(raw_bytes, binary=True),
humanize.naturalsize(logical_bytes, binary=True),
)
jax.monitoring.record_scalar(
f'/jax/orbax/{direction.value}/worker/io/compression_ratio',
ratio,
storage_type=storage_type,
)


def _log_io_metrics(
direction: types.IoDirection,
logical_bytes: int,
start_time: float,
parent_dir: epath.Path,
initial_ts_metrics: Sequence[dict[str, Any]] | None = None,
):
"""Logs and records IO telemetry metrics for array serialization/deserialization."""
duration = time.time() - start_time
storage_type = path_utils.get_storage_type(parent_dir)

_record_logical_metrics(
direction,
logical_bytes,
duration,
storage_type,
)
_record_raw_metrics(
direction,
logical_bytes,
duration,
storage_type,
initial_ts_metrics=initial_ts_metrics,
)


def _worker_serialize_arrays(
arrays: Sequence[jax.Array],
infos: Sequence[types.ParamInfo],
Expand All @@ -251,6 +383,13 @@ def _worker_serialize_arrays(
ext_metadata: Dict[str, Any],
):
"""Worker function to serialize arrays."""
try:
initial_ts_metrics = ts.experimental_collect_matching_metrics(
'/tensorstore/'
)
except Exception: # pylint: disable=broad-except
initial_ts_metrics = None
total_start_time = time.time()
rslices_per_array = _get_replica_slices(
arrays,
replica_id,
Expand All @@ -272,6 +411,15 @@ def _worker_serialize_arrays(
ext_metadata=ext_metadata,
)
)
if infos:
total_io_bytes = sum(v.nbytes for v in rslices_per_array)
_log_io_metrics(
direction=types.IoDirection.WRITE,
logical_bytes=total_io_bytes,
start_time=total_start_time,
parent_dir=infos[0].parent_dir,
initial_ts_metrics=initial_ts_metrics,
)


def _get_deprioritized_batches_to_serialize(
Expand Down Expand Up @@ -380,7 +528,19 @@ def _serialize_arrays_batches_without_dispatcher(
)

async def _serialize_without_dispatcher():
if not prioritized and not deprioritized:
return
try:
initial_ts_metrics = ts.experimental_collect_matching_metrics(
'/tensorstore/'
)
except Exception: # pylint: disable=broad-except
initial_ts_metrics = None
total_start_time = time.time()
total_io_bytes = 0

if prioritized_values_on_host:
total_io_bytes += sum(v.nbytes for v in prioritized_values_on_host)
await async_serialize_replica_slices_batch(
prioritized_values_on_host,
prioritized_infos,
Expand All @@ -404,13 +564,23 @@ async def _serialize_without_dispatcher():
):
b_arrays_on_host = replica_slices_transfer_arrays_to_host(b_arrays)
_on_batch_callback(b_infos, callback.on_transfer_end)
total_io_bytes += sum(v.nbytes for v in b_arrays_on_host)
await async_serialize_replica_slices_batch(
b_arrays_on_host,
b_infos,
b_args,
)
_on_batch_callback(b_infos, callback.on_write_end)

info_sample = prioritized[0][1] if prioritized else deprioritized[0][1]
_log_io_metrics(
direction=types.IoDirection.WRITE,
logical_bytes=total_io_bytes,
start_time=total_start_time,
parent_dir=info_sample.parent_dir,
initial_ts_metrics=initial_ts_metrics,
)

return future.CommitFutureAwaitingContractedSignals(
_serialize_without_dispatcher(),
name='array_type_handler',
Expand Down Expand Up @@ -787,6 +957,12 @@ async def _deserialize_arrays(
array_metadata_store: array_metadata_store_lib.Store | None,
) -> Sequence[jax.Array]:
"""Deserializes arrays and applies array_metadata if available."""
try:
initial_ts_metrics = ts.experimental_collect_matching_metrics(
'/tensorstore/'
)
except Exception: # pylint: disable=broad-except
initial_ts_metrics = None
total_start_time = time.time()

async def _async_deserialize(
Expand Down Expand Up @@ -884,41 +1060,14 @@ async def _async_deserialize(
metadata_key=metadata_key,
)

total_duration = time.time() - total_start_time
io_throughput = total_io_bytes / total_duration if total_duration > 0 else 0

storage_type = path_utils.get_storage_type(infos[0].parent_dir)

logging.info(
'[process=%d] %s throughput: %s/s (total gbytes: %s) (time elapsed: %s s)'
' (per-host)',
multihost.process_index(),
'/jax/orbax/read/worker/io/requested',
humanize.naturalsize(io_throughput, binary=True, format='%.3f'),
humanize.naturalsize(total_io_bytes, binary=True),
total_duration,
)

# Record total duration of the read operation. Note that for McJAX, it
# includes IO time and H2D transfer time. For Pathways Remote Python,
# it includes only IO time.
jax.monitoring.record_event_duration_secs(
'/jax/orbax/read/worker/total_duration_secs',
total_duration,
storage_type=storage_type,
)

# record total bytes requested to be read from IO
jax.monitoring.record_scalar(
'/jax/orbax/read/worker/io/requested/gbytes',
total_io_bytes / (1024**3),
storage_type=storage_type,
)
jax.monitoring.record_scalar(
'/jax/orbax/read/worker/io/requested/throughput/gbytes_per_sec',
io_throughput / (1024**3),
storage_type=storage_type,
)
if infos:
_log_io_metrics(
direction=types.IoDirection.READ,
logical_bytes=total_io_bytes,
start_time=total_start_time,
parent_dir=infos[0].parent_dir,
initial_ts_metrics=initial_ts_metrics,
)
return ret


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class SingleReplicaArrayRestoreArgs(ArrayRestoreArgs):

def __post_init__(self):
super().__post_init__()
logging.log_first_n(
logging.WARNING,
'`single_replica_sharding` is deprecated and will be removed in a'
' future version. It is not needed, as Orbax code will automatically'
' construct a single-replica sharding used for restoring before'
' broadcasting.',
1,
)
if self.single_replica_sharding is not None:
logging.log_first_n(
logging.WARNING,
'`single_replica_sharding` is deprecated and will be removed in a'
' future version. It is not needed, as Orbax code will automatically'
' construct a single-replica sharding used for restoring before'
' broadcasting.',
1,
)

Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,29 @@ def array_metadata_from_tensorstore(
)


def get_total_bytes_from_tensorstore(
metrics: Sequence[dict[str, Any]], direction: types.IoDirection
) -> int:
"""Sums bytes_read or bytes_written from all kvstore drivers in metrics."""
total = 0
if direction == types.IoDirection.WRITE:
suffix = '/bytes_written'
elif direction == types.IoDirection.READ:
suffix = '/bytes_read'
else:
raise ValueError(f'Invalid direction: {direction}')

for m in metrics:
if not isinstance(m, dict):
continue
name = m.get('name', '')
if name.startswith('/tensorstore/kvstore/') and name.endswith(suffix):
for val in m.get('values', []):
if isinstance(val, dict):
total += val.get('value', 0)
return total


def print_ts_debug_data(key: str | None, infos: Sequence[types.ParamInfo]):
"""Log Tensorstore related metrics."""
ts_metrics = ts.experimental_collect_matching_metrics('/tensorstore')
Expand Down
Loading
Loading