mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
support gemini & claude domain setting (#3638)
This commit is contained in:
@@ -323,7 +323,8 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
|||||||
|
|
||||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
||||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
domain := c.config.resolveDomain("", claudeDomain)
|
||||||
|
util.OverwriteRequestHostHeader(headers, domain)
|
||||||
|
|
||||||
if c.config.apiVersion == "" {
|
if c.config.apiVersion == "" {
|
||||||
c.config.apiVersion = claudeDefaultVersion
|
c.config.apiVersion = claudeDefaultVersion
|
||||||
|
|||||||
@@ -62,11 +62,12 @@ func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
|||||||
|
|
||||||
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||||
|
domain := config.resolveDomain("", geminiDomain)
|
||||||
return &geminiProvider{
|
return &geminiProvider{
|
||||||
config: config,
|
config: config,
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||||
Host: geminiDomain,
|
Host: domain,
|
||||||
}),
|
}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -89,7 +90,8 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
domain := g.config.resolveDomain("", geminiDomain)
|
||||||
|
util.OverwriteRequestHostHeader(headers, domain)
|
||||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -476,6 +476,9 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN 合并连续同角色消息
|
// @Title zh-CN 合并连续同角色消息
|
||||||
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||||
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||||
|
// @Title zh-CN 通用 Provider 域名
|
||||||
|
// @Description zh-CN 通用的 Provider 服务域名配置,适用于所有 Provider。当配置此字段时,将优先使用此域名覆盖默认的硬编码域名。常用于代理服务器场景
|
||||||
|
providerDomain string `required:"false" yaml:"providerDomain" json:"providerDomain"`
|
||||||
// @Title zh-CN 空内容时提升思考为正文
|
// @Title zh-CN 空内容时提升思考为正文
|
||||||
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
|
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
|
||||||
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
|
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
|
||||||
@@ -496,6 +499,20 @@ func (c *ProviderConfig) GetProtocol() string {
|
|||||||
return c.protocol
|
return c.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveDomain resolves the domain to use based on priority:
|
||||||
|
// 1. providerDomain (generic override for all providers)
|
||||||
|
// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain)
|
||||||
|
// 3. default hardcoded domain
|
||||||
|
func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string {
|
||||||
|
if c.providerDomain != "" {
|
||||||
|
return c.providerDomain
|
||||||
|
}
|
||||||
|
if providerSpecificDomain != "" {
|
||||||
|
return providerSpecificDomain
|
||||||
|
}
|
||||||
|
return defaultDomain
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetVllmCustomUrl() string {
|
func (c *ProviderConfig) GetVllmCustomUrl() string {
|
||||||
return c.vllmCustomUrl
|
return c.vllmCustomUrl
|
||||||
}
|
}
|
||||||
@@ -707,6 +724,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||||
|
c.providerDomain = json.Get("providerDomain").String()
|
||||||
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
|
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
|
||||||
c.hiclawMode = json.Get("hiclawMode").Bool()
|
c.hiclawMode = json.Get("hiclawMode").Bool()
|
||||||
if c.hiclawMode {
|
if c.hiclawMode {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsStatefulAPI(t *testing.T) {
|
func TestIsStatefulAPI(t *testing.T) {
|
||||||
@@ -132,11 +133,11 @@ func TestIsStatefulAPI(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
apiTokens []string
|
apiTokens []string
|
||||||
consumer string
|
consumer string
|
||||||
wantEmpty bool
|
wantEmpty bool
|
||||||
wantToken string // If not empty, expected specific token (for single token case)
|
wantToken string // If not empty, expected specific token (for single token case)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "no_tokens_returns_empty",
|
name: "no_tokens_returns_empty",
|
||||||
@@ -273,3 +274,44 @@ func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProviderDomain_Config(t *testing.T) {
|
||||||
|
t.Run("providerDomain_field_exists", func(t *testing.T) {
|
||||||
|
config := ProviderConfig{}
|
||||||
|
config.FromJson(gjson.Result{})
|
||||||
|
assert.Equal(t, "", config.providerDomain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("providerDomain_parsed_from_json", func(t *testing.T) {
|
||||||
|
config := ProviderConfig{}
|
||||||
|
jsonStr := `{"providerDomain": "universal-proxy.example.com"}`
|
||||||
|
config.FromJson(gjson.Parse(jsonStr))
|
||||||
|
assert.Equal(t, "universal-proxy.example.com", config.providerDomain)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveDomain_Priority(t *testing.T) {
|
||||||
|
t.Run("providerDomain_takes_priority", func(t *testing.T) {
|
||||||
|
config := ProviderConfig{
|
||||||
|
providerDomain: "universal-proxy.com",
|
||||||
|
}
|
||||||
|
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||||
|
assert.Equal(t, "universal-proxy.com", result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("providerSpecificDomain_when_providerDomain_empty", func(t *testing.T) {
|
||||||
|
config := ProviderConfig{
|
||||||
|
providerDomain: "",
|
||||||
|
}
|
||||||
|
result := config.resolveDomain("specific-domain.com", "default.com")
|
||||||
|
assert.Equal(t, "specific-domain.com", result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("defaultDomain_when_both_empty", func(t *testing.T) {
|
||||||
|
config := ProviderConfig{
|
||||||
|
providerDomain: "",
|
||||||
|
}
|
||||||
|
result := config.resolveDomain("", "default.com")
|
||||||
|
assert.Equal(t, "default.com", result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user