Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions go/api/adk/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,32 @@ type EmbeddingConfig struct {
Provider string `json:"provider"`
Model string `json:"model"`
BaseUrl string `json:"base_url,omitempty"`

// TLS/SSL configuration (mirrors BaseModel) so the embedding HTTP client
// honours the same ModelConfig.spec.tls as the chat path.
TLSInsecureSkipVerify *bool `json:"tls_insecure_skip_verify,omitempty"`
TLSCACertPath *string `json:"tls_ca_cert_path,omitempty"`
TLSDisableSystemCAs *bool `json:"tls_disable_system_cas,omitempty"`
}

func (e *EmbeddingConfig) UnmarshalJSON(data []byte) error {
var tmp struct {
Type string `json:"type"`
Provider string `json:"provider"`
Model string `json:"model"`
BaseUrl string `json:"base_url"`
Type string `json:"type"`
Provider string `json:"provider"`
Model string `json:"model"`
BaseUrl string `json:"base_url"`
TLSInsecureSkipVerify *bool `json:"tls_insecure_skip_verify"`
TLSCACertPath *string `json:"tls_ca_cert_path"`
TLSDisableSystemCAs *bool `json:"tls_disable_system_cas"`
}
if err := json.Unmarshal(data, &tmp); err != nil {
return err
}
e.Model = tmp.Model
e.BaseUrl = tmp.BaseUrl
e.TLSInsecureSkipVerify = tmp.TLSInsecureSkipVerify
e.TLSCACertPath = tmp.TLSCACertPath
e.TLSDisableSystemCAs = tmp.TLSDisableSystemCAs
if tmp.Provider != "" {
e.Provider = tmp.Provider
} else {
Expand All @@ -404,28 +416,45 @@ func ModelToEmbeddingConfig(m Model) *EmbeddingConfig {
return nil
}
e := &EmbeddingConfig{Provider: m.GetType()}
// copyTLS copies the TLS pointer fields from a model's embedded BaseModel
// onto the EmbeddingConfig so the Python embedding client honours the same
// ModelConfig.spec.tls as the chat path.
copyTLS := func(b BaseModel) {
e.TLSInsecureSkipVerify = b.TLSInsecureSkipVerify
e.TLSCACertPath = b.TLSCACertPath
e.TLSDisableSystemCAs = b.TLSDisableSystemCAs
}
Comment on lines +422 to +426
switch v := m.(type) {
case *OpenAI:
e.Model = v.Model
e.BaseUrl = v.BaseUrl
copyTLS(v.BaseModel)
case *AzureOpenAI:
e.Model = v.Model
copyTLS(v.BaseModel)
case *Anthropic:
e.Model = v.Model
e.BaseUrl = v.BaseUrl
copyTLS(v.BaseModel)
case *GeminiVertexAI:
e.Model = v.Model
copyTLS(v.BaseModel)
case *GeminiAnthropic:
e.Model = v.Model
copyTLS(v.BaseModel)
case *Ollama:
e.Model = v.Model
copyTLS(v.BaseModel)
case *Gemini:
e.Model = v.Model
copyTLS(v.BaseModel)
case *Bedrock:
e.Model = v.Model
copyTLS(v.BaseModel)
case *SAPAICore:
e.Model = v.Model
e.BaseUrl = v.BaseUrl
copyTLS(v.BaseModel)
default:
e.Model = ""
}
Expand Down
83 changes: 83 additions & 0 deletions go/api/adk/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,89 @@ func TestEmbeddingConfig_UnmarshalJSON_ProviderOverridesType(t *testing.T) {
}
}

// TestModelToEmbeddingConfig_PropagatesTLS asserts that ModelToEmbeddingConfig
// copies the TLS fields from a model's embedded BaseModel onto the
// EmbeddingConfig, so the embedding HTTP client in the Python runtime honours
// the same ModelConfig.spec.tls as the chat/LLM path (upstream issue #1992).
func TestModelToEmbeddingConfig_PropagatesTLS(t *testing.T) {
base := BaseModel{
Model: "text-embedding-3-small",
TLSInsecureSkipVerify: new(true),
TLSCACertPath: new("/etc/ssl/certs/custom/corp-ca/ca.crt"),
TLSDisableSystemCAs: new(false),
}

tests := []struct {
name string
model Model
}{
{name: "OpenAI", model: &OpenAI{BaseModel: base, BaseUrl: "https://litellm.internal.corp:8080"}},
{name: "AzureOpenAI", model: &AzureOpenAI{BaseModel: base}},
{name: "Anthropic", model: &Anthropic{BaseModel: base}},
{name: "GeminiVertexAI", model: &GeminiVertexAI{BaseModel: base}},
{name: "GeminiAnthropic", model: &GeminiAnthropic{BaseModel: base}},
{name: "Ollama", model: &Ollama{BaseModel: base}},
{name: "Gemini", model: &Gemini{BaseModel: base}},
{name: "Bedrock", model: &Bedrock{BaseModel: base}},
{name: "SAPAICore", model: &SAPAICore{BaseModel: base}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ModelToEmbeddingConfig(tt.model)
if e == nil {
t.Fatal("ModelToEmbeddingConfig() returned nil")
}
if e.TLSInsecureSkipVerify == nil || !*e.TLSInsecureSkipVerify {
t.Errorf("TLSInsecureSkipVerify = %v, want pointer to true", e.TLSInsecureSkipVerify)
}
if e.TLSCACertPath == nil || *e.TLSCACertPath != "/etc/ssl/certs/custom/corp-ca/ca.crt" {
t.Errorf("TLSCACertPath = %v, want pointer to %q", e.TLSCACertPath, "/etc/ssl/certs/custom/corp-ca/ca.crt")
}
if e.TLSDisableSystemCAs == nil || *e.TLSDisableSystemCAs {
t.Errorf("TLSDisableSystemCAs = %v, want pointer to false", e.TLSDisableSystemCAs)
}
})
}
}

