Skip to content

CheckpointManager crashes when jax_enable_preemption_service=False #3406

Description

@marsunique

CheckpointManager crashes when jax_enable_preemption_service=False

Description

CheckpointManager.save() unconditionally call multihost_utils.reached_preemption_sync_point(), which raises RuntimeError when JAX's preemption service is disabled via jax.config.update("jax_enable_preemption_service", False).

Reproduction

import jax
import orbax.checkpoint as ocp

jax.config.update("jax_enable_preemption_service", False)
jax.distributed.initialize(coordinator_address=..., num_processes=..., process_id=...)

manager = ocp.CheckpointManager(directory="/tmp/ckpt", options=ocp.CheckpointManagerOptions(save_interval_steps=10))
manager.save(10, args=ocp.args.PyTreeSave({"x": 1}))  # raises RuntimeError

Error

RuntimeError: Preemption sync manager has not been initialized. Make sure the 'jax_enable_preemption_service' config is enabled.

Stack trace

checkpoint_manager.py:1236 in should_save
  → reached_preemption = self.reached_preemption(step)
checkpoint_manager.py:1210 in reached_preemption
  → return utils.reached_preemption(step)
multihost.py:401
  → multihost_utils.reached_preemption_sync_point(step)
jax/experimental/multihost_utils.py:229
  → raise RuntimeError("Preemption sync manager has not been initialized...")

Root cause

jax_enable_preemption_service is a documented JAX config option. When disabled, jax.distributed.initialize() skips creating the PreemptionSyncManager, and reached_preemption_sync_point() raises a RuntimeError as a documented precondition error. Orbax's multihost.reached_preemption() does not check this precondition before calling, nor handle the raised exception, causing CheckpointManager to crash when a supported JAX config is used.

Possible fixes

A few options:

  1. Check before calling — guard against the uninitialized manager before invoking reached_preemption_sync_point():
def reached_preemption(step: int) -> bool:
    if jax._src.distributed.global_state.preemption_sync_manager is None:
        return False
    preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(step)
    _maybe_log_reached_preemption(step, preemption_sync_point_reached)
    return preemption_sync_point_reached
  1. Gracefully handle the exception — wrap the call and treat RuntimeError as "preemption not available":
def reached_preemption(step: int) -> bool:
    try:
        preemption_sync_point_reached = multihost_utils.reached_preemption_sync_point(step)
    except RuntimeError:
        return False
    _maybe_log_reached_preemption(step, preemption_sync_point_reached)
    return preemption_sync_point_reached

Please let me know if I am missing anything here or this is an intentional design choice.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions