mirror of
https://github.com/alibaba/higress.git
synced 2026-05-09 05:17:27 +08:00
feat(ai-proxy): 新增 providerBasePath 配置并优化 providerDomain 处理方式 (#3686)
This commit is contained in:
@@ -300,7 +300,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
// 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数
|
||||
if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
||||
// generic provider 不做能力映射,不添加 stream_options
|
||||
if providerConfig.IsOpenAIProtocol() && !providerConfig.IsGeneric() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
||||
newBody = normalizeOpenAiRequestBody(newBody)
|
||||
}
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
|
||||
@@ -323,8 +323,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
||||
domain := c.config.resolveDomain("", claudeDomain)
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
|
||||
if c.config.apiVersion == "" {
|
||||
c.config.apiVersion = claudeDefaultVersion
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -198,6 +198,11 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders)
|
||||
}
|
||||
|
||||
// Apply providerBasePath if configured
|
||||
if c.providerBasePath != "" {
|
||||
modifiedHeaders.Set(":path", c.applyProviderBasePath(modifiedHeaders.Get(":path")))
|
||||
}
|
||||
|
||||
var err error
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body)
|
||||
|
||||
@@ -62,12 +62,11 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
|
||||
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
domain := config.resolveDomain("", geminiDomain)
|
||||
return &geminiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: domain,
|
||||
Host: geminiDomain,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
@@ -90,8 +89,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
domain := g.config.resolveDomain("", geminiDomain)
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||
}
|
||||
|
||||
@@ -52,6 +52,8 @@ func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// buffer original request body
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -487,6 +487,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN HiClaw 模式
|
||||
// @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。
|
||||
hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"`
|
||||
// @Title zh-CN Provider 基础路径
|
||||
// @Description zh-CN 当配置了此值时,各个 Provider 在改写请求路径时会将其添加到路径前面,例如配置"/api/ai"后,请求路径"/v1/chat/completions"会被改写为"/api/ai/v1/chat/completions"
|
||||
providerBasePath string `required:"false" yaml:"providerBasePath" json:"providerBasePath"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -501,20 +504,6 @@ func (c *ProviderConfig) GetProtocol() string {
|
||||
return c.protocol
|
||||
}
|
||||
|
||||
// resolveDomain resolves the domain to use based on priority:
|
||||
// 1. providerDomain (generic override for all providers)
|
||||
// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain)
|
||||
// 3. default hardcoded domain
|
||||
func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string {
|
||||
if c.providerDomain != "" {
|
||||
return c.providerDomain
|
||||
}
|
||||
if providerSpecificDomain != "" {
|
||||
return providerSpecificDomain
|
||||
}
|
||||
return defaultDomain
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetVllmCustomUrl() string {
|
||||
return c.vllmCustomUrl
|
||||
}
|
||||
@@ -733,6 +722,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.mergeConsecutiveMessages = true
|
||||
c.promoteThinkingOnEmpty = true
|
||||
}
|
||||
c.providerBasePath = json.Get("providerBasePath").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -867,6 +857,10 @@ func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsGeneric() bool {
|
||||
return c.typ == providerTypeGeneric
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool {
|
||||
return c.promoteThinkingOnEmpty
|
||||
}
|
||||
@@ -883,6 +877,14 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
// applyProviderBasePath prepends the ProviderBasePath to the given path if configured.
|
||||
func (c *ProviderConfig) applyProviderBasePath(path string) string {
|
||||
if c.providerBasePath != "" && !strings.HasPrefix(path, c.providerBasePath) {
|
||||
return c.providerBasePath + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
@@ -1220,6 +1222,10 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
// Apply providerBasePath if configured
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(headers.Get(":path")))
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
|
||||
@@ -1276,6 +1282,18 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
||||
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
||||
}
|
||||
|
||||
// Apply providerBasePath if configured
|
||||
currentPath := headers.Get(":path")
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(currentPath))
|
||||
}
|
||||
|
||||
// Apply providerDomain if configured (overrides any domain set by the provider)
|
||||
if c.providerDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, c.providerDomain)
|
||||
}
|
||||
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -290,28 +291,390 @@ func TestProviderDomain_Config(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveDomain_Priority(t *testing.T) {
|
||||
t.Run("providerDomain_takes_priority", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "universal-proxy.com",
|
||||
}
|
||||
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||
assert.Equal(t, "universal-proxy.com", result)
|
||||
func TestProviderBasePath_Config(t *testing.T) {
|
||||
t.Run("providerBasePath_field_exists", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
config.FromJson(gjson.Result{})
|
||||
assert.Equal(t, "", config.providerBasePath)
|
||||
})
|
||||
|
||||
t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "",
|
||||
}
|
||||
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||
assert.Equal(t, "specific-domain.com", result)
|
||||
t.Run("providerBasePath_parsed_from_json", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{"providerBasePath": "/api/ai"}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "/api/ai", config.providerBasePath)
|
||||
})
|
||||
|
||||
t.Run("defaultDomain_when_both_empty", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "",
|
||||
}
|
||||
result := config.resolveDomain("", "default.com")
|
||||
assert.Equal(t, "default.com", result)
|
||||
t.Run("providerBasePath_with_other_config", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{
|
||||
"type": "openai",
|
||||
"apiToken": "sk-test",
|
||||
"providerBasePath": "/api/v1",
|
||||
"providerDomain": "proxy.example.com"
|
||||
}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "openai", config.typ)
|
||||
assert.Equal(t, "/api/v1", config.providerBasePath)
|
||||
assert.Equal(t, "proxy.example.com", config.providerDomain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyProviderBasePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerBasePath string
|
||||
originalPath string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "no_base_path_configured",
|
||||
providerBasePath: "",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "base_path_prepended",
|
||||
providerBasePath: "/api/ai",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/api/ai/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "path_already_has_base_path",
|
||||
providerBasePath: "/api/ai",
|
||||
originalPath: "/api/ai/v1/chat/completions",
|
||||
expectedPath: "/api/ai/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "base_path_with_trailing_slash",
|
||||
providerBasePath: "/api/ai/",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/api/ai//v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "deep_base_path",
|
||||
providerBasePath: "/internal/services/ai",
|
||||
originalPath: "/v1/models",
|
||||
expectedPath: "/internal/services/ai/v1/models",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: tt.providerBasePath,
|
||||
}
|
||||
result := config.applyProviderBasePath(tt.originalPath)
|
||||
assert.Equal(t, tt.expectedPath, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestHeaders_PathHandling(t *testing.T) {
|
||||
// This test verifies the path handling logic in handleRequestHeaders
|
||||
// including basePathHandling and providerBasePath
|
||||
|
||||
t.Run("basePath_removePrefix_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/gateway",
|
||||
basePathHandling: basePathHandlingRemovePrefix,
|
||||
}
|
||||
// Simulate the logic - actual test would need mock provider
|
||||
originPath := "/gateway/v1/chat"
|
||||
expectedPath := "/v1/chat"
|
||||
result := strings.TrimPrefix(originPath, config.basePath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("basePath_prepend_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/api",
|
||||
basePathHandling: basePathHandlingPrepend,
|
||||
}
|
||||
currentPath := "/v1/chat"
|
||||
// basePath preprend + providerBasePath (not set) = just basePath effect
|
||||
// Note: applyProviderBasePath only handles providerBasePath, not basePath
|
||||
// So this test just verifies that applyProviderBasePath doesn't modify path when providerBasePath is empty
|
||||
expectedPath := "/v1/chat" // applyProviderBasePath doesn't change path without providerBasePath configured
|
||||
result := config.applyProviderBasePath(currentPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/ai-proxy",
|
||||
}
|
||||
currentPath := "/v1/chat"
|
||||
expectedPath := "/ai-proxy/v1/chat"
|
||||
result := config.applyProviderBasePath(currentPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("both_basePath_and_providerBasePath", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/gateway",
|
||||
basePathHandling: basePathHandlingRemovePrefix,
|
||||
providerBasePath: "/ai",
|
||||
}
|
||||
// First removePrefix, then apply providerBasePath
|
||||
originPath := "/gateway/v1/chat"
|
||||
afterRemovePrefix := strings.TrimPrefix(originPath, config.basePath)
|
||||
finalPath := config.applyProviderBasePath(afterRemovePrefix)
|
||||
assert.Equal(t, "/ai/v1/chat", finalPath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_IsOriginal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
expected bool
|
||||
}{
|
||||
{"openai_protocol", protocolOpenAI, false},
|
||||
{"original_protocol", protocolOriginal, true},
|
||||
{"empty_protocol", "", false},
|
||||
{"unknown_protocol", "unknown", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
protocol: tt.protocol,
|
||||
}
|
||||
result := config.IsOriginal()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfig_GetPromoteThinkingOnEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
promoteThinkingOnEmpty bool
|
||||
expected bool
|
||||
}{
|
||||
{"enabled", true, true},
|
||||
{"disabled", false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
promoteThinkingOnEmpty: tt.promoteThinkingOnEmpty,
|
||||
}
|
||||
result := config.GetPromoteThinkingOnEmpty()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Failover Tests ============
|
||||
|
||||
func TestFailover_FromJson_Defaults(t *testing.T) {
|
||||
t.Run("default_failure_threshold", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(3), f.failureThreshold)
|
||||
})
|
||||
|
||||
t.Run("default_success_threshold", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(1), f.successThreshold)
|
||||
})
|
||||
|
||||
t.Run("default_health_check_interval", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(5000), f.healthCheckInterval)
|
||||
})
|
||||
|
||||
t.Run("default_health_check_timeout", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(5000), f.healthCheckTimeout)
|
||||
})
|
||||
|
||||
t.Run("custom_values", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"failureThreshold": 5,
|
||||
"successThreshold": 3,
|
||||
"healthCheckInterval": 10000,
|
||||
"healthCheckTimeout": 8000,
|
||||
"healthCheckModel": "test-model"
|
||||
}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, true, f.enabled)
|
||||
assert.Equal(t, int64(5), f.failureThreshold)
|
||||
assert.Equal(t, int64(3), f.successThreshold)
|
||||
assert.Equal(t, int64(10000), f.healthCheckInterval)
|
||||
assert.Equal(t, int64(8000), f.healthCheckTimeout)
|
||||
assert.Equal(t, "test-model", f.healthCheckModel)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
t.Run("parse_failoverOnStatus_array", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"failoverOnStatus": ["401", "403", "5[0-9][0-9]"]
|
||||
}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, 3, len(f.failoverOnStatus))
|
||||
assert.Contains(t, f.failoverOnStatus, "401")
|
||||
assert.Contains(t, f.failoverOnStatus, "403")
|
||||
assert.Contains(t, f.failoverOnStatus, "5[0-9][0-9]")
|
||||
})
|
||||
|
||||
t.Run("empty_failoverOnStatus", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
// When failoverOnStatus is not specified, it keeps default values
|
||||
// Default regex patterns may be set elsewhere
|
||||
assert.True(t, f.enabled)
|
||||
assert.Equal(t, int64(3), f.failureThreshold)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealthCheckEndpoint_Struct(t *testing.T) {
|
||||
t.Run("health_check_endpoint_fields", func(t *testing.T) {
|
||||
endpoint := HealthCheckEndpoint{
|
||||
Host: "api.example.com",
|
||||
Path: "/v1/chat/completions",
|
||||
Cluster: "ai-provider-cluster",
|
||||
}
|
||||
assert.Equal(t, "api.example.com", endpoint.Host)
|
||||
assert.Equal(t, "/v1/chat/completions", endpoint.Path)
|
||||
assert.Equal(t, "ai-provider-cluster", endpoint.Cluster)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLease_Struct(t *testing.T) {
|
||||
t.Run("lease_fields", func(t *testing.T) {
|
||||
lease := Lease{
|
||||
VMID: "vm-12345",
|
||||
Timestamp: 1234567890,
|
||||
}
|
||||
assert.Equal(t, "vm-12345", lease.VMID)
|
||||
assert.Equal(t, int64(1234567890), lease.Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Constants(t *testing.T) {
|
||||
t.Run("cas_max_retries_value", func(t *testing.T) {
|
||||
assert.Equal(t, 10, casMaxRetries)
|
||||
})
|
||||
|
||||
t.Run("operation_constants", func(t *testing.T) {
|
||||
assert.Equal(t, "addApiToken", addApiTokenOperation)
|
||||
assert.Equal(t, "removeApiToken", removeApiTokenOperation)
|
||||
assert.Equal(t, "addApiTokenRequestCount", addApiTokenRequestCountOperation)
|
||||
assert.Equal(t, "resetApiTokenRequestCount", resetApiTokenRequestCountOperation)
|
||||
})
|
||||
|
||||
t.Run("context_key_constants", func(t *testing.T) {
|
||||
assert.Equal(t, "requestHost", CtxRequestHost)
|
||||
assert.Equal(t, "requestPath", CtxRequestPath)
|
||||
assert.Equal(t, "requestBody", CtxRequestBody)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_TransformRequestHeadersAndBody_PathHandling(t *testing.T) {
|
||||
// Test that providerBasePath is applied in transformRequestHeadersAndBody
|
||||
t.Run("providerBasePath_applied", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/api/ai",
|
||||
}
|
||||
|
||||
// Test the applyProviderBasePath logic used in transformRequestHeadersAndBody
|
||||
testPath := "/v1/chat/completions"
|
||||
expectedPath := "/api/ai/v1/chat/completions"
|
||||
result := config.applyProviderBasePath(testPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_already_present", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/api/ai",
|
||||
}
|
||||
|
||||
testPath := "/api/ai/v1/chat/completions"
|
||||
result := config.applyProviderBasePath(testPath)
|
||||
// Should not duplicate the prefix
|
||||
assert.Equal(t, "/api/ai/v1/chat/completions", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_IsSupportedAPI(t *testing.T) {
|
||||
t.Run("supported_api", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
string(ApiNameEmbeddings): "/v1/embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameChatCompletion)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported_api", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameEmbeddings)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("empty_capabilities", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameChatCompletion)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_SetDefaultCapabilities(t *testing.T) {
|
||||
t.Run("set_when_nil", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: nil,
|
||||
}
|
||||
|
||||
defaultCaps := map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
}
|
||||
config.setDefaultCapabilities(defaultCaps)
|
||||
|
||||
assert.NotNil(t, config.capabilities)
|
||||
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
|
||||
})
|
||||
|
||||
t.Run("merge_with_existing", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameEmbeddings): "/v1/embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
defaultCaps := map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
}
|
||||
config.setDefaultCapabilities(defaultCaps)
|
||||
|
||||
assert.Equal(t, "/v1/embeddings", config.capabilities[string(ApiNameEmbeddings)])
|
||||
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user