// TestModelToEmbeddingConfig_NoTLS asserts that a model without TLS config
// yields an EmbeddingConfig with nil TLS pointers, so the Python embedding
// client keeps its default (no custom httpx client) behaviour.
func TestModelToEmbeddingConfig_NoTLS(t *testing.T) {
e := ModelToEmbeddingConfig(&OpenAI{BaseModel: BaseModel{Model: "text-embedding-3-small"}})
if e == nil {
t.Fatal("ModelToEmbeddingConfig() returned nil")
}
if e.TLSInsecureSkipVerify != nil || e.TLSCACertPath != nil || e.TLSDisableSystemCAs != nil {
t.Errorf("expected nil TLS pointers, got %v %v %v", e.TLSInsecureSkipVerify, e.TLSCACertPath, e.TLSDisableSystemCAs)
}
}

func TestEmbeddingConfig_UnmarshalJSON_TLSFields(t *testing.T) {
data := []byte(`{
"provider":"openai",
"model":"text-embedding-3-small",
"base_url":"https://litellm.internal.corp:8080",
"tls_insecure_skip_verify":true,
"tls_ca_cert_path":"/etc/ssl/certs/custom/corp-ca/ca.crt",
"tls_disable_system_cas":false
}`)
var cfg EmbeddingConfig
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("UnmarshalJSON() error = %v", err)
}
if cfg.TLSInsecureSkipVerify == nil || !*cfg.TLSInsecureSkipVerify {
t.Errorf("TLSInsecureSkipVerify = %v, want pointer to true", cfg.TLSInsecureSkipVerify)
}
if cfg.TLSCACertPath == nil || *cfg.TLSCACertPath != "/etc/ssl/certs/custom/corp-ca/ca.crt" {
t.Errorf("TLSCACertPath = %v, want pointer to ca.crt path", cfg.TLSCACertPath)
}
if cfg.TLSDisableSystemCAs == nil || *cfg.TLSDisableSystemCAs {
t.Errorf("TLSDisableSystemCAs = %v, want pointer to false", cfg.TLSDisableSystemCAs)
}
}

