feat(ai-proxy): 新增 providerBasePath 配置并优化 providerDomain 处理方式 (#3686)

This commit is contained in:
rinfx
2026-04-08 15:23:15 +08:00
committed by GitHub
parent 1c9e981bf2
commit 228eb27e6a
7 changed files with 428 additions and 42 deletions

View File

@@ -1,6 +1,7 @@
package provider
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -290,28 +291,390 @@ func TestProviderDomain_Config(t *testing.T) {
})
}
func TestResolveDomain_Priority(t *testing.T) {
t.Run("providerDomain_takes_priority", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "universal-proxy.com",
}
result := config.resolveDomain("specific-domain.com", "default.com")
assert.Equal(t, "universal-proxy.com", result)
func TestProviderBasePath_Config(t *testing.T) {
t.Run("providerBasePath_field_exists", func(t *testing.T) {
config := ProviderConfig{}
config.FromJson(gjson.Result{})
assert.Equal(t, "", config.providerBasePath)
})
t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "",
}
result := config.resolveDomain("specific-domain.com", "default.com")
assert.Equal(t, "specific-domain.com", result)
t.Run("providerBasePath_parsed_from_json", func(t *testing.T) {
config := ProviderConfig{}
jsonStr := `{"providerBasePath": "/api/ai"}`
config.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, "/api/ai", config.providerBasePath)
})
t.Run("defaultDomain_when_both_empty", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "",
}
result := config.resolveDomain("", "default.com")
assert.Equal(t, "default.com", result)
t.Run("providerBasePath_with_other_config", func(t *testing.T) {
config := ProviderConfig{}
jsonStr := `{
"type": "openai",
"apiToken": "sk-test",
"providerBasePath": "/api/v1",
"providerDomain": "proxy.example.com"
}`
config.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, "openai", config.typ)
assert.Equal(t, "/api/v1", config.providerBasePath)
assert.Equal(t, "proxy.example.com", config.providerDomain)
})
}
func TestApplyProviderBasePath(t *testing.T) {
tests := []struct {
name string
providerBasePath string
originalPath string
expectedPath string
}{
{
name: "no_base_path_configured",
providerBasePath: "",
originalPath: "/v1/chat/completions",
expectedPath: "/v1/chat/completions",
},
{
name: "base_path_prepended",
providerBasePath: "/api/ai",
originalPath: "/v1/chat/completions",
expectedPath: "/api/ai/v1/chat/completions",
},
{
name: "path_already_has_base_path",
providerBasePath: "/api/ai",
originalPath: "/api/ai/v1/chat/completions",
expectedPath: "/api/ai/v1/chat/completions",
},
{
name: "base_path_with_trailing_slash",
providerBasePath: "/api/ai/",
originalPath: "/v1/chat/completions",
expectedPath: "/api/ai//v1/chat/completions",
},
{
name: "deep_base_path",
providerBasePath: "/internal/services/ai",
originalPath: "/v1/models",
expectedPath: "/internal/services/ai/v1/models",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &ProviderConfig{
providerBasePath: tt.providerBasePath,
}
result := config.applyProviderBasePath(tt.originalPath)
assert.Equal(t, tt.expectedPath, result)
})
}
}
func TestHandleRequestHeaders_PathHandling(t *testing.T) {
// This test verifies the path handling logic in handleRequestHeaders
// including basePathHandling and providerBasePath
t.Run("basePath_removePrefix_only", func(t *testing.T) {
config := &ProviderConfig{
basePath: "/gateway",
basePathHandling: basePathHandlingRemovePrefix,
}
// Simulate the logic - actual test would need mock provider
originPath := "/gateway/v1/chat"
expectedPath := "/v1/chat"
result := strings.TrimPrefix(originPath, config.basePath)
assert.Equal(t, expectedPath, result)
})
t.Run("basePath_prepend_only", func(t *testing.T) {
config := &ProviderConfig{
basePath: "/api",
basePathHandling: basePathHandlingPrepend,
}
currentPath := "/v1/chat"
// basePath preprend + providerBasePath (not set) = just basePath effect
// Note: applyProviderBasePath only handles providerBasePath, not basePath
// So this test just verifies that applyProviderBasePath doesn't modify path when providerBasePath is empty
expectedPath := "/v1/chat" // applyProviderBasePath doesn't change path without providerBasePath configured
result := config.applyProviderBasePath(currentPath)
assert.Equal(t, expectedPath, result)
})
t.Run("providerBasePath_only", func(t *testing.T) {
config := &ProviderConfig{
providerBasePath: "/ai-proxy",
}
currentPath := "/v1/chat"
expectedPath := "/ai-proxy/v1/chat"
result := config.applyProviderBasePath(currentPath)
assert.Equal(t, expectedPath, result)
})
t.Run("both_basePath_and_providerBasePath", func(t *testing.T) {
config := &ProviderConfig{
basePath: "/gateway",
basePathHandling: basePathHandlingRemovePrefix,
providerBasePath: "/ai",
}
// First removePrefix, then apply providerBasePath
originPath := "/gateway/v1/chat"
afterRemovePrefix := strings.TrimPrefix(originPath, config.basePath)
finalPath := config.applyProviderBasePath(afterRemovePrefix)
assert.Equal(t, "/ai/v1/chat", finalPath)
})
}
func TestProviderConfig_IsOriginal(t *testing.T) {
tests := []struct {
name string
protocol string
expected bool
}{
{"openai_protocol", protocolOpenAI, false},
{"original_protocol", protocolOriginal, true},
{"empty_protocol", "", false},
{"unknown_protocol", "unknown", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &ProviderConfig{
protocol: tt.protocol,
}
result := config.IsOriginal()
assert.Equal(t, tt.expected, result)
})
}
}
func TestProviderConfig_GetPromoteThinkingOnEmpty(t *testing.T) {
tests := []struct {
name string
promoteThinkingOnEmpty bool
expected bool
}{
{"enabled", true, true},
{"disabled", false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &ProviderConfig{
promoteThinkingOnEmpty: tt.promoteThinkingOnEmpty,
}
result := config.GetPromoteThinkingOnEmpty()
assert.Equal(t, tt.expected, result)
})
}
}
// ============ Failover Tests ============
func TestFailover_FromJson_Defaults(t *testing.T) {
t.Run("default_failure_threshold", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(3), f.failureThreshold)
})
t.Run("default_success_threshold", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(1), f.successThreshold)
})
t.Run("default_health_check_interval", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(5000), f.healthCheckInterval)
})
t.Run("default_health_check_timeout", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(5000), f.healthCheckTimeout)
})
t.Run("custom_values", func(t *testing.T) {
f := &failover{}
jsonStr := `{
"enabled": true,
"failureThreshold": 5,
"successThreshold": 3,
"healthCheckInterval": 10000,
"healthCheckTimeout": 8000,
"healthCheckModel": "test-model"
}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, true, f.enabled)
assert.Equal(t, int64(5), f.failureThreshold)
assert.Equal(t, int64(3), f.successThreshold)
assert.Equal(t, int64(10000), f.healthCheckInterval)
assert.Equal(t, int64(8000), f.healthCheckTimeout)
assert.Equal(t, "test-model", f.healthCheckModel)
})
}
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
t.Run("parse_failoverOnStatus_array", func(t *testing.T) {
f := &failover{}
jsonStr := `{
"enabled": true,
"failoverOnStatus": ["401", "403", "5[0-9][0-9]"]
}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, 3, len(f.failoverOnStatus))
assert.Contains(t, f.failoverOnStatus, "401")
assert.Contains(t, f.failoverOnStatus, "403")
assert.Contains(t, f.failoverOnStatus, "5[0-9][0-9]")
})
t.Run("empty_failoverOnStatus", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
// When failoverOnStatus is not specified, it keeps default values
// Default regex patterns may be set elsewhere
assert.True(t, f.enabled)
assert.Equal(t, int64(3), f.failureThreshold)
})
}
func TestHealthCheckEndpoint_Struct(t *testing.T) {
t.Run("health_check_endpoint_fields", func(t *testing.T) {
endpoint := HealthCheckEndpoint{
Host: "api.example.com",
Path: "/v1/chat/completions",
Cluster: "ai-provider-cluster",
}
assert.Equal(t, "api.example.com", endpoint.Host)
assert.Equal(t, "/v1/chat/completions", endpoint.Path)
assert.Equal(t, "ai-provider-cluster", endpoint.Cluster)
})
}
func TestLease_Struct(t *testing.T) {
t.Run("lease_fields", func(t *testing.T) {
lease := Lease{
VMID: "vm-12345",
Timestamp: 1234567890,
}
assert.Equal(t, "vm-12345", lease.VMID)
assert.Equal(t, int64(1234567890), lease.Timestamp)
})
}
func TestFailover_Constants(t *testing.T) {
t.Run("cas_max_retries_value", func(t *testing.T) {
assert.Equal(t, 10, casMaxRetries)
})
t.Run("operation_constants", func(t *testing.T) {
assert.Equal(t, "addApiToken", addApiTokenOperation)
assert.Equal(t, "removeApiToken", removeApiTokenOperation)
assert.Equal(t, "addApiTokenRequestCount", addApiTokenRequestCountOperation)
assert.Equal(t, "resetApiTokenRequestCount", resetApiTokenRequestCountOperation)
})
t.Run("context_key_constants", func(t *testing.T) {
assert.Equal(t, "requestHost", CtxRequestHost)
assert.Equal(t, "requestPath", CtxRequestPath)
assert.Equal(t, "requestBody", CtxRequestBody)
})
}
func TestProviderConfig_TransformRequestHeadersAndBody_PathHandling(t *testing.T) {
// Test that providerBasePath is applied in transformRequestHeadersAndBody
t.Run("providerBasePath_applied", func(t *testing.T) {
config := &ProviderConfig{
providerBasePath: "/api/ai",
}
// Test the applyProviderBasePath logic used in transformRequestHeadersAndBody
testPath := "/v1/chat/completions"
expectedPath := "/api/ai/v1/chat/completions"
result := config.applyProviderBasePath(testPath)
assert.Equal(t, expectedPath, result)
})
t.Run("providerBasePath_already_present", func(t *testing.T) {
config := &ProviderConfig{
providerBasePath: "/api/ai",
}
testPath := "/api/ai/v1/chat/completions"
result := config.applyProviderBasePath(testPath)
// Should not duplicate the prefix
assert.Equal(t, "/api/ai/v1/chat/completions", result)
})
}
func TestProviderConfig_IsSupportedAPI(t *testing.T) {
t.Run("supported_api", func(t *testing.T) {
config := &ProviderConfig{
capabilities: map[string]string{
string(ApiNameChatCompletion): "/v1/chat/completions",
string(ApiNameEmbeddings): "/v1/embeddings",
},
}
result := config.IsSupportedAPI(ApiNameChatCompletion)
assert.True(t, result)
})
t.Run("unsupported_api", func(t *testing.T) {
config := &ProviderConfig{
capabilities: map[string]string{
string(ApiNameChatCompletion): "/v1/chat/completions",
},
}
result := config.IsSupportedAPI(ApiNameEmbeddings)
assert.False(t, result)
})
t.Run("empty_capabilities", func(t *testing.T) {
config := &ProviderConfig{
capabilities: map[string]string{},
}
result := config.IsSupportedAPI(ApiNameChatCompletion)
assert.False(t, result)
})
}
func TestProviderConfig_SetDefaultCapabilities(t *testing.T) {
t.Run("set_when_nil", func(t *testing.T) {
config := &ProviderConfig{
capabilities: nil,
}
defaultCaps := map[string]string{
string(ApiNameChatCompletion): "/v1/chat/completions",
}
config.setDefaultCapabilities(defaultCaps)
assert.NotNil(t, config.capabilities)
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
})
t.Run("merge_with_existing", func(t *testing.T) {
config := &ProviderConfig{
capabilities: map[string]string{
string(ApiNameEmbeddings): "/v1/embeddings",
},
}
defaultCaps := map[string]string{
string(ApiNameChatCompletion): "/v1/chat/completions",
}
config.setDefaultCapabilities(defaultCaps)
assert.Equal(t, "/v1/embeddings", config.capabilities[string(ApiNameEmbeddings)])
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
})
}