From 39c007d0450a2f15f9339c39e1abde48afd2d192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Thu, 19 Dec 2024 16:22:35 +0800 Subject: [PATCH] optimize ai proxy (#1603) --- plugins/wasm-go/extensions/ai-proxy/main.go | 20 ++++++++++++------- .../extensions/ai-proxy/provider/ai360.go | 4 +--- .../extensions/ai-proxy/provider/azure.go | 2 +- .../extensions/ai-proxy/provider/claude.go | 6 ++---- .../ai-proxy/provider/cloudflare.go | 2 -- .../extensions/ai-proxy/provider/deepl.go | 2 -- .../extensions/ai-proxy/provider/gemini.go | 4 +--- .../extensions/ai-proxy/provider/github.go | 2 -- .../extensions/ai-proxy/provider/hunyuan.go | 7 ++----- .../extensions/ai-proxy/provider/qwen.go | 4 ---- .../extensions/ai-proxy/provider/spark.go | 2 -- 11 files changed, 20 insertions(+), 35 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index aa9cb032c..3f4dc49ba 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -89,16 +89,21 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } if apiName == "" { - log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) - log.Debugf("[onHttpRequestHeader] no send response") + log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path) return types.ActionContinue } + // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. + ctx.DisableReroute() + ctx.SetContext(ctxKeyApiName, apiName) + _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) + _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) + if needHandleBody || needHandleStreamingBody { + proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + } + if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { - // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. - ctx.DisableReroute() // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) @@ -106,11 +111,12 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { if hasRequestBody { + proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) - // Always return types.HeaderStopIteration to support fallback routing, - // as long as onHttpRequestBody can be called. + // Delay the header processing to allow changing in OnRequestBody return types.HeaderStopIteration } + ctx.DontReadRequestBody() return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index b762a0a58..fa5f1362c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, ai360Domain) - util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index e08013437..9e02d0fd9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) } util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) - util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) + headers.Set("api-key", m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 5f99d0293..994346974 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -114,15 +114,13 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) util.OverwriteRequestHostHeader(headers, claudeDomain) - headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) + headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion } - headers.Add("anthropic-version", c.config.claudeVersion) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set("anthropic-version", c.config.claudeVersion) } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index e9663b0da..4340183ee 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) util.OverwriteRequestHostHeader(headers, cloudflareDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 82998ee1e..345a70c94 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -87,8 +87,6 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath) util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx)) - headers.Del("Content-Length") - headers.Del("Accept-Encoding") } func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index abb6268ea..7a9b0a3dd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -62,9 +62,7 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, geminiDomain) - headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) } func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 1d5c53dc4..348134c0a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -68,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) } util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } func (m *githubProvider) GetApiName(path string) ApiName { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index bcd598830..4b10a4d7c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) // 添加 hunyuan 需要的自定义字段 - headers.Add(actionKey, hunyuanChatCompletionTCAction) - headers.Add(versionKey, versionValue) - - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set(actionKey, hunyuanChatCompletionTCAction) + headers.Set(versionKey, versionValue) } // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index dbba80355..95fe28e4b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -80,9 +80,6 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } else if apiName == ApiNameEmbeddings { util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) } - - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { @@ -109,7 +106,6 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName return nil } - // Delay the header processing to allow changing streaming mode in OnRequestBody return nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index 1bdea9d67..f44b9e3c0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -177,6 +177,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") }