diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 54ffa05b8..2e4af4349 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -34,6 +34,7 @@ description: AI 代理插件配置参考 | `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | | `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | | `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | +| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | `context`的配置字段说明如下: @@ -43,6 +44,33 @@ description: AI 代理插件配置参考 | `serviceName` | string | 必填 | - | URL 所对应的 Higress 后端服务完整名称 | | `servicePort` | number | 必填 | - | URL 所对应的 Higress 后端服务访问端口 | + +`customSettings`的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------- | +| `name` | string | 必填 | - | 想要设置的参数的名称,例如`max_tokens` | +| `value` | string/int/float/bool | 必填 | - | 想要设置的参数的值,例如0 | +| `mode` | string | 非必填 | "auto" | 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会自动根据协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查 | +| `overwrite` | bool | 非必填 | true | 如果为false则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置 | + + +custom-setting会遵循如下表格,根据`name`和协议来替换对应的字段,用户需要填写表格中`settingName`列中存在的值。例如用户将`name`设置为`max_tokens`,在openai协议中会替换`max_tokens`,在gemini中会替换`maxOutputTokens`。 +`none`表示该协议不支持此参数。如果`name`不在此表格中或者对应协议不支持此参数,同时没有设置raw模式,则配置不会生效。 + + +| settingName | openai | baidu | spark | qwen | gemini | hunyuan | claude | minimax | +| ----------- | ----------- | ----------------- | ----------- | ----------- | --------------- | ----------- | ----------- | ------------------ | +| max_tokens | max_tokens | max_output_tokens | max_tokens | max_tokens | maxOutputTokens | none | max_tokens | tokens_to_generate | +| temperature | temperature | temperature | temperature | temperature | temperature | Temperature | temperature | temperature | +| top_p | top_p | top_p | none | top_p | topP | TopP | top_p | top_p | +| top_k | none | none | top_k | none | topK | none | top_k | none | +| seed | seed | none | none | seed | none | none | none | none | + +如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。 +对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。 + + ### 提供商特有配置 #### OpenAI diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index 01248ef58..e1bba6402 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -50,3 +50,7 @@ func (c *PluginConfig) Complete() error { func (c *PluginConfig) GetProvider() provider.Provider { return c.provider } + +func (c *PluginConfig) GetProviderConfig() provider.ProviderConfig { + return c.providerConfig +} diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index f09e0d4af..807b48bce 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -75,15 +75,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() - action, err := handler.OnRequestHeaders(ctx, apiName, log) + _, err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { - if contentType, err := proxywasm.GetHttpRequestHeader("Content-Type"); err == nil && contentType != "" { + if wrapper.HasRequestBody() { ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) // Always return types.HeaderStopIteration to support fallback routing, // as long as onHttpRequestBody can be called. return types.HeaderStopIteration } - return action + return types.ActionContinue } _ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err)) return types.ActionContinue @@ -104,12 +104,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + + newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body) + if settingErr != nil { + _ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to rewrite request body by custom settings: %v", settingErr)) + return types.ActionContinue + } + + log.Debugf("[onHttpRequestBody] newBody=%s", newBody) + body = newBody action, err := handler.OnRequestBody(ctx, apiName, body, log) if err == nil { return action } _ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request body: %v", err)) - return types.ActionContinue } return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go b/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go new file mode 100644 index 000000000..a59100311 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go @@ -0,0 +1,137 @@ +package provider + +import ( + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + nameMaxTokens = "max_tokens" + nameTemperature = "temperature" + nameTopP = "top_p" + nameTopK = "top_k" + nameSeed = "seed" +) + +var maxTokensMapping = map[string]string{ + "openai": "max_tokens", + "baidu": "max_output_tokens", + "spark": "max_tokens", + "qwen": "max_tokens", + "gemini": "maxOutputTokens", + "claude": "max_tokens", + "minimax": "tokens_to_generate", +} + +var temperatureMapping = map[string]string{ + "openai": "temperature", + "baidu": "temperature", + "spark": "temperature", + "qwen": "temperature", + "gemini": "temperature", + "hunyuan": "Temperature", + "claude": "temperature", + "minimax": "temperature", +} + +var topPMapping = map[string]string{ + "openai": "top_p", + "baidu": "top_p", + "qwen": "top_p", + "gemini": "topP", + "hunyuan": "TopP", + "claude": "top_p", + "minimax": "top_p", +} + +var topKMapping = map[string]string{ + "spark": "top_k", + "gemini": "topK", + "claude": "top_k", +} + +var seedMapping = map[string]string{ + "openai": "seed", + "qwen": "seed", +} + +var settingMapping = map[string]map[string]string{ + nameMaxTokens: maxTokensMapping, + nameTemperature: temperatureMapping, + nameTopP: topPMapping, + nameTopK: topKMapping, + nameSeed: seedMapping, +} + +type CustomSetting struct { + // @Title zh-CN 参数名称 + // @Description zh-CN 想要设置的参数的名称,例如max_tokens + name string + // @Title zh-CN 参数值 + // @Description zh-CN 想要设置的参数的值,例如0 + value string + // @Title zh-CN 设置模式 + // @Description zh-CN 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会根据 /plugins/wasm-go/extensions/ai-proxy/README.md中关于custom-setting部分的表格自动按照协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查 + mode string + // @Title zh-CN json edit 模式 + // @Description zh-CN 如果为false则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置 + overwrite bool +} + +func (c *CustomSetting) FromJson(json gjson.Result) { + c.name = json.Get("name").String() + c.value = json.Get("value").Raw + if obj := json.Get("mode"); obj.Exists() { + c.mode = obj.String() + } else { + c.mode = "auto" + } + if obj := json.Get("overwrite"); obj.Exists() { + c.overwrite = obj.Bool() + } else { + c.overwrite = true + } +} + +func (c *CustomSetting) Validate() bool { + return c.name != "" +} + +func (c *CustomSetting) setInvalid() { + c.name = "" // set empty to represent invalid +} + +func (c *CustomSetting) AdjustWithProtocol(protocol string) { + if !(c.mode == "raw") { + mapping, ok := settingMapping[c.name] + if ok { + c.name, ok = mapping[protocol] + } + if !ok { + c.setInvalid() + return + } + } + + if protocol == providerTypeQwen { + c.name = "parameters." + c.name + } + if protocol == providerTypeGemini { + c.name = "generation_config." + c.name + } +} + +func ReplaceByCustomSettings(body []byte, settings []CustomSetting) ([]byte, error) { + var err error + strBody := string(body) + for _, setting := range settings { + if !setting.overwrite && gjson.Get(strBody, setting.name).Exists() { + continue + } + strBody, err = sjson.SetRaw(strBody, setting.name, setting.value) + if err != nil { + break + } + } + return []byte(strBody), err +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 3d1d0217f..76b488e7a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -184,6 +184,9 @@ type ProviderConfig struct { // @Title zh-CN 指定服务返回的响应需满足的JSON Schema // @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考:https://platform.openai.com/docs/guides/structured-outputs responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"` + // @Title zh-CN 自定义大模型参数配置 + // @Description zh-CN 用于填充或者覆盖大模型调用时的参数 + customSettings []CustomSetting } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -239,6 +242,25 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.responseJsonSchema = nil } + + c.customSettings = make([]CustomSetting, 0) + customSettingsJson := json.Get("customSettings") + if customSettingsJson.Exists() { + protocol := protocolOpenAI + if c.protocol == protocolOriginal { + // use provider name to represent original protocol name + protocol = c.typ + } + for _, settingJson := range customSettingsJson.Array() { + setting := CustomSetting{} + setting.FromJson(settingJson) + // use protocol info to rewrite setting + setting.AdjustWithProtocol(protocol) + if setting.Validate() { + c.customSettings = append(c.customSettings, setting) + } + } + } } func (c *ProviderConfig) Validate() error { @@ -324,3 +346,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } + +func (c ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { + return ReplaceByCustomSettings(body, c.customSettings) +} diff --git a/test/e2e/conformance/tests/go-wasm-ai-proxy.go b/test/e2e/conformance/tests/go-wasm-ai-proxy.go new file mode 100644 index 000000000..fc341b0be --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-proxy.go @@ -0,0 +1,115 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "github.com/alibaba/higress/test/e2e/conformance/utils/http" + "github.com/alibaba/higress/test/e2e/conformance/utils/suite" +) + +func init() { + Register(WasmPluginsAiProxy) +} + +var WasmPluginsAiProxy = suite.ConformanceTest{ + ShortName: "WasmPluginAiProxy", + Description: "The Ingress in the higress-conformance-infra namespace test the ai-proxy WASM plugin.", + Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature}, + Manifests: []string{"tests/go-wasm-ai-proxy.yaml"}, + Test: func(t *testing.T, suite *suite.ConformanceTestSuite) { + testcases := []http.Assertion{ + { + Meta: http.AssertionMeta{ + TestCaseName: "case 1: openai", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "openai.ai.com", + Path: "/v1/chat/completions", + Method:"POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "gpt-3", + "messages": [{"role":"user","content":"hi"}]}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "api.openai.com", + Path: "/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "gpt-3", + "messages": [{"role":"user","content":"hi"}], + "max_tokens": 123, + "temperature": 0.66}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 200, + }, + }, + }, + { + Meta: http.AssertionMeta{ + TestCaseName: "case 2: qwen", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "qwen.ai.com", + Path: "/v1/chat/completions", + Method:"POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "input": {"messages": [{"role":"user","content":"hi"}]}, + "parameters": {"max_tokens": 321, "temperature": 0.7}}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/api/v1/services/aigc/text-generation/generation", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "input": {"messages": [{"role":"user","content":"hi"}]}, + "parameters": {"max_tokens": 321, "temperature": 0.66}}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 500, + }, + }, + }, + + } + t.Run("WasmPlugins ai-proxy", func(t *testing.T) { + for _, testcase := range testcases { + http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase) + } + }) + }, +} diff --git a/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml b/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml new file mode 100644 index 000000000..ab2de2df0 --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml @@ -0,0 +1,87 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-proxy-openai + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "openai.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-proxy-qwen + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "qwen.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-proxy + namespace: higress-system +spec: + priority: 200 + matchRules: + - config: + provider: + type: "openai" + customSettings: + - name: "max_tokens" + value: 123 + overwrite: false + - name: "temperature" + value: 0.66 + overwrite: true + ingress: + - higress-conformance-infra/wasmplugin-ai-proxy-openai + - config: + provider: + type: "qwen" + apiTokens: "fake-token" + customSettings: + - name: "max_tokens" + value: 123 + overwrite: false + - name: "temperature" + value: 0.66 + overwrite: true + ingress: + - higress-conformance-infra/wasmplugin-ai-proxy-qwen + url: file:///opt/plugins/wasm-go/extensions/ai-proxy/plugin.wasm