diff --git a/src/uipath/runtime/governance/native/guardrail_compensation.py b/src/uipath/runtime/governance/native/guardrail_compensation.py new file mode 100644 index 0000000..6e1752c --- /dev/null +++ b/src/uipath/runtime/governance/native/guardrail_compensation.py @@ -0,0 +1,353 @@ +"""Compensating governance for disabled centralized guardrails. + +When a ``guardrail_fallback`` rule fires (the guardrail is mapped to +UiPath but the centralized policy is disabled), the framework asks the +governance-server to run the real guardrail check via its +``/{org_id}/agenticgovernance_/api/v1/runtime/govern`` endpoint. + +This module owns only the **local concerns**: a bounded background +pool that schedules the call without blocking the agent hook, and a +trace-id capture that runs on the caller thread before the worker hop +(the worker has no OpenTelemetry context). + +The actual HTTP call — URL composition, auth, headers, JSON +serialisation, env-backed job-context auto-fill — is the +:class:`uipath.core.governance.GovernanceCompensationProvider`'s job. +Callers inject a concrete provider (typically +``uipath.platform.governance.UiPathPlatformGovernanceProvider``) and +this module just builds the :class:`GovernRequest` wire model and hands +it off. + +The call is **fire-and-forget**: the server runs the guardrail AND +writes the audit trace from its side. The agent doesn't inspect the +response — it only cares about whether the call reached the server. + +The compensator is **instance-scoped**: each :class:`GovernanceRuntime` +owns its own pool and semaphore. ``uipath eval`` parallel runtimes +don't share workers, queue slots, or saturation state — one runtime's +spam can't silently drop another's compensation calls. + +The compensator does **not** read host env vars. The trace id is +passed in by the wiring layer (uipath CLI → :class:`GovernanceRuntime` +→ :class:`GuardrailCompensator`). Inside the compensator, resolution +order is: constructor-supplied trace id → live OTel span on the caller +thread → per-call fallback. +""" + +from __future__ import annotations + +import atexit +import logging +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from uipath.core.governance import ( + FiredRule, + GovernanceCompensationProvider, + GovernRequest, +) + +logger = logging.getLogger(__name__) + + +# ---------------------------------------------------------------------------- +# Process-wide cleanup machinery +# +# One ``atexit`` hook walks a ``WeakSet`` of live compensators on exit and +# closes each. Bounded atexit registrations (N runtimes → 1 hook, not N) and +# weakref tracking so a disposed compensator can be GC'd. Same pattern as +# :class:`uipath.runtime.governance._audit.base.AuditManager`. +# ---------------------------------------------------------------------------- + +_live_compensators: weakref.WeakSet[GuardrailCompensator] = weakref.WeakSet() +_atexit_registered = False +_atexit_lock = threading.Lock() + + +def _process_cleanup_compensators() -> None: + """Process-exit handler: close every live compensator.""" + for compensator in list(_live_compensators): + try: + compensator.close() + except Exception as exc: # noqa: BLE001 - exit cleanup must not raise + logger.debug("Compensator process cleanup error: %s", exc) + + +def _register_compensator_for_cleanup(compensator: GuardrailCompensator) -> None: + """Add ``compensator`` to the cleanup set + ensure atexit is wired once.""" + global _atexit_registered + _live_compensators.add(compensator) + if _atexit_registered: + return + with _atexit_lock: + if not _atexit_registered: + atexit.register(_process_cleanup_compensators) + _atexit_registered = True + + +# ---------------------------------------------------------------------------- +# Stateless helpers +# ---------------------------------------------------------------------------- + + +def disabled_guardrails(audit: Any, policy_index: Any) -> list[FiredRule]: + """Return per-rule metadata for each fired guardrail-fallback rule. + + A guardrail rule fires only when it is mapped to UiPath + (``mapped_to_uipath`` true) but disabled (``policy_enabled`` false) — + see the ``guardrail_fallback`` operator. The validator name (e.g. + ``pii_detection``) is read from the rule's ``guardrail_fallback`` + check config and used as the validator on the compensating call. + + One :class:`FiredRule` entry is emitted per matching + ``guardrail_fallback`` condition. Rules in this codebase declare a + single fallback condition each, so the returned list has one entry + per fired rule in practice; multi-condition rules would emit more + than one entry sharing the same ``rule_id``. + """ + out: list[FiredRule] = [] + for ev in audit.evaluations: + if not ev.matched: + continue + rule = policy_index.get_rule(ev.rule_id) + if rule is None: + continue + for check in rule.checks: + for cond in check.conditions: + if cond.operator != "guardrail_fallback": + continue + if not isinstance(cond.value, dict): + continue + # The ``guardrail_fallback`` operator at evaluation time + # only matches when ``mapped_to_uipath=True`` AND + # ``policy_enabled=False``. We re-check here defensively + # so a future code path that bypasses the evaluator (or + # a multi-condition rule that fired on a sibling check) + # can't trigger a compensation call for a guardrail + # that isn't actually disabled. + if not bool(cond.value.get("mapped_to_uipath", False)): + continue + if bool(cond.value.get("policy_enabled", True)): + continue + validator = str(cond.value.get("validator", "")) + if validator: + out.append( + FiredRule( + rule_id=ev.rule_id, + rule_name=ev.rule_name, + pack_name=getattr(rule, "pack_name", "") or "", + validator=validator, + ) + ) + return out + + +def _validators(rules: list[FiredRule]) -> list[str]: + """Distinct validator names from the fired rules, preserving order.""" + return list(dict.fromkeys(r.validator for r in rules if r.validator)) + + +def _resolve_trace_id(supplied: str | None, fallback: str) -> str: + """Resolve the agent's trace id while still on the caller thread. + + MUST be called before the background-pool hop in + :meth:`GuardrailCompensator.submit`: the worker thread that issues + the ``/govern`` call has no OpenTelemetry context, so resolving + there would fall back to a detached id — orphaning the + server-written compensation records from the agent's real trace. + + Resolution order: + + 1. ``supplied`` — the trace id the wiring layer passed into + :class:`GuardrailCompensator` at construction (typically read + from ``UIPATH_TRACE_ID`` by ``uipath`` CLI). Authoritative when + set: native governance audit spans are exported under that id + (the platform rebinds spans to the agent's run trace), so + server-written compensation records must land on the *same* id. + 2. Live OTel span trace id (32-char hex) — used when the wiring + layer didn't supply one and a current OTel context exists. + 3. ``fallback`` — the per-call value the caller passed to + ``submit``. Last resort. + + The function does **not** read host env vars. Env reading lives + in the wiring layer (per the boundary discipline applied across + the governance stack). + """ + if supplied: + return supplied + + try: + from opentelemetry import trace + + ctx = trace.get_current_span().get_span_context() + if ctx.is_valid: + return format(ctx.trace_id, "032x") + except Exception as exc: # noqa: BLE001 - tracing is best-effort; fall through + logger.debug("OTel trace-id lookup failed in _resolve_trace_id: %s", exc) + + return fallback + + +# ---------------------------------------------------------------------------- +# GuardrailCompensator +# ---------------------------------------------------------------------------- + + +class GuardrailCompensator: + """Instance-scoped compensating-governance dispatcher. + + Each :class:`GovernanceRuntime` constructs one. Owns: + + - A :class:`ThreadPoolExecutor` (default 4 workers) that runs the + ``/runtime/govern`` POST off the agent's hook thread. + - A :class:`threading.BoundedSemaphore` (default cap = workers × 4) + that bounds total in-flight submissions (running + queued) so a + misbehaving agent firing compensation faster than the server can + absorb can't grow memory without limit. Saturated submissions are + dropped with a warning. + + Process exit cancels queued work via a single process-level atexit + handler (see :func:`_process_cleanup_compensators`); running tasks + finish bounded by the provider's HTTP timeout. + + Fire-and-forget: :meth:`submit` returns immediately. The actual HTTP + work is delegated to :meth:`GovernanceCompensationProvider.compensate` + — this class never touches URL/headers/auth/JSON itself. + """ + + _DEFAULT_MAX_WORKERS = 4 + # Queue depth multiplier — total in-flight cap = max_workers × this. + _INFLIGHT_OVERSUBSCRIPTION = 4 + + def __init__( + self, + provider: GovernanceCompensationProvider, + *, + trace_id: str | None = None, + max_workers: int = _DEFAULT_MAX_WORKERS, + inflight_oversubscription: int = _INFLIGHT_OVERSUBSCRIPTION, + ) -> None: + """Construct a compensator bound to one provider. + + Args: + provider: The :class:`GovernanceCompensationProvider` that + actually fires the ``/runtime/govern`` POST. Typically + ``uipath.platform.governance.UiPathPlatformGovernanceProvider``. + trace_id: Trace id the wiring layer (uipath CLI) read from + ``UIPATH_TRACE_ID`` and propagated through + :class:`GovernanceRuntime`. Authoritative when set: + server-written compensation records land on the agent's + run trace. ``None`` (default) falls back to the live + OTel span / caller-supplied id at submit time. + max_workers: Concurrent worker threads in the pool. + inflight_oversubscription: How deep the work queue grows + before saturated submissions get dropped. Total cap is + ``max_workers * inflight_oversubscription``. + """ + self._provider = provider + self._trace_id = trace_id + self._inflight_cap = max_workers * inflight_oversubscription + self._pool = ThreadPoolExecutor( + max_workers=max_workers, + thread_name_prefix="governance-compensation", + ) + self._inflight = threading.BoundedSemaphore(self._inflight_cap) + _register_compensator_for_cleanup(self) + + def submit( + self, + rules: list[FiredRule], + data: dict[str, Any], + hook: str, + trace_id: str, + src_timestamp: str, + agent_name: str, + runtime_id: str, + ) -> None: + """Schedule a /runtime/govern call on the bounded background pool. + + Fire-and-forget. Returns immediately; the call runs on a worker + thread. When the in-flight queue is saturated the call is + dropped with a warning and the agent continues. + + ``rules`` is the per-rule metadata from :func:`disabled_guardrails`; + the validators sent to the guardrail API are derived from it. + + Never raises — including when the pool has already been shut down. + """ + if not rules: + return + + validators = _validators(rules) + if not validators: + return + + # Resolve the trace id HERE, on the caller (hook) thread where the + # agent's OTel span is still live. The provider.compensate call + # below runs on a background worker where that context is gone, + # so the resolved value is captured now and carried into the + # worker — ensuring the server writes compensation records under + # the agent's real trace, not a detached id. + trace_id = _resolve_trace_id(self._trace_id, trace_id) + + if not self._inflight.acquire(blocking=False): + logger.warning( + "Compensation pool saturated (>%d in flight); dropping call " + "(validators=[%s])", + self._inflight_cap, + ", ".join(validators), + ) + return + + request = GovernRequest( + validators=validators, + rules=rules, + data=data, + hook=hook, + trace_id=trace_id, + src_timestamp=src_timestamp, + agent_name=agent_name, + runtime_id=runtime_id, + ) + + provider = self._provider + inflight = self._inflight + + def _run() -> None: + try: + provider.compensate(request) + except Exception as exc: # noqa: BLE001 - fail-open by contract + logger.warning( + "Compensation worker failed (validators=[%s]): %s", + ", ".join(validators), + exc, + ) + finally: + inflight.release() + + try: + self._pool.submit(_run) + except RuntimeError as exc: + # Pool was shut down (atexit, dispose, or test teardown) — + # release the semaphore slot we took and log; never raise. + self._inflight.release() + logger.warning( + "Compensation pool unavailable (validators=[%s]): %s", + ", ".join(validators), + exc, + ) + + def close(self) -> None: + """Cancel queued tasks. Running tasks finish bounded by the provider HTTP timeout. + + ``wait=False`` returns immediately so caller / process shutdown + isn't held up; ``cancel_futures=True`` drops anything not yet + running. Idempotent — calling close on an already-closed pool + is a logged no-op. + """ + try: + self._pool.shutdown(wait=False, cancel_futures=True) + except Exception as exc: # noqa: BLE001 - shutdown must not raise + logger.debug("Compensator shutdown error: %s", exc) diff --git a/src/uipath/runtime/governance/runtime.py b/src/uipath/runtime/governance/runtime.py index c8f9dd9..be843c3 100644 --- a/src/uipath/runtime/governance/runtime.py +++ b/src/uipath/runtime/governance/runtime.py @@ -9,9 +9,9 @@ The wiring layer (uipath CLI) decides whether to construct ``GovernanceRuntime`` at all (feature flag, project config, etc.) and -passes ``is_conversational`` explicitly when it knows the agent type. -The runtime layer does not introspect the delegate's private attributes -to discover that. +passes ``is_conversational`` and ``trace_id`` explicitly. The runtime +layer does not introspect the delegate's private attributes nor read +env vars to discover those. **Staging caveat — policy loading only, no enforcement yet.** This module is the policy-loading scaffold: ``__init__`` constructs an @@ -19,7 +19,7 @@ prefetch. ``execute`` / ``stream`` / ``get_schema`` / ``dispose`` are pure passthroughs — no per-hook policy evaluation runs. The evaluator and framework adapter wiring that consumes the loader's policy index -lands in a follow-up slice. Customers constructing +and the ``trace_id`` lands in a follow-up slice. Customers constructing :class:`GovernanceRuntime` today get policy loading without policy enforcement; this is intentional and will change when the evaluator slice merges. @@ -68,6 +68,7 @@ def __init__( policy_provider: GovernancePolicyProvider | None, *, is_conversational: bool | None = None, + trace_id: str | None = None, ): """Initialize the governance runtime. @@ -83,8 +84,17 @@ def __init__( leaves the selector unset — the provider applies its default. The wiring layer (uipath CLI) is expected to pass the concrete value when it knows the agent type. + trace_id: Trace identifier the platform host has bound to + this run (typically read from ``UIPATH_TRACE_ID`` by + the wiring layer). The evaluator slice forwards this + into the :class:`GuardrailCompensator` so server-written + compensation records land on the agent's run trace + instead of a detached id. ``None`` (default) leaves + downstream consumers to fall back to the live OTel + span / caller-supplied value. """ self._delegate = delegate + self._trace_id = trace_id self._loader = PolicyLoader( policy_provider, is_conversational=is_conversational, @@ -100,6 +110,16 @@ def loader(self) -> PolicyLoader: """ return self._loader + @property + def trace_id(self) -> str | None: + """Trace id supplied by the wiring layer (or ``None``). + + Exposed so the evaluator slice can read it at hook-wire time + and pass it into the :class:`GuardrailCompensator` it + constructs. + """ + return self._trace_id + async def execute( self, input: dict[str, Any] | None = None, diff --git a/tests/test_governance_runtime.py b/tests/test_governance_runtime.py index 810a881..65286ce 100644 --- a/tests/test_governance_runtime.py +++ b/tests/test_governance_runtime.py @@ -211,6 +211,29 @@ def test_governance_runtime_with_none_provider_yields_empty_index() -> None: assert index.total_rules == 0 +def test_governance_runtime_stashes_trace_id() -> None: + """``trace_id`` constructor arg is exposed via the ``trace_id`` property. + + The wiring layer (uipath CLI) reads ``UIPATH_TRACE_ID`` from the + host env and passes the value in. The evaluator slice (future) + consumes it through :attr:`GovernanceRuntime.trace_id` and + forwards it into the :class:`GuardrailCompensator` constructor so + compensation records land on the agent's run trace. + """ + runtime = GovernanceRuntime( + _StubDelegate(), + policy_provider=None, + trace_id="wired-trace-0001", + ) + assert runtime.trace_id == "wired-trace-0001" + + +def test_governance_runtime_default_trace_id_is_none() -> None: + """Omitting ``trace_id`` leaves the property as ``None``.""" + runtime = GovernanceRuntime(_StubDelegate(), policy_provider=None) + assert runtime.trace_id is None + + async def test_governance_runtime_execute_delegates() -> None: delegate = _StubDelegate() runtime = GovernanceRuntime(delegate, policy_provider=None) diff --git a/tests/test_guardrail_compensation.py b/tests/test_guardrail_compensation.py new file mode 100644 index 0000000..c537fa7 --- /dev/null +++ b/tests/test_guardrail_compensation.py @@ -0,0 +1,576 @@ +"""Tests for the instance-scoped GuardrailCompensator. + +The runtime layer owns only the bounded background pool and the +trace-id capture; HTTP/auth/URL/header concerns live behind the +:class:`uipath.core.governance.GovernanceCompensationProvider` protocol +and are exercised in ``uipath-platform``'s own tests. + +These tests cover: + +- ``disabled_guardrails`` — distilling fired ``guardrail_fallback`` rules + into per-rule wire metadata. +- ``GuardrailCompensator.submit`` — pool routing, in-flight + backpressure, shutdown safety, wire-model assembly, and the + thread-boundary trace-id capture. +- ``_resolve_trace_id`` — env > live OTel span > fallback ordering. +- Cross-instance isolation — two compensators do not share a pool or + semaphore. +- Process-level cleanup — one ``atexit`` registration, weak refs only. +""" + +from __future__ import annotations + +import gc +import threading +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from uipath.core.governance import ( + FiredRule, + GovernanceCompensationProvider, + GovernRequest, +) + +from uipath.runtime.governance.native import guardrail_compensation +from uipath.runtime.governance.native.guardrail_compensation import ( + GuardrailCompensator, + _resolve_trace_id, + disabled_guardrails, +) + +# Evaluator integration is not present on this branch — the evaluator +# module (which would consume the compensator) lands in a later slice. +# Tests that exercise the full dispatch path skip until then. +_HAS_EVALUATOR = False +try: + from uipath.runtime.governance.native.evaluator import ( # type: ignore[import-not-found] # noqa: F401 + GovernanceEvaluator, + ) + + _HAS_EVALUATOR = True +except ImportError: + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _provider() -> MagicMock: + """Mock satisfying the GovernanceCompensationProvider protocol.""" + return MagicMock(spec=GovernanceCompensationProvider) + + +def _rules( + *validators: str, + rule_id: str = "R1", + rule_name: str = "n", + pack: str = "p", +) -> list[FiredRule]: + """Build a list of FiredRule wire models — one per validator.""" + return [ + FiredRule( + rule_id=rule_id, + rule_name=rule_name, + pack_name=pack, + validator=v, + ) + for v in validators + ] + + +def _run_inline(compensator: GuardrailCompensator) -> None: + """Replace the pool's ``submit`` with synchronous execution. + + Lets tests assert provider behavior deterministically without + relying on wait()/sleep(). + """ + + def _sync_submit(fn: Any, *args: Any, **kwargs: Any) -> None: + fn() + + compensator._pool.submit = _sync_submit # type: ignore[method-assign] + + +@pytest.fixture(autouse=True) +def _close_dangling_compensators() -> Any: + """Best-effort teardown: close any compensator weak-refs still in the set. + + Each test should call ``compensator.close()``, but a failing + assertion mid-test could leak. The sweep prevents pytest from + hanging at exit on a leftover worker pool. + """ + yield + for compensator in list(guardrail_compensation._live_compensators): + try: + compensator.close() + except Exception: # noqa: BLE001 - best-effort teardown + pass + guardrail_compensation._live_compensators.clear() + + +# --------------------------------------------------------------------------- +# disabled_guardrails +# --------------------------------------------------------------------------- + + +def test_disabled_guardrails_returns_fired_rule_for_matched_disabled_guardrail() -> None: + cond = SimpleNamespace( + operator="guardrail_fallback", + value={ + "validator": "pii_detection", + "mapped_to_uipath": True, + "policy_enabled": False, + }, + ) + rule = SimpleNamespace(checks=[SimpleNamespace(conditions=[cond])], pack_name="") + audit = SimpleNamespace( + evaluations=[ + SimpleNamespace(matched=True, rule_id="R1", rule_name="PII guardrail") + ] + ) + policy_index = SimpleNamespace( + get_rule=lambda rid: rule if rid == "R1" else None + ) + + out = disabled_guardrails(audit, policy_index) + + assert len(out) == 1 + fr = out[0] + assert isinstance(fr, FiredRule) + assert fr.rule_id == "R1" + assert fr.rule_name == "PII guardrail" + assert fr.pack_name == "" + assert fr.validator == "pii_detection" + + +def test_disabled_guardrails_skips_unmatched_evaluations() -> None: + audit = SimpleNamespace( + evaluations=[SimpleNamespace(matched=False, rule_id="R1", rule_name="x")] + ) + policy_index = SimpleNamespace(get_rule=lambda rid: None) + assert disabled_guardrails(audit, policy_index) == [] + + +def test_disabled_guardrails_skips_non_guardrail_conditions() -> None: + cond = SimpleNamespace(operator="regex", value="some-pattern") + rule = SimpleNamespace(checks=[SimpleNamespace(conditions=[cond])]) + audit = SimpleNamespace( + evaluations=[SimpleNamespace(matched=True, rule_id="R1", rule_name="x")] + ) + policy_index = SimpleNamespace(get_rule=lambda rid: rule) + assert disabled_guardrails(audit, policy_index) == [] + + +def test_disabled_guardrails_skips_enabled_guardrails() -> None: + """Mapped to UiPath AND enabled → no compensation needed.""" + cond = SimpleNamespace( + operator="guardrail_fallback", + value={ + "validator": "pii_detection", + "mapped_to_uipath": True, + "policy_enabled": True, + }, + ) + rule = SimpleNamespace(checks=[SimpleNamespace(conditions=[cond])], pack_name="") + audit = SimpleNamespace( + evaluations=[SimpleNamespace(matched=True, rule_id="R1", rule_name="x")] + ) + policy_index = SimpleNamespace(get_rule=lambda rid: rule) + assert disabled_guardrails(audit, policy_index) == [] + + +def test_disabled_guardrails_skips_unmapped_guardrails() -> None: + """Not mapped to UiPath → server can't fall back; skip.""" + cond = SimpleNamespace( + operator="guardrail_fallback", + value={ + "validator": "pii_detection", + "mapped_to_uipath": False, + "policy_enabled": False, + }, + ) + rule = SimpleNamespace(checks=[SimpleNamespace(conditions=[cond])], pack_name="") + audit = SimpleNamespace( + evaluations=[SimpleNamespace(matched=True, rule_id="R1", rule_name="x")] + ) + policy_index = SimpleNamespace(get_rule=lambda rid: rule) + assert disabled_guardrails(audit, policy_index) == [] + + +# --------------------------------------------------------------------------- +# GuardrailCompensator.submit — short-circuits + pool routing + backpressure +# --------------------------------------------------------------------------- + + +def test_submit_empty_rules_short_circuits() -> None: + """No rules → no pool submit, no provider call.""" + provider = _provider() + compensator = GuardrailCompensator(provider) + with patch.object(compensator, "_pool") as mock_pool: + compensator.submit([], {}, "before_model", "t", "ts", "a", "r") + mock_pool.submit.assert_not_called() + provider.compensate.assert_not_called() + + +def test_submit_no_validators_short_circuits() -> None: + """Rules with empty validator strings → no call (nothing to dispatch).""" + provider = _provider() + compensator = GuardrailCompensator(provider) + rules = [FiredRule(rule_id="R", rule_name="n", pack_name="p", validator="")] + with patch.object(compensator, "_pool") as mock_pool: + compensator.submit(rules, {}, "before_model", "t", "ts", "a", "r") + mock_pool.submit.assert_not_called() + provider.compensate.assert_not_called() + + +def test_submit_routes_through_pool() -> None: + """A non-empty rules list submits a single task to the pool.""" + provider = _provider() + compensator = GuardrailCompensator(provider) + with patch.object(compensator, "_pool") as mock_pool: + compensator.submit( + _rules("pii_detection"), + {"content": "x"}, + "before_model", + "trace-1", + "ts", + "agent", + "run", + ) + mock_pool.submit.assert_called_once() + + +def test_submit_drops_when_pool_saturated() -> None: + """When the in-flight semaphore is exhausted, the call is dropped.""" + provider = _provider() + compensator = GuardrailCompensator(provider) + + # Force the semaphore into "exhausted" state. + drained = threading.BoundedSemaphore(1) + drained.acquire() # next acquire(blocking=False) returns False + compensator._inflight = drained + + with patch.object(compensator, "_pool") as mock_pool: + compensator.submit( + _rules("pii_detection"), + {}, + "before_model", + "trace-1", + "ts", + "agent", + "run", + ) + + mock_pool.submit.assert_not_called() + provider.compensate.assert_not_called() + + +def test_submit_swallows_pool_shutdown_runtimeerror() -> None: + """If the pool was shut down, submit must not raise.""" + + class _ShutdownPool: + def submit(self, fn: Any, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("cannot schedule new futures after shutdown") + + compensator = GuardrailCompensator(_provider()) + compensator._pool = _ShutdownPool() # type: ignore[assignment] + compensator._inflight = threading.BoundedSemaphore(4) + + # Must not raise. + compensator.submit(_rules("x"), {}, "before_model", "t", "ts", "a", "r") + + +# --------------------------------------------------------------------------- +# GuardrailCompensator.submit — wire-model assembly + provider invocation +# --------------------------------------------------------------------------- + + +def test_submit_invokes_provider_with_govern_request() -> None: + """The provider receives a GovernRequest carrying every wire field.""" + provider = _provider() + compensator = GuardrailCompensator(provider) + _run_inline(compensator) + rules = _rules("pii_detection", "harmful_content") + + compensator.submit( + rules, + {"content": "x"}, + "before_model", + "trace-1", + "2026-06-06T00:00:00Z", + "langchain", + "patch-langchain", + ) + + provider.compensate.assert_called_once() + (request,) = provider.compensate.call_args.args + assert isinstance(request, GovernRequest) + # distinct validators drive the guardrail API call + assert request.validators == ["pii_detection", "harmful_content"] + assert request.rules == rules + assert request.data == {"content": "x"} + assert request.hook == "before_model" + assert request.trace_id == "trace-1" + assert request.src_timestamp == "2026-06-06T00:00:00Z" + assert request.agent_name == "langchain" + assert request.runtime_id == "patch-langchain" + # Job-context fields are left for the provider to auto-fill from env. + assert request.folder_key is None + assert request.job_key is None + assert request.process_key is None + assert request.reference_id is None + assert request.agent_version is None + + +def test_submit_dedupes_validators() -> None: + """Multiple rules with the same validator collapse on the wire.""" + provider = _provider() + compensator = GuardrailCompensator(provider) + _run_inline(compensator) + rules = _rules("pii_detection") + _rules("pii_detection", rule_id="R2") + + compensator.submit(rules, {}, "before_model", "t", "ts", "a", "r") + + (request,) = provider.compensate.call_args.args + assert request.validators == ["pii_detection"] + # Per-rule metadata is preserved (one record per rule even with shared validator). + assert len(request.rules) == 2 + + +def test_submit_swallows_provider_errors() -> None: + """A provider exception must never propagate to the caller / agent.""" + provider = _provider() + provider.compensate.side_effect = RuntimeError("network down") + compensator = GuardrailCompensator(provider) + _run_inline(compensator) + + # Must not raise. + compensator.submit(_rules("x"), {}, "before_model", "t", "ts", "a", "r") + + provider.compensate.assert_called_once() + + +def test_submit_releases_semaphore_on_provider_error() -> None: + """Provider failure must not leak a semaphore slot.""" + provider = _provider() + provider.compensate.side_effect = RuntimeError("transient") + # 4 workers × 1 oversubscription = 4 slots total. + compensator = GuardrailCompensator(provider, inflight_oversubscription=1) + _run_inline(compensator) + + # Fire 8 — all 8 must reach the provider; the semaphore must release + # on each error so the next submit can acquire. + for _ in range(8): + compensator.submit(_rules("x"), {}, "before_model", "t", "ts", "a", "r") + + assert provider.compensate.call_count == 8, ( + "All 8 submissions should fire — semaphore must release on error" + ) + + +# --------------------------------------------------------------------------- +# _resolve_trace_id — must capture the live trace on the caller thread +# --------------------------------------------------------------------------- + + +def test_resolve_trace_id_prefers_supplied_over_active_span() -> None: + """Constructor-supplied trace id wins over a live span. + + The wiring layer (uipath CLI) reads ``UIPATH_TRACE_ID`` and passes + the value into :class:`GuardrailCompensator`. That id is + authoritative because native governance audit spans are exported + under it (platform rebinds spans to the agent's run trace) and + server-written compensation records must land on the same id. + """ + from opentelemetry.sdk.trace import TracerProvider + + tracer = TracerProvider().get_tracer("test") + with tracer.start_as_current_span("root"): + assert _resolve_trace_id("supplied-0001", "fallback-id") == "supplied-0001" + + +def test_resolve_trace_id_falls_back_to_active_span_when_not_supplied() -> None: + """No supplied id → the live span's trace id is used.""" + from opentelemetry.sdk.trace import TracerProvider + + tracer = TracerProvider().get_tracer("test") + with tracer.start_as_current_span("root") as span: + expected = format(span.get_span_context().trace_id, "032x") + result = _resolve_trace_id(None, "fallback-id") + assert result == expected + assert len(result) == 32 # dashless OTel hex, not a dashed uuid + + +def test_resolve_trace_id_uses_fallback_without_context() -> None: + """No supplied id and no active span → fallback wins.""" + assert _resolve_trace_id(None, "fallback-id") == "fallback-id" + + +def test_resolve_trace_id_does_not_read_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Runtime layer must not read host env vars; only the wiring layer does. + + Pin radu's PR #121 boundary rule for this code path. Even when + ``UIPATH_TRACE_ID`` is set in the environment, ``_resolve_trace_id`` + ignores it — the wiring layer is solely responsible for env reads. + """ + monkeypatch.setenv("UIPATH_TRACE_ID", "env-should-be-ignored") + # No supplied, no active span → fallback should win, NOT the env value. + assert _resolve_trace_id(None, "fallback-id") == "fallback-id" + + +def test_compensator_trace_id_overrides_caller_supplied_value() -> None: + """A compensator constructed with ``trace_id`` stamps it on every dispatch. + + The wiring layer passes ``UIPATH_TRACE_ID`` into the compensator at + construction; per-call ``trace_id`` arguments become only a fallback + for the case where the constructor value is absent. + """ + provider = _provider() + compensator = GuardrailCompensator(provider, trace_id="wired-trace-0001") + _run_inline(compensator) + + compensator.submit( + _rules("pii_detection"), + {}, + "before_model", + "per-call-fallback", # must lose to the constructor value + "ts", + "agent", + "run", + ) + + (request,) = provider.compensate.call_args.args + assert request.trace_id == "wired-trace-0001" + + +def test_submit_captures_live_trace_before_thread_hop() -> None: + """End-to-end thread-boundary proof. + + ``submit`` runs on the caller (hook) thread, then hands the + compensation call to a background worker pool. The trace id must + be resolved on the caller (where the OTel span is live) and + carried into the worker — the worker has no live OTel context. + """ + from opentelemetry.sdk.trace import TracerProvider + + tracer = TracerProvider().get_tracer("test") + provider = _provider() + compensator = GuardrailCompensator(provider) + + done = threading.Event() + captured: dict[str, Any] = {} + + def _capture(request: GovernRequest) -> None: + # Runs on the background worker thread. + captured["trace_id"] = request.trace_id + # Prove the worker has NO live context: resolving here with no + # supplied id and no live span falls all the way through to the + # WORKER-MISS sentinel. + captured["worker_resolves_to"] = _resolve_trace_id(None, "WORKER-MISS") + done.set() + + provider.compensate.side_effect = _capture + + with tracer.start_as_current_span("agent-run") as span: + expected = format(span.get_span_context().trace_id, "032x") + compensator.submit( + _rules("pii_detection"), + {"content": "x"}, + "before_model", + "stale-fallback", # must be overridden by the live trace + "2026-06-06T00:00:00Z", + "agent", + "rt", + ) + assert done.wait(timeout=2.0), "compensation worker never ran" + + # (1) worker thread could not see the span — fell back to the sentinel + assert captured["worker_resolves_to"] == "WORKER-MISS" + # (2) the value the provider received is the live span trace, captured pre-hop + assert captured["trace_id"] == expected + assert captured["trace_id"] != "stale-fallback" + + +# --------------------------------------------------------------------------- +# Cross-instance isolation — the architectural motivation for the refactor +# --------------------------------------------------------------------------- + + +def test_two_compensators_do_not_share_pool_or_semaphore() -> None: + """Parallel runtimes cannot saturate each other's compensation pool.""" + p1 = _provider() + p2 = _provider() + c1 = GuardrailCompensator(p1) + c2 = GuardrailCompensator(p2) + + assert c1._pool is not c2._pool + assert c1._inflight is not c2._inflight + + # Drain c1's semaphore to its cap; c2 must remain unaffected. + drained = threading.BoundedSemaphore(1) + drained.acquire() + c1._inflight = drained + + _run_inline(c2) + c2.submit(_rules("pii_detection"), {}, "before_model", "t", "ts", "a", "r") + p2.compensate.assert_called_once() + p1.compensate.assert_not_called() + + +# --------------------------------------------------------------------------- +# Lifecycle — bounded atexit + weakref tracking (mirrors AuditManager pattern) +# --------------------------------------------------------------------------- + + +def test_three_compensators_register_one_process_atexit_hook() -> None: + """N compensators → 1 atexit registration, not N. + + Regression: a per-instance ``atexit.register(self.close)`` would + grow the atexit list linearly. The fix routes everyone through one + process-level cleanup hook keyed by a WeakSet. + """ + with patch.object(guardrail_compensation.atexit, "register") as mock_register: + guardrail_compensation._atexit_registered = False + GuardrailCompensator(_provider()) + GuardrailCompensator(_provider()) + GuardrailCompensator(_provider()) + assert mock_register.call_count == 1, ( + "Each compensator must NOT register its own atexit handler" + ) + + +def test_disposed_compensator_can_be_garbage_collected() -> None: + """The WeakSet must NOT keep a disposed compensator alive.""" + import weakref + + compensator = GuardrailCompensator(_provider()) + ref = weakref.ref(compensator) + + assert compensator in guardrail_compensation._live_compensators + + compensator.close() + del compensator + gc.collect() + + assert ref() is None, ( + "GuardrailCompensator kept alive — strong reference leak in cleanup machinery" + ) + + +def test_process_cleanup_handles_already_closed_compensator() -> None: + """If a compensator was explicitly closed, the process hook is a no-op for it.""" + c = GuardrailCompensator(_provider()) + c.close() + # Must not raise. + guardrail_compensation._process_cleanup_compensators() + + +def test_close_is_idempotent() -> None: + """Calling close() twice is a logged no-op, not a crash.""" + c = GuardrailCompensator(_provider()) + c.close() + c.close() # must not raise