Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.adk.runner;

import static com.google.common.collect.ImmutableSet.toImmutableSet;

import com.google.adk.agents.ActiveStreamingTool;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.ContextCacheConfig;
Expand Down Expand Up @@ -45,6 +47,8 @@
import com.google.adk.utils.CollectionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.MapMaker;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.AudioTranscriptionConfig;
Expand All @@ -64,6 +68,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -766,12 +771,15 @@ private boolean isTransferableAcrossAgentTree(BaseAgent agentToRun) {
return true;
}

/**
* Returns the agent that should handle the next request based on session history.
*
* @return agent to run.
*/
/** Returns the agent that should handle the next request based on session history. */
private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
// Route function responses back to the originating function-call author so HITL tool
// confirmations resume the sub-agent even through non-LlmAgent ancestors.
Optional<BaseAgent> functionCallAuthor = findFunctionCallAuthor(session, rootAgent);
if (functionCallAuthor.isPresent()) {
return functionCallAuthor.get();
}

List<Event> events = new ArrayList<>(session.events());
Collections.reverse(events);

Expand Down Expand Up @@ -802,6 +810,39 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
return rootAgent;
}

/**
* If the last event is a function response, returns the agent that emitted the matching function
* call (by id), or empty if no match is found in the agent tree.
*/
private static Optional<BaseAgent> findFunctionCallAuthor(Session session, BaseAgent rootAgent) {
List<Event> events = session.events();
if (events.isEmpty()) {
return Optional.empty();
}
ImmutableSet<String> functionResponseIds =
Iterables.getLast(events).functionResponses().stream()
.map(fr -> fr.id().orElse(null))
.filter(Objects::nonNull)
.collect(toImmutableSet());

// Iterate in reverse to prefer the most recent matching call, mirroring Python ADK's
// find_event_by_function_call_id. Function call IDs are unique in normal flows, so this
// is defense-in-depth and not covered by mutation testing.
List<Event> precedingEvents = new ArrayList<>(events.subList(0, events.size() - 1));
Collections.reverse(precedingEvents);
for (Event event : precedingEvents) {
boolean matches =
event.functionCalls().stream()
.map(fc -> fc.id().orElse(null))
.filter(Objects::nonNull)
.anyMatch(functionResponseIds::contains);
if (matches && event.author() != null) {
return rootAgent.findAgent(event.author());
}
}
return Optional.empty();
}

private void addActiveStreamingTools(InvocationContext invocationContext, List<BaseTool> tools) {
tools.stream()
.filter(FunctionTool.class::isInstance)
Expand Down
102 changes: 102 additions & 0 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.google.adk.agents.LiveRequestQueue;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.agents.SequentialAgent;
import com.google.adk.apps.App;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
Expand Down Expand Up @@ -1568,6 +1569,107 @@ public void runAsync_withToolConfirmation() {
.inOrder();
}

// HITL tool confirmation must resume the originating sub-agent even when wrapped inside a
// non-LlmAgent workflow agent (e.g. SequentialAgent).
@Test
public void runAsync_withToolConfirmation_inSequentialAgentSubAgent_resumesSubAgent() {
TestLlm childTestLlm =
createTestLlm(
createFunctionCallLlmResponse(
"tool_call_id", "echoTool", ImmutableMap.of("message", "hello")),
createTextLlmResponse("Response after observing tool needs confirmation."),
createTextLlmResponse("Response after user confirmed."));
LlmAgent childAgent =
createTestAgentBuilder(childTestLlm)
.name("child_agent")
.tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true))
.build();
SequentialAgent workflowAgent =
SequentialAgent.builder()
.name("workflow_agent")
.subAgents(ImmutableList.of(childAgent))
.build();
// Root transfers to workflow_agent to mirror the bug report's control flow.
TestLlm rootTestLlm =
createTestLlm(
createLlmResponse(
Content.fromParts(
Part.fromFunctionCall(
"transfer_to_agent", ImmutableMap.of("agent_name", "workflow_agent")))));
LlmAgent rootAgent =
createTestAgentBuilder(rootTestLlm)
.name("root_agent")
.subAgents(ImmutableList.of(workflowAgent))
.build();
Runner runner =
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();

List<Event> eventsBeforeConfirmation =
runner
.runAsync("user", session.id(), Content.fromParts(Part.fromText("from user")))
.toList()
.blockingGet();
FunctionCall askUserConfirmationFunctionCall =
Iterables.getOnlyElement(
eventsBeforeConfirmation.stream()
.map(Functions::getAskUserConfirmationFunctionCalls)
.filter(functionCalls -> !functionCalls.isEmpty())
.findFirst()
.get());
List<Event> eventsAfterConfirmation =
runner
.runAsync(
"user",
session.id(),
Content.fromParts(
Part.builder()
.functionResponse(
FunctionResponse.builder()
.id(askUserConfirmationFunctionCall.id().get())
.name(askUserConfirmationFunctionCall.name().get())
.response(ImmutableMap.of("confirmed", true)))
.build()))
.toList()
.blockingGet();

// The originating child agent (not the root agent) must execute the tool.
assertThat(simplifyEvents(eventsAfterConfirmation))
.containsExactly(
"child_agent: FunctionResponse(name=echoTool, response={message=hello})",
"child_agent: Response after user confirmed.")
.inOrder();
}

// Orphan function responses (id not matching any prior call) should fall back to the root agent.
@Test
public void runAsync_withFunctionResponseNotMatchingAnyCall_fallsBackToRootAgent() {
TestLlm rootLlm = createTestLlm(createTextLlmResponse("after function response"));
LlmAgent rootAgent = createTestAgentBuilder(rootLlm).name("root_agent").build();
Runner runner =
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();

// Function response with id that does not match any prior function call.
List<Event> events =
runner
.runAsync(
"user",
session.id(),
Content.fromParts(
Part.builder()
.functionResponse(
FunctionResponse.builder()
.id("non_existent_id")
.name("orphanFn")
.response(ImmutableMap.of("x", 1)))
.build()))
.toList()
.blockingGet();

assertThat(simplifyEvents(events)).containsExactly("root_agent: after function response");
}

@Test
public void close_closesPluginsAndCodeExecutors() {
BasePlugin plugin = mockPlugin("close_test_plugin");
Expand Down