From d6089093e7f625e70fd88c61e99abccbf77eca1b Mon Sep 17 00:00:00 2001 From: Damian Momot Date: Thu, 28 May 2026 05:10:59 -0700 Subject: [PATCH] fix: route HITL confirmation back to originating sub-agent in workflow agents When an LlmAgent that uses a tool requiring Human-in-the-Loop confirmation is wrapped inside a non-LlmAgent workflow agent (e.g. SequentialAgent, ParallelAgent, LoopAgent), the runner used to fall back to the root agent on confirmation resumption. This caused 'VerifyException: Tool not found' because the root agent does not have the sub-agent's tools registered. Runner.findAgentToRun now first checks whether the last event is a function response and, if so, routes it back to the agent that emitted the matching function call (looked up by id), regardless of whether that agent's parent chain is fully transferable. This mirrors the Python ADK behaviour in Runner._find_agent_to_run via find_matching_function_call. PiperOrigin-RevId: 922707673 --- .../java/com/google/adk/runner/Runner.java | 51 ++++++++- .../com/google/adk/runner/RunnerTest.java | 102 ++++++++++++++++++ 2 files changed, 148 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 26a523fdd..1ab101398 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -16,6 +16,8 @@ package com.google.adk.runner; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.ContextCacheConfig; @@ -45,6 +47,8 @@ import com.google.adk.utils.CollectionUtils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.MapMaker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.AudioTranscriptionConfig; @@ -64,6 +68,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -772,12 +777,15 @@ private boolean isTransferableAcrossAgentTree(BaseAgent agentToRun) { return true; } - /** - * Returns the agent that should handle the next request based on session history. - * - * @return agent to run. - */ + /** Returns the agent that should handle the next request based on session history. */ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { + // Route function responses back to the originating function-call author so HITL tool + // confirmations resume the sub-agent even through non-LlmAgent ancestors. + Optional functionCallAuthor = findFunctionCallAuthor(session, rootAgent); + if (functionCallAuthor.isPresent()) { + return functionCallAuthor.get(); + } + List events = new ArrayList<>(session.events()); Collections.reverse(events); @@ -808,6 +816,39 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } + /** + * If the last event is a function response, returns the agent that emitted the matching function + * call (by id), or empty if no match is found in the agent tree. + */ + private static Optional findFunctionCallAuthor(Session session, BaseAgent rootAgent) { + List events = session.events(); + if (events.isEmpty()) { + return Optional.empty(); + } + ImmutableSet functionResponseIds = + Iterables.getLast(events).functionResponses().stream() + .map(fr -> fr.id().orElse(null)) + .filter(Objects::nonNull) + .collect(toImmutableSet()); + + // Iterate in reverse to prefer the most recent matching call, mirroring Python ADK's + // find_event_by_function_call_id. Function call IDs are unique in normal flows, so this + // is defense-in-depth and not covered by mutation testing. + List precedingEvents = new ArrayList<>(events.subList(0, events.size() - 1)); + Collections.reverse(precedingEvents); + for (Event event : precedingEvents) { + boolean matches = + event.functionCalls().stream() + .map(fc -> fc.id().orElse(null)) + .filter(Objects::nonNull) + .anyMatch(functionResponseIds::contains); + if (matches && event.author() != null) { + return rootAgent.findAgent(event.author()); + } + } + return Optional.empty(); + } + private void addActiveStreamingTools(InvocationContext invocationContext, List tools) { tools.stream() .filter(FunctionTool.class::isInstance) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 00d5d63bf..95718e3e0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -42,6 +42,7 @@ import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; +import com.google.adk.agents.SequentialAgent; import com.google.adk.apps.App; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -1604,6 +1605,107 @@ public void runAsync_withToolConfirmation() { .inOrder(); } + // HITL tool confirmation must resume the originating sub-agent even when wrapped inside a + // non-LlmAgent workflow agent (e.g. SequentialAgent). + @Test + public void runAsync_withToolConfirmation_inSequentialAgentSubAgent_resumesSubAgent() { + TestLlm childTestLlm = + createTestLlm( + createFunctionCallLlmResponse( + "tool_call_id", "echoTool", ImmutableMap.of("message", "hello")), + createTextLlmResponse("Response after observing tool needs confirmation."), + createTextLlmResponse("Response after user confirmed.")); + LlmAgent childAgent = + createTestAgentBuilder(childTestLlm) + .name("child_agent") + .tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true)) + .build(); + SequentialAgent workflowAgent = + SequentialAgent.builder() + .name("workflow_agent") + .subAgents(ImmutableList.of(childAgent)) + .build(); + // Root transfers to workflow_agent to mirror the bug report's control flow. + TestLlm rootTestLlm = + createTestLlm( + createLlmResponse( + Content.fromParts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "workflow_agent"))))); + LlmAgent rootAgent = + createTestAgentBuilder(rootTestLlm) + .name("root_agent") + .subAgents(ImmutableList.of(workflowAgent)) + .build(); + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + List eventsBeforeConfirmation = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("from user"))) + .toList() + .blockingGet(); + FunctionCall askUserConfirmationFunctionCall = + Iterables.getOnlyElement( + eventsBeforeConfirmation.stream() + .map(Functions::getAskUserConfirmationFunctionCalls) + .filter(functionCalls -> !functionCalls.isEmpty()) + .findFirst() + .get()); + List eventsAfterConfirmation = + runner + .runAsync( + "user", + session.id(), + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(askUserConfirmationFunctionCall.id().get()) + .name(askUserConfirmationFunctionCall.name().get()) + .response(ImmutableMap.of("confirmed", true))) + .build())) + .toList() + .blockingGet(); + + // The originating child agent (not the root agent) must execute the tool. + assertThat(simplifyEvents(eventsAfterConfirmation)) + .containsExactly( + "child_agent: FunctionResponse(name=echoTool, response={message=hello})", + "child_agent: Response after user confirmed.") + .inOrder(); + } + + // Orphan function responses (id not matching any prior call) should fall back to the root agent. + @Test + public void runAsync_withFunctionResponseNotMatchingAnyCall_fallsBackToRootAgent() { + TestLlm rootLlm = createTestLlm(createTextLlmResponse("after function response")); + LlmAgent rootAgent = createTestAgentBuilder(rootLlm).name("root_agent").build(); + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Function response with id that does not match any prior function call. + List events = + runner + .runAsync( + "user", + session.id(), + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("non_existent_id") + .name("orphanFn") + .response(ImmutableMap.of("x", 1))) + .build())) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("root_agent: after function response"); + } + @Test public void close_closesPluginsAndCodeExecutors() { BasePlugin plugin = mockPlugin("close_test_plugin");