mirror of
https://github.com/alibaba/higress.git
synced 2026-05-26 05:37:25 +08:00
feat(ai-proxy): 新增 providerBasePath 配置并优化 providerDomain 处理方式 (#3686)
This commit is contained in:
@@ -487,6 +487,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN HiClaw 模式
|
||||
// @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。
|
||||
hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"`
|
||||
// @Title zh-CN Provider 基础路径
|
||||
// @Description zh-CN 当配置了此值时,各个 Provider 在改写请求路径时会将其添加到路径前面,例如配置"/api/ai"后,请求路径"/v1/chat/completions"会被改写为"/api/ai/v1/chat/completions"
|
||||
providerBasePath string `required:"false" yaml:"providerBasePath" json:"providerBasePath"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -501,20 +504,6 @@ func (c *ProviderConfig) GetProtocol() string {
|
||||
return c.protocol
|
||||
}
|
||||
|
||||
// resolveDomain resolves the domain to use based on priority:
|
||||
// 1. providerDomain (generic override for all providers)
|
||||
// 2. provider-specific domain config (e.g., geminiDomain, doubaoDomain)
|
||||
// 3. default hardcoded domain
|
||||
func (c *ProviderConfig) resolveDomain(providerSpecificDomain, defaultDomain string) string {
|
||||
if c.providerDomain != "" {
|
||||
return c.providerDomain
|
||||
}
|
||||
if providerSpecificDomain != "" {
|
||||
return providerSpecificDomain
|
||||
}
|
||||
return defaultDomain
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetVllmCustomUrl() string {
|
||||
return c.vllmCustomUrl
|
||||
}
|
||||
@@ -733,6 +722,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.mergeConsecutiveMessages = true
|
||||
c.promoteThinkingOnEmpty = true
|
||||
}
|
||||
c.providerBasePath = json.Get("providerBasePath").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -867,6 +857,10 @@ func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsGeneric() bool {
|
||||
return c.typ == providerTypeGeneric
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool {
|
||||
return c.promoteThinkingOnEmpty
|
||||
}
|
||||
@@ -883,6 +877,14 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
// applyProviderBasePath prepends the ProviderBasePath to the given path if configured.
|
||||
func (c *ProviderConfig) applyProviderBasePath(path string) string {
|
||||
if c.providerBasePath != "" && !strings.HasPrefix(path, c.providerBasePath) {
|
||||
return c.providerBasePath + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
@@ -1220,6 +1222,10 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
// Apply providerBasePath if configured
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(headers.Get(":path")))
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
|
||||
@@ -1276,6 +1282,18 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
||||
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
||||
}
|
||||
|
||||
// Apply providerBasePath if configured
|
||||
currentPath := headers.Get(":path")
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(currentPath))
|
||||
}
|
||||
|
||||
// Apply providerDomain if configured (overrides any domain set by the provider)
|
||||
if c.providerDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, c.providerDomain)
|
||||
}
|
||||
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user