diff --git a/go/api/adk/types.go b/go/api/adk/types.go index e3862000a7..0885f11016 100644 --- a/go/api/adk/types.go +++ b/go/api/adk/types.go @@ -124,6 +124,9 @@ func (o *OpenAI) GetType() string { type AzureOpenAI struct { BaseModel + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } func (a *AzureOpenAI) GetType() string { diff --git a/go/core/internal/controller/translator/agent/adk_api_translator.go b/go/core/internal/controller/translator/agent/adk_api_translator.go index 74490cd673..12b232048a 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator.go @@ -586,6 +586,9 @@ func (a *adkApiTranslator) translateModel(ctx context.Context, namespace, modelC Model: model.Spec.AzureOpenAI.DeploymentName, Headers: model.Spec.DefaultHeaders, }, + Temperature: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.Temperature), + TopP: utils.ParseStringToFloat64(model.Spec.AzureOpenAI.TopP), + MaxTokens: model.Spec.AzureOpenAI.MaxTokens, } // Populate TLS fields in BaseModel populateTLSFields(&azureOpenAI.BaseModel, model.Spec.TLS) diff --git a/go/core/internal/controller/translator/agent/adk_api_translator_test.go b/go/core/internal/controller/translator/agent/adk_api_translator_test.go index f48337277e..506ebb1cbf 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator_test.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator_test.go @@ -430,6 +430,49 @@ func Test_AdkApiTranslator_OllamaOptions(t *testing.T) { assert.Equal(t, "0.7", ollamaModel.Options["temperature"]) } +func Test_AdkApiTranslator_AzureOpenAIParams(t *testing.T) { + scheme := schemev1.Scheme + require.NoError(t, v1alpha2.AddToScheme(scheme)) + + maxTokens := 2048 + modelConfig := &v1alpha2.ModelConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "m", Namespace: "ns"}, + Spec: v1alpha2.ModelConfigSpec{ + Model: "gpt-4o", + Provider: v1alpha2.ModelProviderAzureOpenAI, + APIKeyPassthrough: true, + AzureOpenAI: &v1alpha2.AzureOpenAIConfig{ + Temperature: "0.5", + TopP: "0.9", + MaxTokens: &maxTokens, + }, + }, + } + agent := &v1alpha2.Agent{ + ObjectMeta: metav1.ObjectMeta{Name: "a", Namespace: "ns"}, + Spec: v1alpha2.AgentSpec{ + Type: v1alpha2.AgentType_Declarative, + Declarative: &v1alpha2.DeclarativeAgentSpec{ + SystemMessage: "x", + ModelConfig: "m", + }, + }, + } + + ns := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "ns"}} + kubeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(ns, modelConfig, agent).Build() + trans := translator.NewAdkApiTranslator(kubeClient, types.NamespacedName{Namespace: "ns", Name: "m"}, nil, "", nil) + + outputs, err := translator.TranslateAgent(context.Background(), trans, agent) + require.NoError(t, err) + + m, ok := outputs.Config.Model.(*adk.AzureOpenAI) + require.True(t, ok) + assert.Equal(t, new(0.5), m.Temperature) + assert.Equal(t, new(0.9), m.TopP) + assert.Equal(t, &maxTokens, m.MaxTokens) +} + func Test_AdkApiTranslator_ServiceAccountNameOverride(t *testing.T) { scheme := schemev1.Scheme require.NoError(t, v1alpha2.AddToScheme(scheme))