support gemini & claude domain setting (#3638)

This commit is contained in:
rinfx
2026-03-26 11:12:30 +08:00
committed by GitHub
parent 3fc01913cf
commit 231ba1cd23
4 changed files with 71 additions and 8 deletions

View File

@@ -323,7 +323,8 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain)
domain := c.config.resolveDomain("", claudeDomain)
util.OverwriteRequestHostHeader(headers, domain)
if c.config.apiVersion == "" {
c.config.apiVersion = claudeDefaultVersion

View File

@@ -62,11 +62,12 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
domain := config.resolveDomain("", geminiDomain)
return &geminiProvider{
config: config,
contextCache: createContextCache(&config),
client: wrapper.NewClusterClient(wrapper.RouteCluster{
Host: geminiDomain,
Host: domain,
}),
}, nil
}
@@ -89,7 +90,8 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
domain := g.config.resolveDomain("", geminiDomain)
util.OverwriteRequestHostHeader(headers, domain)
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
util.OverwriteRequestAuthorizationHeader(headers, "")
}

View File

@@ -476,6 +476,9 @@ type ProviderConfig struct {
// @Title zh-CN 合并连续同角色消息
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息将其内容合并为一条以满足要求严格轮流交替user→assistant→user→...)的模型服务商的要求。
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
// @Title zh-CN 通用 Provider 域名
// @Description zh-CN 通用的 Provider 服务域名配置,适用于所有 Provider。当配置此字段时将优先使用此域名覆盖默认的硬编码域名。常用于代理服务器场景
providerDomain string `required:"false" yaml:"providerDomain" json:"providerDomain"`
// @Title zh-CN 空内容时提升思考为正文
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
@@ -496,6 +499,20 @@ func (c *ProviderConfig) GetProtocol() string {
return c.protocol
}
// resolveDomain resolves the domain to use based on priority:
// 1. providerDomain (generic override for all providers)
// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain)
// 3. default hardcoded domain
func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string {
if c.providerDomain != "" {
return c.providerDomain
}
if providerSpecificDomain != "" {
return providerSpecificDomain
}
return defaultDomain
}
func (c *ProviderConfig) GetVllmCustomUrl() string {
return c.vllmCustomUrl
}
@@ -707,6 +724,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
c.providerDomain = json.Get("providerDomain").String()
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
c.hiclawMode = json.Get("hiclawMode").Bool()
if c.hiclawMode {

View File

@@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
func TestIsStatefulAPI(t *testing.T) {
@@ -132,11 +133,11 @@ func TestIsStatefulAPI(t *testing.T) {
func TestGetTokenWithConsumerAffinity(t *testing.T) {
tests := []struct {
name string
apiTokens []string
consumer string
wantEmpty bool
wantToken string // If not empty, expected specific token (for single token case)
name string
apiTokens []string
consumer string
wantEmpty bool
wantToken string // If not empty, expected specific token (for single token case)
}{
{
name: "no_tokens_returns_empty",
@@ -273,3 +274,44 @@ func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
})
}
}
func TestProviderDomain_Config(t *testing.T) {
t.Run("providerDomain_field_exists", func(t *testing.T) {
config := ProviderConfig{}
config.FromJson(gjson.Result{})
assert.Equal(t, "", config.providerDomain)
})
t.Run("providerDomain_parsed_from_json", func(t *testing.T) {
config := ProviderConfig{}
jsonStr := `{"providerDomain": "universal-proxy.example.com"}`
config.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, "universal-proxy.example.com", config.providerDomain)
})
}
func TestResolveDomain_Priority(t *testing.T) {
t.Run("providerDomain_takes_priority", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "universal-proxy.com",
}
result := config.resolveDomain("specific-domain.com", "default.com")
assert.Equal(t, "universal-proxy.com", result)
})
t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "",
}
result := config.resolveDomain("specific-domain.com", "default.com")
assert.Equal(t, "specific-domain.com", result)
})
t.Run("defaultDomain_when_both_empty", func(t *testing.T) {
config := ProviderConfig{
providerDomain: "",
}
result := config.resolveDomain("", "default.com")
assert.Equal(t, "default.com", result)
})
}