CheckpointManager crashes when jax_enable_preemption_service=False
Description
CheckpointManager.save() unconditionally call multihost_utils.reached_preemption_sync_point(), which raises RuntimeError when JAX's preemption service is disabled via jax.config.update("jax_enable_preemption_service", False).
Reproduction
import jax
import orbax.checkpoint as ocp
jax.config.update("jax_enable_preemption_service", False)
jax.distributed.initialize(coordinator_address=..., num_processes=..., process_id=...)
manager = ocp.CheckpointManager(directory="/tmp/ckpt", options=ocp.CheckpointManagerOptions(save_interval_steps=10))
manager.save(10, args=ocp.args.PyTreeSave({"x": 1})) # raises RuntimeError
Error
RuntimeError: Preemption sync manager has not been initialized. Make sure the 'jax_enable_preemption_service' config is enabled.
Stack trace
checkpoint_manager.py:1236 in should_save
→ reached_preemption = self.reached_preemption(step)
checkpoint_manager.py:1210 in reached_preemption
→ return utils.reached_preemption(step)
multihost.py:401
→ multihost_utils.reached_preemption_sync_point(step)
jax/experimental/multihost_utils.py:229
→ raise RuntimeError("Preemption sync manager has not been initialized...")
Root cause
jax_enable_preemption_service is a documented JAX config option. When disabled, jax.distributed.initialize() skips creating the PreemptionSyncManager, and reached_preemption_sync_point() raises a RuntimeError as a documented precondition error. Orbax's multihost.reached_preemption() does not check this precondition before calling, nor handle the raised exception, causing CheckpointManager to crash when a supported JAX config is used.
Possible fixes
A few options:
- Check before calling — guard against the uninitialized manager before invoking
reached_preemption_sync_point():
def reached_preemption(step: int) -> bool:
if jax._src.distributed.global_state.preemption_sync_manager is None:
return False
preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(step)
_maybe_log_reached_preemption(step, preemption_sync_point_reached)
return preemption_sync_point_reached
- Gracefully handle the exception — wrap the call and treat
RuntimeError as "preemption not available":
def reached_preemption(step: int) -> bool:
try:
preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(step)
except RuntimeError:
return False
_maybe_log_reached_preemption(step, preemption_sync_point_reached)
return preemption_sync_point_reached
Please let me know if I am missing anything here or this is an intentional design choice.
CheckpointManager crashes when
jax_enable_preemption_service=FalseDescription
CheckpointManager.save()unconditionally callmultihost_utils.reached_preemption_sync_point(), which raisesRuntimeErrorwhen JAX's preemption service is disabled viajax.config.update("jax_enable_preemption_service", False).Reproduction
Error
Stack trace
Root cause
jax_enable_preemption_serviceis a documented JAX config option. When disabled,jax.distributed.initialize()skips creating thePreemptionSyncManager, andreached_preemption_sync_point()raises aRuntimeErroras a documented precondition error. Orbax'smultihost.reached_preemption()does not check this precondition before calling, nor handle the raised exception, causingCheckpointManagerto crash when a supported JAX config is used.Possible fixes
A few options:
reached_preemption_sync_point():RuntimeErroras "preemption not available":Please let me know if I am missing anything here or this is an intentional design choice.