diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md
index 54ffa05b8..2e4af4349 100644
--- a/plugins/wasm-go/extensions/ai-proxy/README.md
+++ b/plugins/wasm-go/extensions/ai-proxy/README.md
@@ -34,6 +34,7 @@ description: AI 代理插件配置参考
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
+| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
`context`的配置字段说明如下:
@@ -43,6 +44,33 @@ description: AI 代理插件配置参考
| `serviceName` | string | 必填 | - | URL 所对应的 Higress 后端服务完整名称 |
| `servicePort` | number | 必填 | - | URL 所对应的 Higress 后端服务访问端口 |
+
+`customSettings`的配置字段说明如下:
+
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+| ----------- | --------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------- |
+| `name` | string | 必填 | - | 想要设置的参数的名称,例如`max_tokens` |
+| `value` | string/int/float/bool | 必填 | - | 想要设置的参数的值,例如0 |
+| `mode` | string | 非必填 | "auto" | 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会自动根据协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查 |
+| `overwrite` | bool | 非必填 | true | 如果为false则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置 |
+
+
+custom-setting会遵循如下表格,根据`name`和协议来替换对应的字段,用户需要填写表格中`settingName`列中存在的值。例如用户将`name`设置为`max_tokens`,在openai协议中会替换`max_tokens`,在gemini中会替换`maxOutputTokens`。
+`none`表示该协议不支持此参数。如果`name`不在此表格中或者对应协议不支持此参数,同时没有设置raw模式,则配置不会生效。
+
+
+| settingName | openai | baidu | spark | qwen | gemini | hunyuan | claude | minimax |
+| ----------- | ----------- | ----------------- | ----------- | ----------- | --------------- | ----------- | ----------- | ------------------ |
+| max_tokens | max_tokens | max_output_tokens | max_tokens | max_tokens | maxOutputTokens | none | max_tokens | tokens_to_generate |
+| temperature | temperature | temperature | temperature | temperature | temperature | Temperature | temperature | temperature |
+| top_p | top_p | top_p | none | top_p | topP | TopP | top_p | top_p |
+| top_k | none | none | top_k | none | topK | none | top_k | none |
+| seed | seed | none | none | seed | none | none | none | none |
+
+如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
+对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
+
+
### 提供商特有配置
#### OpenAI
diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go
index 01248ef58..e1bba6402 100644
--- a/plugins/wasm-go/extensions/ai-proxy/config/config.go
+++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go
@@ -50,3 +50,7 @@ func (c *PluginConfig) Complete() error {
func (c *PluginConfig) GetProvider() provider.Provider {
return c.provider
}
+
+func (c *PluginConfig) GetProviderConfig() provider.ProviderConfig {
+ return c.providerConfig
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go
index f09e0d4af..807b48bce 100644
--- a/plugins/wasm-go/extensions/ai-proxy/main.go
+++ b/plugins/wasm-go/extensions/ai-proxy/main.go
@@ -75,15 +75,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
- action, err := handler.OnRequestHeaders(ctx, apiName, log)
+ _, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
- if contentType, err := proxywasm.GetHttpRequestHeader("Content-Type"); err == nil && contentType != "" {
+ if wrapper.HasRequestBody() {
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Always return types.HeaderStopIteration to support fallback routing,
// as long as onHttpRequestBody can be called.
return types.HeaderStopIteration
}
- return action
+ return types.ActionContinue
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
@@ -104,12 +104,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
+
+ newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
+ if settingErr != nil {
+ _ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to rewrite request body by custom settings: %v", settingErr))
+ return types.ActionContinue
+ }
+
+ log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
+ body = newBody
action, err := handler.OnRequestBody(ctx, apiName, body, log)
if err == nil {
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request body: %v", err))
- return types.ActionContinue
}
return types.ActionContinue
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go b/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go
new file mode 100644
index 000000000..a59100311
--- /dev/null
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/custom_setting.go
@@ -0,0 +1,137 @@
+package provider
+
+import (
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+const (
+ nameMaxTokens = "max_tokens"
+ nameTemperature = "temperature"
+ nameTopP = "top_p"
+ nameTopK = "top_k"
+ nameSeed = "seed"
+)
+
+var maxTokensMapping = map[string]string{
+ "openai": "max_tokens",
+ "baidu": "max_output_tokens",
+ "spark": "max_tokens",
+ "qwen": "max_tokens",
+ "gemini": "maxOutputTokens",
+ "claude": "max_tokens",
+ "minimax": "tokens_to_generate",
+}
+
+var temperatureMapping = map[string]string{
+ "openai": "temperature",
+ "baidu": "temperature",
+ "spark": "temperature",
+ "qwen": "temperature",
+ "gemini": "temperature",
+ "hunyuan": "Temperature",
+ "claude": "temperature",
+ "minimax": "temperature",
+}
+
+var topPMapping = map[string]string{
+ "openai": "top_p",
+ "baidu": "top_p",
+ "qwen": "top_p",
+ "gemini": "topP",
+ "hunyuan": "TopP",
+ "claude": "top_p",
+ "minimax": "top_p",
+}
+
+var topKMapping = map[string]string{
+ "spark": "top_k",
+ "gemini": "topK",
+ "claude": "top_k",
+}
+
+var seedMapping = map[string]string{
+ "openai": "seed",
+ "qwen": "seed",
+}
+
+var settingMapping = map[string]map[string]string{
+ nameMaxTokens: maxTokensMapping,
+ nameTemperature: temperatureMapping,
+ nameTopP: topPMapping,
+ nameTopK: topKMapping,
+ nameSeed: seedMapping,
+}
+
+type CustomSetting struct {
+ // @Title zh-CN 参数名称
+ // @Description zh-CN 想要设置的参数的名称,例如max_tokens
+ name string
+ // @Title zh-CN 参数值
+ // @Description zh-CN 想要设置的参数的值,例如0
+ value string
+ // @Title zh-CN 设置模式
+ // @Description zh-CN 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会根据 /plugins/wasm-go/extensions/ai-proxy/README.md中关于custom-setting部分的表格自动按照协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查
+ mode string
+ // @Title zh-CN json edit 模式
+ // @Description zh-CN 如果为false则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置
+ overwrite bool
+}
+
+func (c *CustomSetting) FromJson(json gjson.Result) {
+ c.name = json.Get("name").String()
+ c.value = json.Get("value").Raw
+ if obj := json.Get("mode"); obj.Exists() {
+ c.mode = obj.String()
+ } else {
+ c.mode = "auto"
+ }
+ if obj := json.Get("overwrite"); obj.Exists() {
+ c.overwrite = obj.Bool()
+ } else {
+ c.overwrite = true
+ }
+}
+
+func (c *CustomSetting) Validate() bool {
+ return c.name != ""
+}
+
+func (c *CustomSetting) setInvalid() {
+ c.name = "" // set empty to represent invalid
+}
+
+func (c *CustomSetting) AdjustWithProtocol(protocol string) {
+ if !(c.mode == "raw") {
+ mapping, ok := settingMapping[c.name]
+ if ok {
+ c.name, ok = mapping[protocol]
+ }
+ if !ok {
+ c.setInvalid()
+ return
+ }
+ }
+
+ if protocol == providerTypeQwen {
+ c.name = "parameters." + c.name
+ }
+ if protocol == providerTypeGemini {
+ c.name = "generation_config." + c.name
+ }
+}
+
+func ReplaceByCustomSettings(body []byte, settings []CustomSetting) ([]byte, error) {
+ var err error
+ strBody := string(body)
+ for _, setting := range settings {
+ if !setting.overwrite && gjson.Get(strBody, setting.name).Exists() {
+ continue
+ }
+ strBody, err = sjson.SetRaw(strBody, setting.name, setting.value)
+ if err != nil {
+ break
+ }
+ }
+ return []byte(strBody), err
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
index 3d1d0217f..76b488e7a 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
@@ -184,6 +184,9 @@ type ProviderConfig struct {
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考:https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
+ // @Title zh-CN 自定义大模型参数配置
+ // @Description zh-CN 用于填充或者覆盖大模型调用时的参数
+ customSettings []CustomSetting
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
@@ -239,6 +242,25 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.responseJsonSchema = nil
}
+
+ c.customSettings = make([]CustomSetting, 0)
+ customSettingsJson := json.Get("customSettings")
+ if customSettingsJson.Exists() {
+ protocol := protocolOpenAI
+ if c.protocol == protocolOriginal {
+ // use provider name to represent original protocol name
+ protocol = c.typ
+ }
+ for _, settingJson := range customSettingsJson.Array() {
+ setting := CustomSetting{}
+ setting.FromJson(settingJson)
+ // use protocol info to rewrite setting
+ setting.AdjustWithProtocol(protocol)
+ if setting.Validate() {
+ c.customSettings = append(c.customSettings, setting)
+ }
+ }
+ }
}
func (c *ProviderConfig) Validate() error {
@@ -324,3 +346,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
+
+func (c ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
+ return ReplaceByCustomSettings(body, c.customSettings)
+}
diff --git a/test/e2e/conformance/tests/go-wasm-ai-proxy.go b/test/e2e/conformance/tests/go-wasm-ai-proxy.go
new file mode 100644
index 000000000..fc341b0be
--- /dev/null
+++ b/test/e2e/conformance/tests/go-wasm-ai-proxy.go
@@ -0,0 +1,115 @@
+// Copyright (c) 2022 Alibaba Group Holding Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "github.com/alibaba/higress/test/e2e/conformance/utils/http"
+ "github.com/alibaba/higress/test/e2e/conformance/utils/suite"
+)
+
+func init() {
+ Register(WasmPluginsAiProxy)
+}
+
+var WasmPluginsAiProxy = suite.ConformanceTest{
+ ShortName: "WasmPluginAiProxy",
+ Description: "The Ingress in the higress-conformance-infra namespace test the ai-proxy WASM plugin.",
+ Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature},
+ Manifests: []string{"tests/go-wasm-ai-proxy.yaml"},
+ Test: func(t *testing.T, suite *suite.ConformanceTestSuite) {
+ testcases := []http.Assertion{
+ {
+ Meta: http.AssertionMeta{
+ TestCaseName: "case 1: openai",
+ TargetBackend: "infra-backend-v1",
+ TargetNamespace: "higress-conformance-infra",
+ },
+ Request: http.AssertionRequest{
+ ActualRequest: http.Request{
+ Host: "openai.ai.com",
+ Path: "/v1/chat/completions",
+ Method:"POST",
+ ContentType: http.ContentTypeApplicationJson,
+ Body: []byte(`{
+ "model": "gpt-3",
+ "messages": [{"role":"user","content":"hi"}]}`),
+ },
+ ExpectedRequest: &http.ExpectedRequest{
+ Request: http.Request{
+ Host: "api.openai.com",
+ Path: "/v1/chat/completions",
+ Method: "POST",
+ ContentType: http.ContentTypeApplicationJson,
+ Body: []byte(`{
+ "model": "gpt-3",
+ "messages": [{"role":"user","content":"hi"}],
+ "max_tokens": 123,
+ "temperature": 0.66}`),
+ },
+ },
+ },
+ Response: http.AssertionResponse{
+ ExpectedResponse: http.Response{
+ StatusCode: 200,
+ },
+ },
+ },
+ {
+ Meta: http.AssertionMeta{
+ TestCaseName: "case 2: qwen",
+ TargetBackend: "infra-backend-v1",
+ TargetNamespace: "higress-conformance-infra",
+ },
+ Request: http.AssertionRequest{
+ ActualRequest: http.Request{
+ Host: "qwen.ai.com",
+ Path: "/v1/chat/completions",
+ Method:"POST",
+ ContentType: http.ContentTypeApplicationJson,
+ Body: []byte(`{
+ "model": "qwen-long",
+ "input": {"messages": [{"role":"user","content":"hi"}]},
+ "parameters": {"max_tokens": 321, "temperature": 0.7}}`),
+ },
+ ExpectedRequest: &http.ExpectedRequest{
+ Request: http.Request{
+ Host: "dashscope.aliyuncs.com",
+ Path: "/api/v1/services/aigc/text-generation/generation",
+ Method: "POST",
+ ContentType: http.ContentTypeApplicationJson,
+ Body: []byte(`{
+ "model": "qwen-long",
+ "input": {"messages": [{"role":"user","content":"hi"}]},
+ "parameters": {"max_tokens": 321, "temperature": 0.66}}`),
+ },
+ },
+ },
+ Response: http.AssertionResponse{
+ ExpectedResponse: http.Response{
+ StatusCode: 500,
+ },
+ },
+ },
+
+ }
+ t.Run("WasmPlugins ai-proxy", func(t *testing.T) {
+ for _, testcase := range testcases {
+ http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase)
+ }
+ })
+ },
+}
diff --git a/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml b/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml
new file mode 100644
index 000000000..ab2de2df0
--- /dev/null
+++ b/test/e2e/conformance/tests/go-wasm-ai-proxy.yaml
@@ -0,0 +1,87 @@
+# Copyright (c) 2022 Alibaba Group Holding Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+apiVersion: networking.k8s.io/v1
+kind: Ingress
+metadata:
+ annotations:
+ name: wasmplugin-ai-proxy-openai
+ namespace: higress-conformance-infra
+spec:
+ ingressClassName: higress
+ rules:
+ - host: "openai.ai.com"
+ http:
+ paths:
+ - pathType: Prefix
+ path: "/"
+ backend:
+ service:
+ name: infra-backend-v1
+ port:
+ number: 8080
+---
+apiVersion: networking.k8s.io/v1
+kind: Ingress
+metadata:
+ annotations:
+ name: wasmplugin-ai-proxy-qwen
+ namespace: higress-conformance-infra
+spec:
+ ingressClassName: higress
+ rules:
+ - host: "qwen.ai.com"
+ http:
+ paths:
+ - pathType: Prefix
+ path: "/"
+ backend:
+ service:
+ name: infra-backend-v1
+ port:
+ number: 8080
+---
+apiVersion: extensions.higress.io/v1alpha1
+kind: WasmPlugin
+metadata:
+ name: ai-proxy
+ namespace: higress-system
+spec:
+ priority: 200
+ matchRules:
+ - config:
+ provider:
+ type: "openai"
+ customSettings:
+ - name: "max_tokens"
+ value: 123
+ overwrite: false
+ - name: "temperature"
+ value: 0.66
+ overwrite: true
+ ingress:
+ - higress-conformance-infra/wasmplugin-ai-proxy-openai
+ - config:
+ provider:
+ type: "qwen"
+ apiTokens: "fake-token"
+ customSettings:
+ - name: "max_tokens"
+ value: 123
+ overwrite: false
+ - name: "temperature"
+ value: 0.66
+ overwrite: true
+ ingress:
+ - higress-conformance-infra/wasmplugin-ai-proxy-qwen
+ url: file:///opt/plugins/wasm-go/extensions/ai-proxy/plugin.wasm