From 8fa1224cbaa883106dfc580702629dd70c010793 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Thu, 15 Aug 2024 18:52:49 +0800 Subject: [PATCH] support qwen compatible mode (#1205) --- plugins/wasm-go/extensions/ai-proxy/go.mod | 1 + plugins/wasm-go/extensions/ai-proxy/go.sum | 3 ++ .../extensions/ai-proxy/provider/provider.go | 4 +++ .../extensions/ai-proxy/provider/qwen.go | 29 +++++++++++++++++-- 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index 6f1fb1af0..e2c671d98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -22,5 +22,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/resp v0.1.1 // indirect + github.com/tidwall/sjson v1.2.5 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index e726b100a..e5b8b7917 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -12,6 +12,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -20,6 +21,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 419b18d6f..b747c8da1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -140,6 +140,9 @@ type ProviderConfig struct { // @Title zh-CN 启用通义千问搜索服务 // @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。 qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"` + // @Title zh-CN 开启通义千问兼容模式 + // @Description zh-CN 启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。 + qwenEnableCompatible bool `required:"false" yaml:"qwenEnableCompatible" json:"qwenEnableCompatible"` // @Title zh-CN Ollama Server IP/Domain // @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。 ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"` @@ -193,6 +196,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.qwenFileIds = append(c.qwenFileIds, fileId.String()) } c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool() + c.qwenEnableCompatible = json.Get("qwenEnableCompatible").Bool() c.ollamaServerHost = json.Get("ollamaServerHost").String() c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint()) c.modelMapping = make(map[string]string) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 65b301526..9424749f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -13,6 +13,8 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // qwenProvider is the provider for Qwen service. @@ -23,6 +25,7 @@ const ( qwenDomain = "dashscope.aliyuncs.com" qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" + qwenCompatiblePath = "/compatible-mode/v1/chat/completions" qwenTopPMin = 0.000001 qwenTopPMax = 0.999999 @@ -63,7 +66,9 @@ func (m *qwenProvider) GetProviderType() string { } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - if apiName == ApiNameChatCompletion { + if m.config.qwenEnableCompatible { + _ = util.OverwriteRequestPath(qwenCompatiblePath) + } else if apiName == ApiNameChatCompletion { _ = util.OverwriteRequestPath(qwenChatCompletionPath) } else if apiName == ApiNameEmbeddings { _ = util.OverwriteRequestPath(qwenTextEmbeddingPath) @@ -85,6 +90,23 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName } func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if m.config.qwenEnableCompatible { + if gjson.GetBytes(body, "model").Exists() { + rawModel := gjson.GetBytes(body, "model").String() + mappedModel := getMappedModel(rawModel, m.config.modelMapping, log) + newBody, err := sjson.SetBytes(body, "model", mappedModel) + if err != nil { + log.Errorf("Replace model error: %v", err) + return types.ActionContinue, err + } + err = proxywasm.ReplaceHttpRequestBody(newBody) + if err != nil { + log.Errorf("Replace request body error: %v", err) + return types.ActionContinue, err + } + } + return types.ActionContinue, nil + } if apiName == ApiNameChatCompletion { return m.onChatCompletionRequestBody(ctx, body, log) } @@ -220,7 +242,7 @@ func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if name != ApiNameChatCompletion { + if m.config.qwenEnableCompatible || name != ApiNameChatCompletion { return chunk, nil } @@ -305,6 +327,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api } func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if m.config.qwenEnableCompatible { + return types.ActionContinue, nil + } if apiName == ApiNameChatCompletion { return m.onChatCompletionResponseBody(ctx, body, log) }