From bd84f91a56ff563c23a73a0a28a80722d840dc8e Mon Sep 17 00:00:00 2001 From: "David J. M. Karlsen" Date: Wed, 10 Jun 2026 17:26:57 +0200 Subject: [PATCH] fix: honor ModelConfig.spec.tls in the embedding path The chat/LLM path threads spec.tls (disableVerify / caCertSecretRef / disableSystemCAs) into a custom httpx client, but the embedding path ignored TLS config entirely: EmbeddingConfig carried no TLS fields and _embed_openai built the OpenAI client with no http_client override. As a result, embeddings (session-memory auto-save) fail behind a TLS-inspecting proxy even when the embedding ModelConfig sets spec.tls.disableVerify: true. This is a full-chain fix: the TLS settings only reach the Python agent via the Go-serialized config, so a Python-only change would be a no-op. Go: - Add TLSInsecureSkipVerify / TLSCACertPath / TLSDisableSystemCAs to EmbeddingConfig (+ UnmarshalJSON parsing). - ModelToEmbeddingConfig copies the TLS pointers from each model variant's BaseModel. (translateEmbeddingConfig already runs deriveTLSFields via translateModel, so the source model carries them.) Python: - Add tls_disable_verify / tls_ca_cert_path / tls_disable_system_cas to EmbeddingConfig (accepting the Go wire name tls_insecure_skip_verify via AliasChoices). - _embed_openai builds http_client=httpx.AsyncClient(verify=create_ssl_context(...)) for both AsyncOpenAI and AsyncAzureOpenAI, only when a TLS field is set so the default path is unchanged. Adds Go and Python unit tests covering propagation and client wiring. Fixes #1992 Signed-off-by: David J. M. Karlsen --- go/api/adk/types.go | 37 ++++- go/api/adk/types_test.go | 83 ++++++++++ .../src/kagent/adk/models/_embedding.py | 36 ++++- .../kagent-adk/src/kagent/adk/types.py | 10 ++ .../unittests/models/test_embedding_tls.py | 143 ++++++++++++++++++ 5 files changed, 303 insertions(+), 6 deletions(-) create mode 100644 python/packages/kagent-adk/tests/unittests/models/test_embedding_tls.py diff --git a/go/api/adk/types.go b/go/api/adk/types.go index 4f15f9f899..cd93c4486d 100644 --- a/go/api/adk/types.go +++ b/go/api/adk/types.go @@ -375,20 +375,32 @@ type EmbeddingConfig struct { Provider string `json:"provider"` Model string `json:"model"` BaseUrl string `json:"base_url,omitempty"` + + // TLS/SSL configuration (mirrors BaseModel) so the embedding HTTP client + // honours the same ModelConfig.spec.tls as the chat path. + TLSInsecureSkipVerify *bool `json:"tls_insecure_skip_verify,omitempty"` + TLSCACertPath *string `json:"tls_ca_cert_path,omitempty"` + TLSDisableSystemCAs *bool `json:"tls_disable_system_cas,omitempty"` } func (e *EmbeddingConfig) UnmarshalJSON(data []byte) error { var tmp struct { - Type string `json:"type"` - Provider string `json:"provider"` - Model string `json:"model"` - BaseUrl string `json:"base_url"` + Type string `json:"type"` + Provider string `json:"provider"` + Model string `json:"model"` + BaseUrl string `json:"base_url"` + TLSInsecureSkipVerify *bool `json:"tls_insecure_skip_verify"` + TLSCACertPath *string `json:"tls_ca_cert_path"` + TLSDisableSystemCAs *bool `json:"tls_disable_system_cas"` } if err := json.Unmarshal(data, &tmp); err != nil { return err } e.Model = tmp.Model e.BaseUrl = tmp.BaseUrl + e.TLSInsecureSkipVerify = tmp.TLSInsecureSkipVerify + e.TLSCACertPath = tmp.TLSCACertPath + e.TLSDisableSystemCAs = tmp.TLSDisableSystemCAs if tmp.Provider != "" { e.Provider = tmp.Provider } else { @@ -404,28 +416,45 @@ func ModelToEmbeddingConfig(m Model) *EmbeddingConfig { return nil } e := &EmbeddingConfig{Provider: m.GetType()} + // copyTLS copies the TLS pointer fields from a model's embedded BaseModel + // onto the EmbeddingConfig so the Python embedding client honours the same + // ModelConfig.spec.tls as the chat path. + copyTLS := func(b BaseModel) { + e.TLSInsecureSkipVerify = b.TLSInsecureSkipVerify + e.TLSCACertPath = b.TLSCACertPath + e.TLSDisableSystemCAs = b.TLSDisableSystemCAs + } switch v := m.(type) { case *OpenAI: e.Model = v.Model e.BaseUrl = v.BaseUrl + copyTLS(v.BaseModel) case *AzureOpenAI: e.Model = v.Model + copyTLS(v.BaseModel) case *Anthropic: e.Model = v.Model e.BaseUrl = v.BaseUrl + copyTLS(v.BaseModel) case *GeminiVertexAI: e.Model = v.Model + copyTLS(v.BaseModel) case *GeminiAnthropic: e.Model = v.Model + copyTLS(v.BaseModel) case *Ollama: e.Model = v.Model + copyTLS(v.BaseModel) case *Gemini: e.Model = v.Model + copyTLS(v.BaseModel) case *Bedrock: e.Model = v.Model + copyTLS(v.BaseModel) case *SAPAICore: e.Model = v.Model e.BaseUrl = v.BaseUrl + copyTLS(v.BaseModel) default: e.Model = "" } diff --git a/go/api/adk/types_test.go b/go/api/adk/types_test.go index d1d9c8c06b..7fd9122e00 100644 --- a/go/api/adk/types_test.go +++ b/go/api/adk/types_test.go @@ -874,6 +874,89 @@ func TestEmbeddingConfig_UnmarshalJSON_ProviderOverridesType(t *testing.T) { } } +// TestModelToEmbeddingConfig_PropagatesTLS asserts that ModelToEmbeddingConfig +// copies the TLS fields from a model's embedded BaseModel onto the +// EmbeddingConfig, so the embedding HTTP client in the Python runtime honours +// the same ModelConfig.spec.tls as the chat/LLM path (upstream issue #1992). +func TestModelToEmbeddingConfig_PropagatesTLS(t *testing.T) { + base := BaseModel{ + Model: "text-embedding-3-small", + TLSInsecureSkipVerify: new(true), + TLSCACertPath: new("/etc/ssl/certs/custom/corp-ca/ca.crt"), + TLSDisableSystemCAs: new(false), + } + + tests := []struct { + name string + model Model + }{ + {name: "OpenAI", model: &OpenAI{BaseModel: base, BaseUrl: "https://litellm.internal.corp:8080"}}, + {name: "AzureOpenAI", model: &AzureOpenAI{BaseModel: base}}, + {name: "Anthropic", model: &Anthropic{BaseModel: base}}, + {name: "GeminiVertexAI", model: &GeminiVertexAI{BaseModel: base}}, + {name: "GeminiAnthropic", model: &GeminiAnthropic{BaseModel: base}}, + {name: "Ollama", model: &Ollama{BaseModel: base}}, + {name: "Gemini", model: &Gemini{BaseModel: base}}, + {name: "Bedrock", model: &Bedrock{BaseModel: base}}, + {name: "SAPAICore", model: &SAPAICore{BaseModel: base}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ModelToEmbeddingConfig(tt.model) + if e == nil { + t.Fatal("ModelToEmbeddingConfig() returned nil") + } + if e.TLSInsecureSkipVerify == nil || !*e.TLSInsecureSkipVerify { + t.Errorf("TLSInsecureSkipVerify = %v, want pointer to true", e.TLSInsecureSkipVerify) + } + if e.TLSCACertPath == nil || *e.TLSCACertPath != "/etc/ssl/certs/custom/corp-ca/ca.crt" { + t.Errorf("TLSCACertPath = %v, want pointer to %q", e.TLSCACertPath, "/etc/ssl/certs/custom/corp-ca/ca.crt") + } + if e.TLSDisableSystemCAs == nil || *e.TLSDisableSystemCAs { + t.Errorf("TLSDisableSystemCAs = %v, want pointer to false", e.TLSDisableSystemCAs) + } + }) + } +} + +// TestModelToEmbeddingConfig_NoTLS asserts that a model without TLS config +// yields an EmbeddingConfig with nil TLS pointers, so the Python embedding +// client keeps its default (no custom httpx client) behaviour. +func TestModelToEmbeddingConfig_NoTLS(t *testing.T) { + e := ModelToEmbeddingConfig(&OpenAI{BaseModel: BaseModel{Model: "text-embedding-3-small"}}) + if e == nil { + t.Fatal("ModelToEmbeddingConfig() returned nil") + } + if e.TLSInsecureSkipVerify != nil || e.TLSCACertPath != nil || e.TLSDisableSystemCAs != nil { + t.Errorf("expected nil TLS pointers, got %v %v %v", e.TLSInsecureSkipVerify, e.TLSCACertPath, e.TLSDisableSystemCAs) + } +} + +func TestEmbeddingConfig_UnmarshalJSON_TLSFields(t *testing.T) { + data := []byte(`{ + "provider":"openai", + "model":"text-embedding-3-small", + "base_url":"https://litellm.internal.corp:8080", + "tls_insecure_skip_verify":true, + "tls_ca_cert_path":"/etc/ssl/certs/custom/corp-ca/ca.crt", + "tls_disable_system_cas":false + }`) + var cfg EmbeddingConfig + if err := json.Unmarshal(data, &cfg); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + if cfg.TLSInsecureSkipVerify == nil || !*cfg.TLSInsecureSkipVerify { + t.Errorf("TLSInsecureSkipVerify = %v, want pointer to true", cfg.TLSInsecureSkipVerify) + } + if cfg.TLSCACertPath == nil || *cfg.TLSCACertPath != "/etc/ssl/certs/custom/corp-ca/ca.crt" { + t.Errorf("TLSCACertPath = %v, want pointer to ca.crt path", cfg.TLSCACertPath) + } + if cfg.TLSDisableSystemCAs == nil || *cfg.TLSDisableSystemCAs { + t.Errorf("TLSDisableSystemCAs = %v, want pointer to false", cfg.TLSDisableSystemCAs) + } +} + func TestAgentConfig_ScanAndValue(t *testing.T) { original := AgentConfig{ Model: &OpenAI{BaseModel: BaseModel{Model: "gpt-4o"}}, diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_embedding.py b/python/packages/kagent-adk/src/kagent/adk/models/_embedding.py index 09810e1e5d..d94350d15f 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_embedding.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_embedding.py @@ -14,8 +14,10 @@ import os from typing import Any, List, Union +import httpx import numpy as np +from kagent.adk.models._ssl import create_ssl_context from kagent.adk.types import EmbeddingConfig logger = logging.getLogger(__name__) @@ -142,9 +144,33 @@ def _normalize_l2(self, x: Union[List[float], np.ndarray]) -> np.ndarray: norm = np.linalg.norm(x, 2, axis=1, keepdims=True) return np.where(norm == 0, x, x / norm) + def _has_tls_config(self) -> bool: + """Return True if any TLS field is set on the embedding config.""" + return bool( + getattr(self.config, "tls_disable_verify", None) + or getattr(self.config, "tls_ca_cert_path", None) + or getattr(self.config, "tls_disable_system_cas", None) + ) + + def _tls_http_client(self) -> "httpx.AsyncClient | None": + """Build a TLS-aware httpx client from the embedding config. + + Returns None when no TLS field is set so the OpenAI SDK keeps its + default http client (preserving the pre-existing default behaviour). + """ + if not self._has_tls_config(): + return None + verify = create_ssl_context( + disable_verify=bool(getattr(self.config, "tls_disable_verify", None)), + ca_cert_path=getattr(self.config, "tls_ca_cert_path", None), + disable_system_cas=bool(getattr(self.config, "tls_disable_system_cas", None)), + ) + return httpx.AsyncClient(verify=verify) + async def _embed_openai(self, texts: List[str]) -> List[List[float]]: """Embed using the OpenAI or Azure OpenAI SDK.""" provider = self.config.provider.lower() + http_client = self._tls_http_client() if provider == "azure_openai": from openai import AsyncAzureOpenAI @@ -153,11 +179,17 @@ async def _embed_openai(self, texts: List[str]) -> List[List[float]]: api_base = self.config.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") if not api_base: raise ValueError("Azure OpenAI endpoint must be set via base_url or AZURE_OPENAI_ENDPOINT env var") - client = AsyncAzureOpenAI(api_version=api_version, azure_endpoint=api_base) + azure_kwargs: dict[str, Any] = {"api_version": api_version, "azure_endpoint": api_base} + if http_client is not None: + azure_kwargs["http_client"] = http_client + client = AsyncAzureOpenAI(**azure_kwargs) else: from openai import AsyncOpenAI - client = AsyncOpenAI(base_url=self.config.base_url or None) + openai_kwargs: dict[str, Any] = {"base_url": self.config.base_url or None} + if http_client is not None: + openai_kwargs["http_client"] = http_client + client = AsyncOpenAI(**openai_kwargs) response = await client.embeddings.create( model=self.config.model, diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 889ed73e0b..f56b1d18df 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -349,6 +349,16 @@ class EmbeddingConfig(BaseModel): provider: str base_url: str | None = None + # TLS/SSL configuration mirrors BaseLLM so the embedding client honours the + # same ModelConfig.spec.tls as the chat path. The Go controller serialises + # the disable-verify flag as "tls_insecure_skip_verify"; accept both names. + tls_disable_verify: bool | None = Field( + default=None, + validation_alias=AliasChoices("tls_disable_verify", "tls_insecure_skip_verify"), + ) + tls_ca_cert_path: str | None = None + tls_disable_system_cas: bool | None = None + class MemoryConfig(BaseModel): """Memory configuration. Its presence signals that memory is enabled.""" diff --git a/python/packages/kagent-adk/tests/unittests/models/test_embedding_tls.py b/python/packages/kagent-adk/tests/unittests/models/test_embedding_tls.py new file mode 100644 index 0000000000..f76e8b5010 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/models/test_embedding_tls.py @@ -0,0 +1,143 @@ +"""Unit tests for embedding-client TLS configuration. + +These tests verify that KAgentEmbedding._embed_openai honours the TLS fields +carried on EmbeddingConfig (upstream issue #1992): the OpenAI / Azure OpenAI +clients must receive an httpx client whose ``verify`` reflects the configured +ModelConfig.spec.tls, and the default (no-TLS) path must be unchanged. +""" + +import ssl +from unittest import mock + +import pytest + +from kagent.adk.models._embedding import KAgentEmbedding +from kagent.adk.types import EmbeddingConfig + + +@pytest.mark.asyncio +async def test_embed_openai_disable_verify_builds_client_with_verify_false(): + """disable_verify=True → create_ssl_context returns False → httpx verify=False.""" + config = EmbeddingConfig( + provider="openai", + model="text-embedding-3-small", + base_url="https://litellm.internal.corp:8080", + tls_insecure_skip_verify=True, + ) + embedding = KAgentEmbedding(config) + + with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx: + with mock.patch("openai.AsyncOpenAI") as mock_openai: + mock_create_ssl.return_value = False + mock_client = mock.MagicMock() + mock_httpx.return_value = mock_client + mock_openai.return_value.embeddings.create = mock.AsyncMock( + return_value=mock.MagicMock(data=[]) + ) + + await embedding._embed_openai(["hello"]) + + # SSL context built from the embedding TLS config + mock_create_ssl.assert_called_once_with( + disable_verify=True, + ca_cert_path=None, + disable_system_cas=False, + ) + # httpx client created with verify reflecting the config + mock_httpx.assert_called_once_with(verify=False) + # OpenAI client received the custom http_client + openai_kwargs = mock_openai.call_args[1] + assert openai_kwargs["http_client"] is mock_client + + +@pytest.mark.asyncio +async def test_embed_openai_custom_ca_builds_client_with_ssl_context(): + """A custom CA cert path is threaded into create_ssl_context and httpx verify.""" + config = EmbeddingConfig( + provider="openai", + model="text-embedding-3-small", + tls_ca_cert_path="/etc/ssl/certs/custom/corp-ca/ca.crt", + tls_disable_system_cas=True, + ) + embedding = KAgentEmbedding(config) + + with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx: + with mock.patch("openai.AsyncOpenAI") as mock_openai: + ssl_context = mock.MagicMock(spec=ssl.SSLContext) + mock_create_ssl.return_value = ssl_context + mock_client = mock.MagicMock() + mock_httpx.return_value = mock_client + mock_openai.return_value.embeddings.create = mock.AsyncMock( + return_value=mock.MagicMock(data=[]) + ) + + await embedding._embed_openai(["hello"]) + + mock_create_ssl.assert_called_once_with( + disable_verify=False, + ca_cert_path="/etc/ssl/certs/custom/corp-ca/ca.crt", + disable_system_cas=True, + ) + mock_httpx.assert_called_once_with(verify=ssl_context) + openai_kwargs = mock_openai.call_args[1] + assert openai_kwargs["http_client"] is mock_client + + +@pytest.mark.asyncio +async def test_embed_openai_no_tls_keeps_default_client(): + """Without any TLS field, no custom http client is built (default path).""" + config = EmbeddingConfig( + provider="openai", + model="text-embedding-3-small", + ) + embedding = KAgentEmbedding(config) + + with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx: + with mock.patch("openai.AsyncOpenAI") as mock_openai: + mock_openai.return_value.embeddings.create = mock.AsyncMock( + return_value=mock.MagicMock(data=[]) + ) + + await embedding._embed_openai(["hello"]) + + # No TLS config → no SSL context / httpx client built + mock_create_ssl.assert_not_called() + mock_httpx.assert_not_called() + # OpenAI client built without an explicit http_client + openai_kwargs = mock_openai.call_args[1] + assert "http_client" not in openai_kwargs + + +@pytest.mark.asyncio +async def test_embed_azure_openai_threads_tls_http_client(): + """Azure OpenAI client also receives the TLS-aware http client.""" + config = EmbeddingConfig( + provider="azure_openai", + model="text-embedding-3-small", + base_url="https://my-azure.openai.azure.com", + tls_insecure_skip_verify=True, + ) + embedding = KAgentEmbedding(config) + + with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx: + with mock.patch("openai.AsyncAzureOpenAI") as mock_azure: + mock_create_ssl.return_value = False + mock_client = mock.MagicMock() + mock_httpx.return_value = mock_client + mock_azure.return_value.embeddings.create = mock.AsyncMock( + return_value=mock.MagicMock(data=[]) + ) + + await embedding._embed_openai(["hello"]) + + mock_create_ssl.assert_called_once_with( + disable_verify=True, + ca_cert_path=None, + disable_system_cas=False, + ) + azure_kwargs = mock_azure.call_args[1] + assert azure_kwargs["http_client"] is mock_client