From 231ba1cd23734bf5d4886a56eb6382162717770e Mon Sep 17 00:00:00 2001 From: rinfx Date: Thu, 26 Mar 2026 11:12:30 +0800 Subject: [PATCH] support gemini & claude domain setting (#3638) --- .../extensions/ai-proxy/provider/claude.go | 3 +- .../extensions/ai-proxy/provider/gemini.go | 6 ++- .../extensions/ai-proxy/provider/provider.go | 18 +++++++ .../ai-proxy/provider/provider_test.go | 52 +++++++++++++++++-- 4 files changed, 71 insertions(+), 8 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 4b763ce75..5c759892a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -323,7 +323,8 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities) - util.OverwriteRequestHostHeader(headers, claudeDomain) + domain := c.config.resolveDomain("", claudeDomain) + util.OverwriteRequestHostHeader(headers, domain) if c.config.apiVersion == "" { c.config.apiVersion = claudeDefaultVersion diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 70a873e40..1d45328f9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -62,11 +62,12 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { config.setDefaultCapabilities(g.DefaultCapabilities()) + domain := config.resolveDomain("", geminiDomain) return &geminiProvider{ config: config, contextCache: createContextCache(&config), client: wrapper.NewClusterClient(wrapper.RouteCluster{ - Host: geminiDomain, + Host: domain, }), }, nil } @@ -89,7 +90,8 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { - util.OverwriteRequestHostHeader(headers, geminiDomain) + domain := g.config.resolveDomain("", geminiDomain) + util.OverwriteRequestHostHeader(headers, domain) headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 54e9203e4..351a96a29 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -476,6 +476,9 @@ type ProviderConfig struct { // @Title zh-CN 合并连续同角色消息 // @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。 mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"` + // @Title zh-CN 通用 Provider 域名 + // @Description zh-CN 通用的 Provider 服务域名配置,适用于所有 Provider。当配置此字段时,将优先使用此域名覆盖默认的硬编码域名。常用于代理服务器场景 + providerDomain string `required:"false" yaml:"providerDomain" json:"providerDomain"` // @Title zh-CN 空内容时提升思考为正文 // @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。 promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"` @@ -496,6 +499,20 @@ func (c *ProviderConfig) GetProtocol() string { return c.protocol } +// resolveDomain resolves the domain to use based on priority: +// 1. providerDomain (generic override for all providers) +// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain) +// 3. default hardcoded domain +func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string { + if c.providerDomain != "" { + return c.providerDomain + } + if providerSpecificDomain != "" { + return providerSpecificDomain + } + return defaultDomain +} + func (c *ProviderConfig) GetVllmCustomUrl() string { return c.vllmCustomUrl } @@ -707,6 +724,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } } c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool() + c.providerDomain = json.Get("providerDomain").String() c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool() c.hiclawMode = json.Get("hiclawMode").Bool() if c.hiclawMode { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go index c061c438d..2c591716a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" ) func TestIsStatefulAPI(t *testing.T) { @@ -132,11 +133,11 @@ func TestIsStatefulAPI(t *testing.T) { func TestGetTokenWithConsumerAffinity(t *testing.T) { tests := []struct { - name string - apiTokens []string - consumer string - wantEmpty bool - wantToken string // If not empty, expected specific token (for single token case) + name string + apiTokens []string + consumer string + wantEmpty bool + wantToken string // If not empty, expected specific token (for single token case) }{ { name: "no_tokens_returns_empty", @@ -273,3 +274,44 @@ func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) { }) } } + +func TestProviderDomain_Config(t *testing.T) { + t.Run("providerDomain_field_exists", func(t *testing.T) { + config := ProviderConfig{} + config.FromJson(gjson.Result{}) + assert.Equal(t, "", config.providerDomain) + }) + + t.Run("providerDomain_parsed_from_json", func(t *testing.T) { + config := ProviderConfig{} + jsonStr := `{"providerDomain": "universal-proxy.example.com"}` + config.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, "universal-proxy.example.com", config.providerDomain) + }) +} + +func TestResolveDomain_Priority(t *testing.T) { + t.Run("providerDomain_takes_priority", func(t *testing.T) { + config := ProviderConfig{ + providerDomain: "universal-proxy.com", + } + result := config.resolveDomain("specific-domain.com", "default.com") + assert.Equal(t, "universal-proxy.com", result) + }) + + t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) { + config := ProviderConfig{ + providerDomain: "", + } + result := config.resolveDomain("specific-domain.com", "default.com") + assert.Equal(t, "specific-domain.com", result) + }) + + t.Run("defaultDomain_when_both_empty", func(t *testing.T) { + config := ProviderConfig{ + providerDomain: "", + } + result := config.resolveDomain("", "default.com") + assert.Equal(t, "default.com", result) + }) +}