diff --git a/pyproject.toml b/pyproject.toml index 523221176..16ad9e76d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.13.4" +version = "0.13.5" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/tools/a2a/a2a_tool.py b/src/uipath_langchain/agent/tools/a2a/a2a_tool.py index 068bd952e..12009eb64 100644 --- a/src/uipath_langchain/agent/tools/a2a/a2a_tool.py +++ b/src/uipath_langchain/agent/tools/a2a/a2a_tool.py @@ -10,6 +10,7 @@ import asyncio import json +import os from contextlib import AsyncExitStack, asynccontextmanager from logging import getLogger from typing import AsyncGenerator @@ -33,7 +34,9 @@ from pydantic import BaseModel, Field from uipath._utils._ssl_context import get_httpx_client_kwargs from uipath.agent.models.agent import AgentA2aResourceConfig +from uipath.platform.common import UiPathConfig +from uipath_langchain._utils import get_current_span_and_trace_ids from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.base_uipath_structured_tool import ( BaseUiPathStructuredTool, @@ -47,6 +50,63 @@ logger = getLogger(__name__) +def _normalize_trace_id(value: str) -> str: + """Normalize a trace id (UUID or hex form) to lowercase 32-char hex.""" + normalized = value.replace("-", "").lower() + if len(normalized) != 32: + raise ValueError(f"Invalid trace ID format: {value}") + return normalized + + +def _coerce_span_id(value: str | None) -> str | None: + """Return ``value`` as a 16-char lowercase hex span id, or ``None``.""" + if not value: + return None + candidate = value.replace("-", "").lower() + if len(candidate) == 16 and all(c in "0123456789abcdef" for c in candidate): + return candidate + return None + + +def _build_traceparent() -> str | None: + """Build a W3C ``traceparent`` carrying the running job's trace id. + + The remote A2A proxy adopts this trace id so the spans it emits for the + call share the calling job's trace. Returns ``None`` when no job trace id + is available, in which case the proxy self-roots the session as before. + """ + raw_trace_id = os.environ.get("UIPATH_TRACE_ID") + if not raw_trace_id: + return None + try: + trace_id = _normalize_trace_id(raw_trace_id) + except ValueError: + logger.warning("Ignoring invalid UIPATH_TRACE_ID: %s", raw_trace_id) + return None + + parent_span_id = ( + _coerce_span_id(os.environ.get("UIPATH_PARENT_SPAN_ID")) + or _coerce_span_id(get_current_span_and_trace_ids()[0]) + or uuid4().hex[:16] + ) + return f"00-{trace_id}-{parent_span_id}-01" + + +def _build_client_headers(secret: str) -> dict[str, str]: + """Build the A2A client's default headers: bearer auth plus the caller's + trace and job context so the remote proxy can correlate the call's spans + with the calling job's trace and group them under that job. + """ + headers = {"Authorization": f"Bearer {secret}"} + traceparent = _build_traceparent() + if traceparent: + headers["traceparent"] = traceparent + job_key = UiPathConfig.job_key + if job_key: + headers["X-UiPath-JobKey"] = job_key + return headers + + class A2aToolInput(BaseModel): """Input schema for A2A agent tool.""" @@ -81,7 +141,7 @@ async def get(self) -> Client: sdk = UiPath() client_kwargs = get_httpx_client_kwargs( - headers={"Authorization": f"Bearer {sdk._config.secret}"}, + headers=_build_client_headers(sdk._config.secret), ) client_kwargs["timeout"] = httpx.Timeout(300.0, connect=10.0) self._http_client = httpx.AsyncClient(**client_kwargs) diff --git a/tests/agent/tools/test_a2a_tool.py b/tests/agent/tools/test_a2a_tool.py index de90967ef..ce6350b94 100644 --- a/tests/agent/tools/test_a2a_tool.py +++ b/tests/agent/tools/test_a2a_tool.py @@ -1,15 +1,19 @@ -"""Tests for A2A tool URL resolution. - -Focuses on the behavior of preferring the UiPath-hosted proxy URL -(``a2a_url``) over any URL cached in the agent card. -""" +"""Tests for the A2A tool: URL resolution and trace-context propagation.""" +import re from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from a2a.types import AgentCard from uipath.agent.models.agent import AgentA2aResourceConfig from uipath_langchain.agent.tools.a2a.a2a_tool import ( + A2aClient, + _build_client_headers, + _build_traceparent, + _coerce_span_id, + _normalize_trace_id, _resolve_a2a_url, create_a2a_tools_and_clients, ) @@ -88,3 +92,125 @@ def test_create_tools_skips_disabled_resource() -> None: assert tools == [] assert clients == [] + + +_GUID = "12345678-9abc-def0-1234-56789abcdef0" +_HEX32 = "123456789abcdef0123456789abcdef0" +_SPAN16 = "fedcba9876543210" +_TRACEPARENT_RE = re.compile(r"^00-[0-9a-f]{32}-[0-9a-f]{16}-01$") + + +class TestNormalizeTraceId: + def test_strips_dashes_and_lowercases_uuid(self): + assert _normalize_trace_id(_GUID.upper()) == _HEX32 + + def test_passes_through_hex(self): + assert _normalize_trace_id(_HEX32) == _HEX32 + + def test_rejects_wrong_length(self): + with pytest.raises(ValueError): + _normalize_trace_id("deadbeef") + + +class TestCoerceSpanId: + def test_accepts_16_hex(self): + assert _coerce_span_id(_SPAN16) == _SPAN16 + + def test_lowercases(self): + assert _coerce_span_id(_SPAN16.upper()) == _SPAN16 + + def test_rejects_none_and_empty(self): + assert _coerce_span_id(None) is None + assert _coerce_span_id("") is None + + def test_rejects_wrong_length(self): + assert _coerce_span_id("abc") is None + + def test_rejects_non_hex(self): + assert _coerce_span_id("zzzzzzzzzzzzzzzz") is None + + +class TestBuildTraceparent: + def test_none_without_trace_id(self, monkeypatch): + monkeypatch.delenv("UIPATH_TRACE_ID", raising=False) + assert _build_traceparent() is None + + def test_none_with_invalid_trace_id(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", "not-a-trace-id") + assert _build_traceparent() is None + + def test_uses_trace_id_and_parent_span(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", _HEX32) + monkeypatch.setenv("UIPATH_PARENT_SPAN_ID", _SPAN16) + assert _build_traceparent() == f"00-{_HEX32}-{_SPAN16}-01" + + def test_normalizes_uuid_trace_id(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", _GUID) + monkeypatch.setenv("UIPATH_PARENT_SPAN_ID", _SPAN16) + assert _build_traceparent() == f"00-{_HEX32}-{_SPAN16}-01" + + def test_mints_parent_when_absent(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", _HEX32) + monkeypatch.delenv("UIPATH_PARENT_SPAN_ID", raising=False) + result = _build_traceparent() + assert result is not None + assert _TRACEPARENT_RE.match(result) + assert result.startswith(f"00-{_HEX32}-") + + +class TestBuildClientHeaders: + def test_authorization_only_without_job_context(self, monkeypatch): + monkeypatch.delenv("UIPATH_TRACE_ID", raising=False) + monkeypatch.delenv("UIPATH_JOB_KEY", raising=False) + headers = _build_client_headers("tok") + assert headers["Authorization"] == "Bearer tok" + assert "traceparent" not in headers + assert "X-UiPath-JobKey" not in headers + + def test_includes_traceparent_and_job_key(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", _HEX32) + monkeypatch.setenv("UIPATH_PARENT_SPAN_ID", _SPAN16) + monkeypatch.setenv("UIPATH_JOB_KEY", "job-123") + headers = _build_client_headers("tok") + assert headers["Authorization"] == "Bearer tok" + assert headers["traceparent"] == f"00-{_HEX32}-{_SPAN16}-01" + assert headers["X-UiPath-JobKey"] == "job-123" + + +class TestA2aClientGet: + async def test_get_applies_trace_and_job_headers(self, monkeypatch): + monkeypatch.setenv("UIPATH_TRACE_ID", _HEX32) + monkeypatch.setenv("UIPATH_PARENT_SPAN_ID", _SPAN16) + monkeypatch.setenv("UIPATH_JOB_KEY", "job-123") + + card = AgentCard( + url="https://example.test/a2a", + name="agent", + description="", + version="1.0.0", + skills=[], + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + ) + client = A2aClient(card) + sdk = MagicMock() + sdk._config.secret = "tok" + + with ( + patch("uipath.platform.UiPath", return_value=sdk), + patch( + "a2a.client.ClientFactory.connect", + new=AsyncMock(return_value=MagicMock()), + ), + ): + await client.get() + + try: + assert client._http_client is not None + headers = client._http_client.headers + assert headers["authorization"] == "Bearer tok" + assert headers["traceparent"] == f"00-{_HEX32}-{_SPAN16}-01" + assert headers["x-uipath-jobkey"] == "job-123" + finally: + await client.dispose() diff --git a/uv.lock b/uv.lock index e983af347..6246580ac 100644 --- a/uv.lock +++ b/uv.lock @@ -4413,7 +4413,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.13.4" +version = "0.13.5" source = { editable = "." } dependencies = [ { name = "a2a-sdk" },