mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
optimize ai proxy (#1603)
This commit is contained in:
@@ -89,16 +89,21 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
}
|
}
|
||||||
|
|
||||||
if apiName == "" {
|
if apiName == "" {
|
||||||
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
||||||
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
|
|
||||||
log.Debugf("[onHttpRequestHeader] no send response")
|
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||||
|
ctx.DisableReroute()
|
||||||
|
|
||||||
ctx.SetContext(ctxKeyApiName, apiName)
|
ctx.SetContext(ctxKeyApiName, apiName)
|
||||||
|
|
||||||
|
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
|
||||||
|
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
||||||
|
if needHandleBody || needHandleStreamingBody {
|
||||||
|
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
|
||||||
ctx.DisableReroute()
|
|
||||||
// Set the apiToken for the current request.
|
// Set the apiToken for the current request.
|
||||||
providerConfig.SetApiTokenInUse(ctx, log)
|
providerConfig.SetApiTokenInUse(ctx, log)
|
||||||
|
|
||||||
@@ -106,11 +111,12 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if hasRequestBody {
|
if hasRequestBody {
|
||||||
|
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||||
// Always return types.HeaderStopIteration to support fallback routing,
|
// Delay the header processing to allow changing in OnRequestBody
|
||||||
// as long as onHttpRequestBody can be called.
|
|
||||||
return types.HeaderStopIteration
|
return types.HeaderStopIteration
|
||||||
}
|
}
|
||||||
|
ctx.DontReadRequestBody()
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
|
|
||||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||||
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
|
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
|||||||
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
|
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
|
||||||
}
|
}
|
||||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
|
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Content-Length")
|
headers.Del("Content-Length")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,15 +114,13 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
|||||||
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
|
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
|
||||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||||
|
|
||||||
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
|
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||||
|
|
||||||
if c.config.claudeVersion == "" {
|
if c.config.claudeVersion == "" {
|
||||||
c.config.claudeVersion = defaultVersion
|
c.config.claudeVersion = defaultVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
headers.Add("anthropic-version", c.config.claudeVersion)
|
headers.Set("anthropic-version", c.config.claudeVersion)
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||||
|
|||||||
@@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
|
|||||||
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||||
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
|
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
|
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,8 +87,6 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
|||||||
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||||
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
|
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
|
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Content-Length")
|
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||||
|
|||||||
@@ -62,9 +62,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
|||||||
|
|
||||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||||
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||||
|
|||||||
@@ -68,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
|||||||
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
|
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
|
||||||
}
|
}
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *githubProvider) GetApiName(path string) ApiName {
|
func (m *githubProvider) GetApiName(path string) ApiName {
|
||||||
|
|||||||
@@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
|||||||
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
||||||
|
|
||||||
// 添加 hunyuan 需要的自定义字段
|
// 添加 hunyuan 需要的自定义字段
|
||||||
headers.Add(actionKey, hunyuanChatCompletionTCAction)
|
headers.Set(actionKey, hunyuanChatCompletionTCAction)
|
||||||
headers.Add(versionKey, versionValue)
|
headers.Set(versionKey, versionValue)
|
||||||
|
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
||||||
|
|||||||
@@ -80,9 +80,6 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
|||||||
} else if apiName == ApiNameEmbeddings {
|
} else if apiName == ApiNameEmbeddings {
|
||||||
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
|
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||||
@@ -109,7 +106,6 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -177,6 +177,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
|||||||
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
|
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
|
||||||
util.OverwriteRequestHostHeader(headers, sparkHost)
|
util.OverwriteRequestHostHeader(headers, sparkHost)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Accept-Encoding")
|
|
||||||
headers.Del("Content-Length")
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user