diff --git a/checkpoint/orbax/checkpoint/_src/asyncio_utils.py b/checkpoint/orbax/checkpoint/_src/asyncio_utils.py index 7b7e80f8e..24caad861 100644 --- a/checkpoint/orbax/checkpoint/_src/asyncio_utils.py +++ b/checkpoint/orbax/checkpoint/_src/asyncio_utils.py @@ -122,10 +122,14 @@ def run_coroutine(self, coro: Coroutine[Any, Any, _T]) -> futures.Future[_T]: raise RuntimeError('AsyncRunner has been shut down.') return asyncio.run_coroutine_threadsafe(coro, self._loop) - def shutdown(self) -> None: + def shutdown(self, wait: bool = True) -> None: """Stops the event loop, waiting for tasks to complete. See the note in the class docstring regarding thread-safety. + + Args: + wait: If True, wait for all tasks to complete before shutting down. + Otherwise, the shutdown will be non-blocking. """ if self._is_closed: return @@ -149,11 +153,19 @@ async def _shutdown_tasks(): logging.info('AsyncRunner: All tasks finished.') else: logging.info('AsyncRunner: No active tasks to wait for.') - logging.info('AsyncRunner: Stopping event loop.') - asyncio.run_coroutine_threadsafe(_shutdown_tasks(), self._loop).result() - - # Place the stop command gracefully at the end of the event queue. - self._loop.call_soon_threadsafe(self._loop.stop) - # Wait for the thread to exit. - self._thread.join() + logging.info('AsyncRunner: Shutting down (wait=%s)...', wait) + shutdown_future = asyncio.run_coroutine_threadsafe( + _shutdown_tasks(), self._loop + ) + + if wait: + shutdown_future.result() + # Place the stop command gracefully at the end of the event queue. + self._loop.call_soon_threadsafe(self._loop.stop) + # Wait for the thread to exit. + self._thread.join() + else: + # Only signal the event loop to stop, but do not wait for it to exit. + self._loop.call_soon_threadsafe(self._loop.stop) + logging.info('AsyncRunner: Stopped.') diff --git a/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py b/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py index 2a6a677c2..5303dd6ed 100644 --- a/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/asyncio_utils_test.py @@ -368,6 +368,24 @@ async def long_running_coro() -> None: self.assertTrue(future.done()) self.assertIsNone(future.result(timeout=self._TIMEOUT)) + def test_shutdown_non_blocking(self): + """Test that shutdown(wait=False) does not wait for tasks and thread.""" + task_can_complete = threading.Event() + + async def long_running_coro() -> None: + await asyncio.to_thread(task_can_complete.wait) + + future = self.runner.run_coroutine(long_running_coro()) + self.assertFalse(future.done()) + + # Verify by that the shutdown call returns immediately. + executor = self.enter_context(futures.ThreadPoolExecutor(max_workers=1)) + shutdown_future = executor.submit(self.runner.shutdown, wait=False) + shutdown_future.result(timeout=self._TIMEOUT) + # Allow the coroutine to complete, so that the other thread where + # `task_can_complete.wait` is run (not the event loop thread) can exit. + task_can_complete.set() + def test_submit_after_shutdown(self): """Test that submitting a coroutine after shutdown raises RuntimeError.""" self.runner.shutdown()