func TestAgentConfig_ScanAndValue(t *testing.T) {
original := AgentConfig{
Model: &OpenAI{BaseModel: BaseModel{Model: "gpt-4o"}},
Expand Down
36 changes: 34 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/models/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import os
from typing import Any, List, Union

import httpx
import numpy as np

from kagent.adk.models._ssl import create_ssl_context
from kagent.adk.types import EmbeddingConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,9 +144,33 @@ def _normalize_l2(self, x: Union[List[float], np.ndarray]) -> np.ndarray:
norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
return np.where(norm == 0, x, x / norm)

def _has_tls_config(self) -> bool:
"""Return True if any TLS field is set on the embedding config."""
return bool(
getattr(self.config, "tls_disable_verify", None)
or getattr(self.config, "tls_ca_cert_path", None)
or getattr(self.config, "tls_disable_system_cas", None)
)

def _tls_http_client(self) -> "httpx.AsyncClient | None":
"""Build a TLS-aware httpx client from the embedding config.

Returns None when no TLS field is set so the OpenAI SDK keeps its
default http client (preserving the pre-existing default behaviour).
"""
if not self._has_tls_config():
return None
verify = create_ssl_context(
disable_verify=bool(getattr(self.config, "tls_disable_verify", None)),
ca_cert_path=getattr(self.config, "tls_ca_cert_path", None),
disable_system_cas=bool(getattr(self.config, "tls_disable_system_cas", None)),
)
return httpx.AsyncClient(verify=verify)

async def _embed_openai(self, texts: List[str]) -> List[List[float]]:
"""Embed using the OpenAI or Azure OpenAI SDK."""
provider = self.config.provider.lower()
http_client = self._tls_http_client()

if provider == "azure_openai":
from openai import AsyncAzureOpenAI
Expand All @@ -153,11 +179,17 @@ async def _embed_openai(self, texts: List[str]) -> List[List[float]]:
api_base = self.config.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not api_base:
raise ValueError("Azure OpenAI endpoint must be set via base_url or AZURE_OPENAI_ENDPOINT env var")
client = AsyncAzureOpenAI(api_version=api_version, azure_endpoint=api_base)
azure_kwargs: dict[str, Any] = {"api_version": api_version, "azure_endpoint": api_base}
if http_client is not None:
azure_kwargs["http_client"] = http_client
client = AsyncAzureOpenAI(**azure_kwargs)
else:
from openai import AsyncOpenAI

client = AsyncOpenAI(base_url=self.config.base_url or None)
openai_kwargs: dict[str, Any] = {"base_url": self.config.base_url or None}
if http_client is not None:
openai_kwargs["http_client"] = http_client
client = AsyncOpenAI(**openai_kwargs)
Comment on lines 170 to +192

response = await client.embeddings.create(
model=self.config.model,
Expand Down
10 changes: 10 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ class EmbeddingConfig(BaseModel):
provider: str
base_url: str | None = None

# TLS/SSL configuration mirrors BaseLLM so the embedding client honours the
# same ModelConfig.spec.tls as the chat path. The Go controller serialises
# the disable-verify flag as "tls_insecure_skip_verify"; accept both names.
tls_disable_verify: bool | None = Field(
default=None,
validation_alias=AliasChoices("tls_disable_verify", "tls_insecure_skip_verify"),
)
tls_ca_cert_path: str | None = None
tls_disable_system_cas: bool | None = None


class MemoryConfig(BaseModel):
"""Memory configuration. Its presence signals that memory is enabled."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Unit tests for embedding-client TLS configuration.

These tests verify that KAgentEmbedding._embed_openai honours the TLS fields
carried on EmbeddingConfig (upstream issue #1992): the OpenAI / Azure OpenAI
clients must receive an httpx client whose ``verify`` reflects the configured
ModelConfig.spec.tls, and the default (no-TLS) path must be unchanged.
"""

import ssl
from unittest import mock

import pytest

from kagent.adk.models._embedding import KAgentEmbedding
from kagent.adk.types import EmbeddingConfig


@pytest.mark.asyncio
async def test_embed_openai_disable_verify_builds_client_with_verify_false():
"""disable_verify=True → create_ssl_context returns False → httpx verify=False."""
config = EmbeddingConfig(
provider="openai",
model="text-embedding-3-small",
base_url="https://litellm.internal.corp:8080",
tls_insecure_skip_verify=True,
)
embedding = KAgentEmbedding(config)

with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl:
with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx:
with mock.patch("openai.AsyncOpenAI") as mock_openai:
mock_create_ssl.return_value = False
mock_client = mock.MagicMock()
mock_httpx.return_value = mock_client
mock_openai.return_value.embeddings.create = mock.AsyncMock(
return_value=mock.MagicMock(data=[])
)

await embedding._embed_openai(["hello"])

# SSL context built from the embedding TLS config
mock_create_ssl.assert_called_once_with(
disable_verify=True,
ca_cert_path=None,
disable_system_cas=False,
)
# httpx client created with verify reflecting the config
mock_httpx.assert_called_once_with(verify=False)
# OpenAI client received the custom http_client
openai_kwargs = mock_openai.call_args[1]
assert openai_kwargs["http_client"] is mock_client


@pytest.mark.asyncio
async def test_embed_openai_custom_ca_builds_client_with_ssl_context():
"""A custom CA cert path is threaded into create_ssl_context and httpx verify."""
config = EmbeddingConfig(
provider="openai",
model="text-embedding-3-small",
tls_ca_cert_path="/etc/ssl/certs/custom/corp-ca/ca.crt",
tls_disable_system_cas=True,
)
embedding = KAgentEmbedding(config)

with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl:
with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx:
with mock.patch("openai.AsyncOpenAI") as mock_openai:
ssl_context = mock.MagicMock(spec=ssl.SSLContext)
mock_create_ssl.return_value = ssl_context
mock_client = mock.MagicMock()
mock_httpx.return_value = mock_client
mock_openai.return_value.embeddings.create = mock.AsyncMock(
return_value=mock.MagicMock(data=[])
)

await embedding._embed_openai(["hello"])

mock_create_ssl.assert_called_once_with(
disable_verify=False,
ca_cert_path="/etc/ssl/certs/custom/corp-ca/ca.crt",
disable_system_cas=True,
)
mock_httpx.assert_called_once_with(verify=ssl_context)
openai_kwargs = mock_openai.call_args[1]
assert openai_kwargs["http_client"] is mock_client


@pytest.mark.asyncio
async def test_embed_openai_no_tls_keeps_default_client():
"""Without any TLS field, no custom http client is built (default path)."""
config = EmbeddingConfig(
provider="openai",
model="text-embedding-3-small",
)
embedding = KAgentEmbedding(config)

with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl:
with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx:
with mock.patch("openai.AsyncOpenAI") as mock_openai:
mock_openai.return_value.embeddings.create = mock.AsyncMock(
return_value=mock.MagicMock(data=[])
)

await embedding._embed_openai(["hello"])

# No TLS config → no SSL context / httpx client built
mock_create_ssl.assert_not_called()
mock_httpx.assert_not_called()
# OpenAI client built without an explicit http_client
openai_kwargs = mock_openai.call_args[1]
assert "http_client" not in openai_kwargs


@pytest.mark.asyncio
async def test_embed_azure_openai_threads_tls_http_client():
"""Azure OpenAI client also receives the TLS-aware http client."""
config = EmbeddingConfig(
provider="azure_openai",
model="text-embedding-3-small",
base_url="https://my-azure.openai.azure.com",
tls_insecure_skip_verify=True,
)
embedding = KAgentEmbedding(config)

with mock.patch("kagent.adk.models._embedding.create_ssl_context") as mock_create_ssl:
with mock.patch("kagent.adk.models._embedding.httpx.AsyncClient") as mock_httpx:
with mock.patch("openai.AsyncAzureOpenAI") as mock_azure:
mock_create_ssl.return_value = False
mock_client = mock.MagicMock()
mock_httpx.return_value = mock_client
mock_azure.return_value.embeddings.create = mock.AsyncMock(
return_value=mock.MagicMock(data=[])
)

await embedding._embed_openai(["hello"])

mock_create_ssl.assert_called_once_with(
disable_verify=True,
ca_cert_path=None,
disable_system_cas=False,
)
azure_kwargs = mock_azure.call_args[1]
assert azure_kwargs["http_client"] is mock_client
Comment on lines +125 to +143
Loading