mirror of
https://github.com/alibaba/higress.git
synced 2026-04-20 03:27:26 +08:00
support gemini & claude domain setting (#3638)
This commit is contained in:
@@ -323,7 +323,8 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
domain := c.config.resolveDomain("", claudeDomain)
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
|
||||
if c.config.apiVersion == "" {
|
||||
c.config.apiVersion = claudeDefaultVersion
|
||||
|
||||
@@ -62,11 +62,12 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
|
||||
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
domain := config.resolveDomain("", geminiDomain)
|
||||
return &geminiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: geminiDomain,
|
||||
Host: domain,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
@@ -89,7 +90,8 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
domain := g.config.resolveDomain("", geminiDomain)
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||
}
|
||||
|
||||
@@ -476,6 +476,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 合并连续同角色消息
|
||||
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||
// @Title zh-CN 通用 Provider 域名
|
||||
// @Description zh-CN 通用的 Provider 服务域名配置,适用于所有 Provider。当配置此字段时,将优先使用此域名覆盖默认的硬编码域名。常用于代理服务器场景
|
||||
providerDomain string `required:"false" yaml:"providerDomain" json:"providerDomain"`
|
||||
// @Title zh-CN 空内容时提升思考为正文
|
||||
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
|
||||
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
|
||||
@@ -496,6 +499,20 @@ func (c *ProviderConfig) GetProtocol() string {
|
||||
return c.protocol
|
||||
}
|
||||
|
||||
// resolveDomain resolves the domain to use based on priority:
|
||||
// 1. providerDomain (generic override for all providers)
|
||||
// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain)
|
||||
// 3. default hardcoded domain
|
||||
func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string {
|
||||
if c.providerDomain != "" {
|
||||
return c.providerDomain
|
||||
}
|
||||
if providerSpecificDomain != "" {
|
||||
return providerSpecificDomain
|
||||
}
|
||||
return defaultDomain
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetVllmCustomUrl() string {
|
||||
return c.vllmCustomUrl
|
||||
}
|
||||
@@ -707,6 +724,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
}
|
||||
}
|
||||
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||
c.providerDomain = json.Get("providerDomain").String()
|
||||
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
|
||||
c.hiclawMode = json.Get("hiclawMode").Bool()
|
||||
if c.hiclawMode {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestIsStatefulAPI(t *testing.T) {
|
||||
@@ -132,11 +133,11 @@ func TestIsStatefulAPI(t *testing.T) {
|
||||
|
||||
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiTokens []string
|
||||
consumer string
|
||||
wantEmpty bool
|
||||
wantToken string // If not empty, expected specific token (for single token case)
|
||||
name string
|
||||
apiTokens []string
|
||||
consumer string
|
||||
wantEmpty bool
|
||||
wantToken string // If not empty, expected specific token (for single token case)
|
||||
}{
|
||||
{
|
||||
name: "no_tokens_returns_empty",
|
||||
@@ -273,3 +274,44 @@ func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDomain_Config(t *testing.T) {
|
||||
t.Run("providerDomain_field_exists", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
config.FromJson(gjson.Result{})
|
||||
assert.Equal(t, "", config.providerDomain)
|
||||
})
|
||||
|
||||
t.Run("providerDomain_parsed_from_json", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{"providerDomain": "universal-proxy.example.com"}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "universal-proxy.example.com", config.providerDomain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveDomain_Priority(t *testing.T) {
|
||||
t.Run("providerDomain_takes_priority", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "universal-proxy.com",
|
||||
}
|
||||
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||
assert.Equal(t, "universal-proxy.com", result)
|
||||
})
|
||||
|
||||
t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "",
|
||||
}
|
||||
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||
assert.Equal(t, "specific-domain.com", result)
|
||||
})
|
||||
|
||||
t.Run("defaultDomain_when_both_empty", func(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
providerDomain: "",
|
||||
}
|
||||
result := config.resolveDomain("", "default.com")
|
||||
assert.Equal(t, "default.com", result)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user