From 2a200cdd42b79c26175055e77cdd3d362ae24255 Mon Sep 17 00:00:00 2001 From: StarryNight Date: Mon, 16 Dec 2024 18:41:38 +0800 Subject: [PATCH] AI proxy return unified status in header phase (#1588) --- plugins/wasm-go/extensions/ai-proxy/main.go | 4 ++-- plugins/wasm-go/extensions/ai-proxy/provider/ai360.go | 6 +++--- plugins/wasm-go/extensions/ai-proxy/provider/azure.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/baichuan.go | 6 +++--- plugins/wasm-go/extensions/ai-proxy/provider/baidu.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/claude.go | 6 +++--- .../extensions/ai-proxy/provider/cloudflare.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/cohere.go | 11 ++++++----- plugins/wasm-go/extensions/ai-proxy/provider/coze.go | 5 ++--- plugins/wasm-go/extensions/ai-proxy/provider/deepl.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/deepseek.go | 9 +++++---- .../wasm-go/extensions/ai-proxy/provider/doubao.go | 11 ++++++----- .../wasm-go/extensions/ai-proxy/provider/gemini.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/github.go | 11 ++++++----- plugins/wasm-go/extensions/ai-proxy/provider/groq.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/hunyuan.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/minimax.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/mistral.go | 9 +++++---- .../wasm-go/extensions/ai-proxy/provider/moonshot.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/ollama.go | 9 +++++---- .../wasm-go/extensions/ai-proxy/provider/openai.go | 4 ++-- .../wasm-go/extensions/ai-proxy/provider/provider.go | 2 +- plugins/wasm-go/extensions/ai-proxy/provider/qwen.go | 8 ++++---- plugins/wasm-go/extensions/ai-proxy/provider/spark.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/stepfun.go | 9 +++++---- plugins/wasm-go/extensions/ai-proxy/provider/yi.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/zhipuai.go | 6 +++--- 27 files changed, 94 insertions(+), 88 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 0bc62175e..aa9cb032c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -103,7 +103,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf providerConfig.SetApiTokenInUse(ctx, log) hasRequestBody := wrapper.HasRequestBody() - action, err := handler.OnRequestHeaders(ctx, apiName, log) + err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { if hasRequestBody { ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) @@ -111,7 +111,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // as long as onHttpRequestBody can be called. return types.HeaderStopIteration } - return action + return types.ActionContinue } util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 6f42d570d..b762a0a58 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string { return providerTypeAi360 } -func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index b09cdd095..e08013437 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string { return providerTypeAzure } -func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index b43ba8ee2..759c2dd03 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string { return providerTypeBaichuan } -func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 090883629..595ef3d4f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -63,12 +63,12 @@ func (g *baiduProvider) GetProviderType() string { return providerTypeBaidu } -func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 8b98d62d6..5f99d0293 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -102,12 +102,12 @@ func (c *claudeProvider) GetProviderType() string { return providerTypeClaude } -func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 2f6108b0d..e9663b0da 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string { return providerTypeCloudflare } -func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 72dbaf280..a3b930e7f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -3,11 +3,12 @@ package provider import ( "encoding/json" "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string { return providerTypeCohere } -func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 878bbb9f9..43cdca60f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -6,7 +6,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) const ( @@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string { return providerTypeCoze } -func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index bafe6b3dd..82998ee1e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -76,12 +76,12 @@ func (d *deeplProvider) GetProviderType() string { return providerTypeDeepl } -func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) - return types.HeaderStopIteration, nil + return nil } func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 9cad3928f..7d240f09a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) // deepseekProvider is the provider for deepseek Ai service. @@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string { return providerTypeDeepSeek } -func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 651b98320..96a4aab54 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string { return providerTypeDoubao } -func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index a4c1ef2cd..abb6268ea 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -51,13 +51,13 @@ func (g *geminiProvider) GetProviderType() string { return providerTypeGemini } -func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 0a2b0c84d..1d5c53dc4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) // githubProvider is the provider for GitHub OpenAI service. @@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string { return providerTypeGithub } -func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index dfbd97126..5f2734519 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string { return providerTypeGroq } -func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index b6a49eb55..bcd598830 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } -func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 56e36441a..9531edcf1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -65,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string { return providerTypeMinimax } -func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 3e5323a60..041665f9d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) const ( @@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string { return providerTypeMistral } -func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 38d99ae0e..733cc038b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string { return providerTypeMoonshot } -func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 533908381..1bed639f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -3,10 +3,11 @@ package provider import ( "errors" "fmt" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) // ollamaProvider is the provider for Ollama service. @@ -48,12 +49,12 @@ func (m *ollamaProvider) GetProviderType() string { return providerTypeOllama } -func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 60c835cd4..480fdda57 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -57,9 +57,9 @@ func (m *openaiProvider) GetProviderType() string { return providerTypeOpenAI } -func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index ea1503b1e..0f482732a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -118,7 +118,7 @@ type ApiNameHandler interface { } type RequestHeadersHandler interface { - OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) + OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error } type TransformRequestHeadersHandler interface { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index a4a727724..e2498b9a8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -95,20 +95,20 @@ func (m *qwenProvider) GetProviderType() string { return providerTypeQwen } -func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() - return types.ActionContinue, nil + return nil } // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index c2e013643..1bdea9d67 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -67,12 +67,12 @@ func (p *sparkProvider) GetProviderType() string { return providerTypeSpark } -func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 1ee01abe6..4f642c5f6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) const ( @@ -39,12 +40,12 @@ func (m *stepfunProvider) GetProviderType() string { return providerTypeStepfun } -func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 7cb05a938..e80148ca0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -40,12 +40,12 @@ func (m *yiProvider) GetProviderType() string { return providerTypeYi } -func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 40fbe4ef8..9c30adb10 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -40,12 +40,12 @@ func (m *zhipuAiProvider) GetProviderType() string { return providerTypeZhipuAi } -func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {