From 228eb27e6a3e6ddf673870c43aa3176fc613e61d Mon Sep 17 00:00:00 2001 From: rinfx Date: Wed, 8 Apr 2026 15:23:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai-proxy):=20=E6=96=B0=E5=A2=9E=20provider?= =?UTF-8?q?BasePath=20=E9=85=8D=E7=BD=AE=E5=B9=B6=E4=BC=98=E5=8C=96=20prov?= =?UTF-8?q?iderDomain=20=E5=A4=84=E7=90=86=E6=96=B9=E5=BC=8F=20(#3686)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/wasm-go/extensions/ai-proxy/main.go | 3 +- .../extensions/ai-proxy/provider/claude.go | 3 +- .../extensions/ai-proxy/provider/failover.go | 9 +- .../extensions/ai-proxy/provider/gemini.go | 6 +- .../extensions/ai-proxy/provider/generic.go | 2 + .../extensions/ai-proxy/provider/provider.go | 46 +- .../ai-proxy/provider/provider_test.go | 401 +++++++++++++++++- 7 files changed, 428 insertions(+), 42 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index dc34346c9..bb987ee54 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -300,7 +300,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig log.Errorf("failed to replace request body by custom settings: %v", settingErr) } // 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数 - if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) { + // generic provider 不做能力映射,不添加 stream_options + if providerConfig.IsOpenAIProtocol() && !providerConfig.IsGeneric() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) { newBody = normalizeOpenAiRequestBody(newBody) } log.Debugf("[onHttpRequestBody] newBody=%s", newBody) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 5c759892a..4b763ce75 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -323,8 +323,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities) - domain := c.config.resolveDomain("", claudeDomain) - util.OverwriteRequestHostHeader(headers, domain) + util.OverwriteRequestHostHeader(headers, claudeDomain) if c.config.apiVersion == "" { c.config.apiVersion = claudeDefaultVersion diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 9c1925f61..4397c31e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -10,11 +10,11 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" - "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/google/uuid" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -198,6 +198,11 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders) } + // Apply providerBasePath if configured + if c.providerBasePath != "" { + modifiedHeaders.Set(":path", c.applyProviderBasePath(modifiedHeaders.Get(":path"))) + } + var err error if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 1d45328f9..70a873e40 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -62,12 +62,11 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { config.setDefaultCapabilities(g.DefaultCapabilities()) - domain := config.resolveDomain("", geminiDomain) return &geminiProvider{ config: config, contextCache: createContextCache(&config), client: wrapper.NewClusterClient(wrapper.RouteCluster{ - Host: domain, + Host: geminiDomain, }), }, nil } @@ -90,8 +89,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { - domain := g.config.resolveDomain("", geminiDomain) - util.OverwriteRequestHostHeader(headers, domain) + util.OverwriteRequestHostHeader(headers, geminiDomain) headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/generic.go b/plugins/wasm-go/extensions/ai-proxy/provider/generic.go index ee35f65f7..494a5e9d9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/generic.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/generic.go @@ -52,6 +52,8 @@ func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN } func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + // buffer original request body + _ = proxywasm.ReplaceHttpRequestBody(body) return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 72be85c72..e82007ca0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -487,6 +487,9 @@ type ProviderConfig struct { // @Title zh-CN HiClaw 模式 // @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。 hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"` + // @Title zh-CN Provider 基础路径 + // @Description zh-CN 当配置了此值时,各个 Provider 在改写请求路径时会将其添加到路径前面,例如配置"/api/ai"后,请求路径"/v1/chat/completions"会被改写为"/api/ai/v1/chat/completions" + providerBasePath string `required:"false" yaml:"providerBasePath" json:"providerBasePath"` } func (c *ProviderConfig) GetId() string { @@ -501,20 +504,6 @@ func (c *ProviderConfig) GetProtocol() string { return c.protocol } -// resolveDomain resolves the domain to use based on priority: -// 1. providerDomain (generic override for all providers) -// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain) -// 3. default hardcoded domain -func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string { - if c.providerDomain != "" { - return c.providerDomain - } - if providerSpecificDomain != "" { - return providerSpecificDomain - } - return defaultDomain -} - func (c *ProviderConfig) GetVllmCustomUrl() string { return c.vllmCustomUrl } @@ -733,6 +722,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.mergeConsecutiveMessages = true c.promoteThinkingOnEmpty = true } + c.providerBasePath = json.Get("providerBasePath").String() } func (c *ProviderConfig) Validate() error { @@ -867,6 +857,10 @@ func (c *ProviderConfig) IsOriginal() bool { return c.protocol == protocolOriginal } +func (c *ProviderConfig) IsGeneric() bool { + return c.typ == providerTypeGeneric +} + func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool { return c.promoteThinkingOnEmpty } @@ -883,6 +877,14 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { return initializer.CreateProvider(pc) } +// applyProviderBasePath prepends the ProviderBasePath to the given path if configured. +func (c *ProviderConfig) applyProviderBasePath(path string) string { + if c.providerBasePath != "" && !strings.HasPrefix(path, c.providerBasePath) { + return c.providerBasePath + path + } + return path +} + func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error { switch req := request.(type) { case *chatCompletionRequest: @@ -1220,6 +1222,10 @@ func (c *ProviderConfig) handleRequestBody( } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { headers := util.GetRequestHeaders() body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers) + // Apply providerBasePath if configured + if c.providerBasePath != "" { + headers.Set(":path", c.applyProviderBasePath(headers.Get(":path"))) + } util.ReplaceRequestHeaders(headers) } else { body, err = c.defaultTransformRequestBody(ctx, apiName, body) @@ -1276,6 +1282,18 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) { headers.Set(":path", path.Join(c.basePath, headers.Get(":path"))) } + + // Apply providerBasePath if configured + currentPath := headers.Get(":path") + if c.providerBasePath != "" { + headers.Set(":path", c.applyProviderBasePath(currentPath)) + } + + // Apply providerDomain if configured (overrides any domain set by the provider) + if c.providerDomain != "" { + util.OverwriteRequestHostHeader(headers, c.providerDomain) + } + util.ReplaceRequestHeaders(headers) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go index 2c591716a..d2359a53e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -1,6 +1,7 @@ package provider import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -290,28 +291,390 @@ func TestProviderDomain_Config(t *testing.T) { }) } -func TestResolveDomain_Priority(t *testing.T) { - t.Run("providerDomain_takes_priority", func(t *testing.T) { - config := ProviderConfig{ - providerDomain: "universal-proxy.com", - } - result := config.resolveDomain("specific-domain.com", "default.com") - assert.Equal(t, "universal-proxy.com", result) +func TestProviderBasePath_Config(t *testing.T) { + t.Run("providerBasePath_field_exists", func(t *testing.T) { + config := ProviderConfig{} + config.FromJson(gjson.Result{}) + assert.Equal(t, "", config.providerBasePath) }) - t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) { - config := ProviderConfig{ - providerDomain: "", - } - result := config.resolveDomain("specific-domain.com", "default.com") - assert.Equal(t, "specific-domain.com", result) + t.Run("providerBasePath_parsed_from_json", func(t *testing.T) { + config := ProviderConfig{} + jsonStr := `{"providerBasePath": "/api/ai"}` + config.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, "/api/ai", config.providerBasePath) }) - t.Run("defaultDomain_when_both_empty", func(t *testing.T) { - config := ProviderConfig{ - providerDomain: "", - } - result := config.resolveDomain("", "default.com") - assert.Equal(t, "default.com", result) + t.Run("providerBasePath_with_other_config", func(t *testing.T) { + config := ProviderConfig{} + jsonStr := `{ + "type": "openai", + "apiToken": "sk-test", + "providerBasePath": "/api/v1", + "providerDomain": "proxy.example.com" + }` + config.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, "openai", config.typ) + assert.Equal(t, "/api/v1", config.providerBasePath) + assert.Equal(t, "proxy.example.com", config.providerDomain) + }) +} + +func TestApplyProviderBasePath(t *testing.T) { + tests := []struct { + name string + providerBasePath string + originalPath string + expectedPath string + }{ + { + name: "no_base_path_configured", + providerBasePath: "", + originalPath: "/v1/chat/completions", + expectedPath: "/v1/chat/completions", + }, + { + name: "base_path_prepended", + providerBasePath: "/api/ai", + originalPath: "/v1/chat/completions", + expectedPath: "/api/ai/v1/chat/completions", + }, + { + name: "path_already_has_base_path", + providerBasePath: "/api/ai", + originalPath: "/api/ai/v1/chat/completions", + expectedPath: "/api/ai/v1/chat/completions", + }, + { + name: "base_path_with_trailing_slash", + providerBasePath: "/api/ai/", + originalPath: "/v1/chat/completions", + expectedPath: "/api/ai//v1/chat/completions", + }, + { + name: "deep_base_path", + providerBasePath: "/internal/services/ai", + originalPath: "/v1/models", + expectedPath: "/internal/services/ai/v1/models", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &ProviderConfig{ + providerBasePath: tt.providerBasePath, + } + result := config.applyProviderBasePath(tt.originalPath) + assert.Equal(t, tt.expectedPath, result) + }) + } +} + +func TestHandleRequestHeaders_PathHandling(t *testing.T) { + // This test verifies the path handling logic in handleRequestHeaders + // including basePathHandling and providerBasePath + + t.Run("basePath_removePrefix_only", func(t *testing.T) { + config := &ProviderConfig{ + basePath: "/gateway", + basePathHandling: basePathHandlingRemovePrefix, + } + // Simulate the logic - actual test would need mock provider + originPath := "/gateway/v1/chat" + expectedPath := "/v1/chat" + result := strings.TrimPrefix(originPath, config.basePath) + assert.Equal(t, expectedPath, result) + }) + + t.Run("basePath_prepend_only", func(t *testing.T) { + config := &ProviderConfig{ + basePath: "/api", + basePathHandling: basePathHandlingPrepend, + } + currentPath := "/v1/chat" + // basePath preprend + providerBasePath (not set) = just basePath effect + // Note: applyProviderBasePath only handles providerBasePath, not basePath + // So this test just verifies that applyProviderBasePath doesn't modify path when providerBasePath is empty + expectedPath := "/v1/chat" // applyProviderBasePath doesn't change path without providerBasePath configured + result := config.applyProviderBasePath(currentPath) + assert.Equal(t, expectedPath, result) + }) + + t.Run("providerBasePath_only", func(t *testing.T) { + config := &ProviderConfig{ + providerBasePath: "/ai-proxy", + } + currentPath := "/v1/chat" + expectedPath := "/ai-proxy/v1/chat" + result := config.applyProviderBasePath(currentPath) + assert.Equal(t, expectedPath, result) + }) + + t.Run("both_basePath_and_providerBasePath", func(t *testing.T) { + config := &ProviderConfig{ + basePath: "/gateway", + basePathHandling: basePathHandlingRemovePrefix, + providerBasePath: "/ai", + } + // First removePrefix, then apply providerBasePath + originPath := "/gateway/v1/chat" + afterRemovePrefix := strings.TrimPrefix(originPath, config.basePath) + finalPath := config.applyProviderBasePath(afterRemovePrefix) + assert.Equal(t, "/ai/v1/chat", finalPath) + }) +} + +func TestProviderConfig_IsOriginal(t *testing.T) { + tests := []struct { + name string + protocol string + expected bool + }{ + {"openai_protocol", protocolOpenAI, false}, + {"original_protocol", protocolOriginal, true}, + {"empty_protocol", "", false}, + {"unknown_protocol", "unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &ProviderConfig{ + protocol: tt.protocol, + } + result := config.IsOriginal() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestProviderConfig_GetPromoteThinkingOnEmpty(t *testing.T) { + tests := []struct { + name string + promoteThinkingOnEmpty bool + expected bool + }{ + {"enabled", true, true}, + {"disabled", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &ProviderConfig{ + promoteThinkingOnEmpty: tt.promoteThinkingOnEmpty, + } + result := config.GetPromoteThinkingOnEmpty() + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============ Failover Tests ============ + +func TestFailover_FromJson_Defaults(t *testing.T) { + t.Run("default_failure_threshold", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(3), f.failureThreshold) + }) + + t.Run("default_success_threshold", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(1), f.successThreshold) + }) + + t.Run("default_health_check_interval", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(5000), f.healthCheckInterval) + }) + + t.Run("default_health_check_timeout", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(5000), f.healthCheckTimeout) + }) + + t.Run("custom_values", func(t *testing.T) { + f := &failover{} + jsonStr := `{ + "enabled": true, + "failureThreshold": 5, + "successThreshold": 3, + "healthCheckInterval": 10000, + "healthCheckTimeout": 8000, + "healthCheckModel": "test-model" + }` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, true, f.enabled) + assert.Equal(t, int64(5), f.failureThreshold) + assert.Equal(t, int64(3), f.successThreshold) + assert.Equal(t, int64(10000), f.healthCheckInterval) + assert.Equal(t, int64(8000), f.healthCheckTimeout) + assert.Equal(t, "test-model", f.healthCheckModel) + }) +} + +func TestFailover_FromJson_FailoverOnStatus(t *testing.T) { + t.Run("parse_failoverOnStatus_array", func(t *testing.T) { + f := &failover{} + jsonStr := `{ + "enabled": true, + "failoverOnStatus": ["401", "403", "5[0-9][0-9]"] + }` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, 3, len(f.failoverOnStatus)) + assert.Contains(t, f.failoverOnStatus, "401") + assert.Contains(t, f.failoverOnStatus, "403") + assert.Contains(t, f.failoverOnStatus, "5[0-9][0-9]") + }) + + t.Run("empty_failoverOnStatus", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + // When failoverOnStatus is not specified, it keeps default values + // Default regex patterns may be set elsewhere + assert.True(t, f.enabled) + assert.Equal(t, int64(3), f.failureThreshold) + }) +} + +func TestHealthCheckEndpoint_Struct(t *testing.T) { + t.Run("health_check_endpoint_fields", func(t *testing.T) { + endpoint := HealthCheckEndpoint{ + Host: "api.example.com", + Path: "/v1/chat/completions", + Cluster: "ai-provider-cluster", + } + assert.Equal(t, "api.example.com", endpoint.Host) + assert.Equal(t, "/v1/chat/completions", endpoint.Path) + assert.Equal(t, "ai-provider-cluster", endpoint.Cluster) + }) +} + +func TestLease_Struct(t *testing.T) { + t.Run("lease_fields", func(t *testing.T) { + lease := Lease{ + VMID: "vm-12345", + Timestamp: 1234567890, + } + assert.Equal(t, "vm-12345", lease.VMID) + assert.Equal(t, int64(1234567890), lease.Timestamp) + }) +} + +func TestFailover_Constants(t *testing.T) { + t.Run("cas_max_retries_value", func(t *testing.T) { + assert.Equal(t, 10, casMaxRetries) + }) + + t.Run("operation_constants", func(t *testing.T) { + assert.Equal(t, "addApiToken", addApiTokenOperation) + assert.Equal(t, "removeApiToken", removeApiTokenOperation) + assert.Equal(t, "addApiTokenRequestCount", addApiTokenRequestCountOperation) + assert.Equal(t, "resetApiTokenRequestCount", resetApiTokenRequestCountOperation) + }) + + t.Run("context_key_constants", func(t *testing.T) { + assert.Equal(t, "requestHost", CtxRequestHost) + assert.Equal(t, "requestPath", CtxRequestPath) + assert.Equal(t, "requestBody", CtxRequestBody) + }) +} + +func TestProviderConfig_TransformRequestHeadersAndBody_PathHandling(t *testing.T) { + // Test that providerBasePath is applied in transformRequestHeadersAndBody + t.Run("providerBasePath_applied", func(t *testing.T) { + config := &ProviderConfig{ + providerBasePath: "/api/ai", + } + + // Test the applyProviderBasePath logic used in transformRequestHeadersAndBody + testPath := "/v1/chat/completions" + expectedPath := "/api/ai/v1/chat/completions" + result := config.applyProviderBasePath(testPath) + assert.Equal(t, expectedPath, result) + }) + + t.Run("providerBasePath_already_present", func(t *testing.T) { + config := &ProviderConfig{ + providerBasePath: "/api/ai", + } + + testPath := "/api/ai/v1/chat/completions" + result := config.applyProviderBasePath(testPath) + // Should not duplicate the prefix + assert.Equal(t, "/api/ai/v1/chat/completions", result) + }) +} + +func TestProviderConfig_IsSupportedAPI(t *testing.T) { + t.Run("supported_api", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameChatCompletion): "/v1/chat/completions", + string(ApiNameEmbeddings): "/v1/embeddings", + }, + } + + result := config.IsSupportedAPI(ApiNameChatCompletion) + assert.True(t, result) + }) + + t.Run("unsupported_api", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameChatCompletion): "/v1/chat/completions", + }, + } + + result := config.IsSupportedAPI(ApiNameEmbeddings) + assert.False(t, result) + }) + + t.Run("empty_capabilities", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: map[string]string{}, + } + + result := config.IsSupportedAPI(ApiNameChatCompletion) + assert.False(t, result) + }) +} + +func TestProviderConfig_SetDefaultCapabilities(t *testing.T) { + t.Run("set_when_nil", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: nil, + } + + defaultCaps := map[string]string{ + string(ApiNameChatCompletion): "/v1/chat/completions", + } + config.setDefaultCapabilities(defaultCaps) + + assert.NotNil(t, config.capabilities) + assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)]) + }) + + t.Run("merge_with_existing", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameEmbeddings): "/v1/embeddings", + }, + } + + defaultCaps := map[string]string{ + string(ApiNameChatCompletion): "/v1/chat/completions", + } + config.setDefaultCapabilities(defaultCaps) + + assert.Equal(t, "/v1/embeddings", config.capabilities[string(ApiNameEmbeddings)]) + assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)]) }) }