diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 1f0a7ed09..d06e792d4 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -635,6 +635,7 @@ def _serialize_arrays( deprioritized = prioritized_async + deprioritized + if dispatcher is None: return _serialize_arrays_batches_without_dispatcher( prioritized,