Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions checkpoint/orbax/checkpoint/_src/asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,14 @@ def run_coroutine(self, coro: Coroutine[Any, Any, _T]) -> futures.Future[_T]:
raise RuntimeError('AsyncRunner has been shut down.')
return asyncio.run_coroutine_threadsafe(coro, self._loop)

def shutdown(self) -> None:
def shutdown(self, wait: bool = True) -> None:
"""Stops the event loop, waiting for tasks to complete.

See the note in the class docstring regarding thread-safety.

Args:
wait: If True, wait for all tasks to complete before shutting down.
Otherwise, the shutdown will be non-blocking.
"""
if self._is_closed:
return
Expand All @@ -149,11 +153,19 @@ async def _shutdown_tasks():
logging.info('AsyncRunner: All tasks finished.')
else:
logging.info('AsyncRunner: No active tasks to wait for.')
logging.info('AsyncRunner: Stopping event loop.')

asyncio.run_coroutine_threadsafe(_shutdown_tasks(), self._loop).result()

# Place the stop command gracefully at the end of the event queue.
self._loop.call_soon_threadsafe(self._loop.stop)
# Wait for the thread to exit.
self._thread.join()
logging.info('AsyncRunner: Shutting down (wait=%s)...', wait)
shutdown_future = asyncio.run_coroutine_threadsafe(
_shutdown_tasks(), self._loop
)

if wait:
shutdown_future.result()
# Place the stop command gracefully at the end of the event queue.
self._loop.call_soon_threadsafe(self._loop.stop)
# Wait for the thread to exit.
self._thread.join()
else:
# Only signal the event loop to stop, but do not wait for it to exit.
self._loop.call_soon_threadsafe(self._loop.stop)
logging.info('AsyncRunner: Stopped.')
18 changes: 18 additions & 0 deletions checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,24 @@ async def long_running_coro() -> None:
self.assertTrue(future.done())
self.assertIsNone(future.result(timeout=self._TIMEOUT))

def test_shutdown_non_blocking(self):
"""Test that shutdown(wait=False) does not wait for tasks and thread."""
task_can_complete = threading.Event()

async def long_running_coro() -> None:
await asyncio.to_thread(task_can_complete.wait)

future = self.runner.run_coroutine(long_running_coro())
self.assertFalse(future.done())

# Verify by that the shutdown call returns immediately.
executor = self.enter_context(futures.ThreadPoolExecutor(max_workers=1))
shutdown_future = executor.submit(self.runner.shutdown, wait=False)
shutdown_future.result(timeout=self._TIMEOUT)
# Allow the coroutine to complete, so that the other thread where
# `task_can_complete.wait` is run (not the event loop thread) can exit.
task_can_complete.set()

def test_submit_after_shutdown(self):
"""Test that submitting a coroutine after shutdown raises RuntimeError."""
self.runner.shutdown()
Expand Down
Loading