From 5e787b32580b854c021dc5b07e1ba760779ee117 Mon Sep 17 00:00:00 2001 From: rinfx Date: Tue, 13 Jan 2026 20:14:29 +0800 Subject: [PATCH] Replace model-router and model-mapper with Go implementation (#3317) --- .../wasm-go/extensions/model-mapper/Makefile | 2 + .../wasm-go/extensions/model-mapper/README.md | 61 ++++ .../extensions/model-mapper/README_EN.md | 61 ++++ .../wasm-go/extensions/model-mapper/go.mod | 24 ++ .../wasm-go/extensions/model-mapper/go.sum | 30 ++ .../wasm-go/extensions/model-mapper/main.go | 192 ++++++++++++ .../extensions/model-mapper/main_test.go | 250 +++++++++++++++ .../wasm-go/extensions/model-router/Makefile | 2 + .../wasm-go/extensions/model-router/README.md | 98 ++++++ .../extensions/model-router/README_EN.md | 97 ++++++ .../wasm-go/extensions/model-router/go.mod | 24 ++ .../wasm-go/extensions/model-router/go.sum | 30 ++ .../wasm-go/extensions/model-router/main.go | 259 ++++++++++++++++ .../extensions/model-router/main_test.go | 288 ++++++++++++++++++ 14 files changed, 1418 insertions(+) create mode 100644 plugins/wasm-go/extensions/model-mapper/Makefile create mode 100644 plugins/wasm-go/extensions/model-mapper/README.md create mode 100644 plugins/wasm-go/extensions/model-mapper/README_EN.md create mode 100644 plugins/wasm-go/extensions/model-mapper/go.mod create mode 100644 plugins/wasm-go/extensions/model-mapper/go.sum create mode 100644 plugins/wasm-go/extensions/model-mapper/main.go create mode 100644 plugins/wasm-go/extensions/model-mapper/main_test.go create mode 100644 plugins/wasm-go/extensions/model-router/Makefile create mode 100644 plugins/wasm-go/extensions/model-router/README.md create mode 100644 plugins/wasm-go/extensions/model-router/README_EN.md create mode 100644 plugins/wasm-go/extensions/model-router/go.mod create mode 100644 plugins/wasm-go/extensions/model-router/go.sum create mode 100644 plugins/wasm-go/extensions/model-router/main.go create mode 100644 plugins/wasm-go/extensions/model-router/main_test.go diff --git a/plugins/wasm-go/extensions/model-mapper/Makefile b/plugins/wasm-go/extensions/model-mapper/Makefile new file mode 100644 index 000000000..c6e7aac48 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/Makefile @@ -0,0 +1,2 @@ +build-go: + GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go \ No newline at end of file diff --git a/plugins/wasm-go/extensions/model-mapper/README.md b/plugins/wasm-go/extensions/model-mapper/README.md new file mode 100644 index 000000000..569e8d768 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/README.md @@ -0,0 +1,61 @@ +# 功能说明 +`model-mapper`插件实现了基于LLM协议中的model参数路由的功能 + +# 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | 选填 | model | 请求body中model参数的位置 | +| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效 | + + +## 效果说明 + +如下配置 + +```yaml +modelMapping: + 'gpt-4-*': "qwen-max" + 'gpt-4o': "qwen-vl-plus" + '*': "qwen-turbo" +``` + +开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo` + +例如原本的请求是: + +```json +{ + "model": "gpt-4o", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + + +经过这个插件后,原始的 LLM 请求体将被改成: + +```json +{ + "model": "qwen-vl-plus", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` diff --git a/plugins/wasm-go/extensions/model-mapper/README_EN.md b/plugins/wasm-go/extensions/model-mapper/README_EN.md new file mode 100644 index 000000000..8ecde8683 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/README_EN.md @@ -0,0 +1,61 @@ +# Function Description +The `model-mapper` plugin implements model parameter mapping functionality based on the LLM protocol. + +# Configuration Fields + +| Name | Type | Requirement | Default Value | Description | +| --- | --- | --- | --- | --- | +| `modelKey` | string | Optional | model | The position of the model parameter in the request body. | +| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model name in the request to the model name supported by the service provider.
1. Supports prefix matching. For example, use "gpt-3-*" to match all names starting with "gpt-3-";
2. Supports using "*" as a key to configure a generic fallback mapping;
3. If the target mapping name is an empty string "", it indicates keeping the original model name. | +| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes. | + + +## Effect Description + +Configuration example: + +```yaml +modelMapping: + 'gpt-4-*': "qwen-max" + 'gpt-4o': "qwen-vl-plus" + '*': "qwen-turbo" +``` + +After enabling, model parameters starting with `gpt-4-` will be replaced with `qwen-max`, `gpt-4o` will be replaced with `qwen-vl-plus`, and all other models will be replaced with `qwen-turbo`. + +For example, the original request is: + +```json +{ + "model": "gpt-4o", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the github address of the main repository of the higress project" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + + +After processing by this plugin, the original LLM request body will be modified to: + +```json +{ + "model": "qwen-vl-plus", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the github address of the main repository of the higress project" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` diff --git a/plugins/wasm-go/extensions/model-mapper/go.mod b/plugins/wasm-go/extensions/model-mapper/go.mod new file mode 100644 index 000000000..a9a8f6462 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/go.mod @@ -0,0 +1,24 @@ +module github.com/alibaba/higress/plugins/wasm-go/extensions/model-mapper + +go 1.24.1 + +toolchain go1.24.7 + +require ( + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c + github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/resp v0.1.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/wasm-go/extensions/model-mapper/go.sum b/plugins/wasm-go/extensions/model-mapper/go.sum new file mode 100644 index 000000000..9d45243f7 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y= +github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc= +github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/wasm-go/extensions/model-mapper/main.go b/plugins/wasm-go/extensions/model-mapper/main.go new file mode 100644 index 000000000..2278b8e11 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/main.go @@ -0,0 +1,192 @@ +package main + +import ( + "errors" + "sort" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB +) + +func main() {} + +func init() { + wrapper.SetCtx( + "model-mapper", + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.WithRebuildAfterRequests[Config](1000), + wrapper.WithRebuildMaxMemBytes[Config](200*1024*1024), + ) +} + +type ModelMapping struct { + Prefix string + Target string +} + +type Config struct { + modelKey string + exactModelMapping map[string]string + prefixModelMapping []ModelMapping + defaultModel string + enableOnPathSuffix []string +} + +func parseConfig(json gjson.Result, config *Config) error { + config.modelKey = json.Get("modelKey").String() + if config.modelKey == "" { + config.modelKey = "model" + } + + modelMapping := json.Get("modelMapping") + if modelMapping.Exists() && !modelMapping.IsObject() { + return errors.New("modelMapping must be an object") + } + + config.exactModelMapping = make(map[string]string) + config.prefixModelMapping = make([]ModelMapping, 0) + + // To replicate C++ behavior (nlohmann::json iterates keys alphabetically), + // we collect entries and sort them by key. + type mappingEntry struct { + key string + value string + } + var entries []mappingEntry + modelMapping.ForEach(func(key, value gjson.Result) bool { + entries = append(entries, mappingEntry{ + key: key.String(), + value: value.String(), + }) + return true + }) + sort.Slice(entries, func(i, j int) bool { + return entries[i].key < entries[j].key + }) + + for _, entry := range entries { + key := entry.key + value := entry.value + if key == "*" { + config.defaultModel = value + } else if strings.HasSuffix(key, "*") { + prefix := strings.TrimSuffix(key, "*") + config.prefixModelMapping = append(config.prefixModelMapping, ModelMapping{ + Prefix: prefix, + Target: value, + }) + } else { + config.exactModelMapping[key] = value + } + } + + enableOnPathSuffix := json.Get("enableOnPathSuffix") + if enableOnPathSuffix.Exists() { + if !enableOnPathSuffix.IsArray() { + return errors.New("enableOnPathSuffix must be an array") + } + for _, item := range enableOnPathSuffix.Array() { + config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String()) + } + } else { + config.enableOnPathSuffix = []string{ + "/completions", + "/embeddings", + "/images/generations", + "/audio/speech", + "/fine_tuning/jobs", + "/moderations", + "/image-synthesis", + "/video-synthesis", + "/rerank", + "/messages", + } + } + + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action { + // Check path suffix + path, err := proxywasm.GetHttpRequestHeader(":path") + if err != nil { + return types.ActionContinue + } + + // Strip query parameters + if idx := strings.Index(path, "?"); idx != -1 { + path = path[:idx] + } + + matched := false + for _, suffix := range config.enableOnPathSuffix { + if strings.HasSuffix(path, suffix) { + matched = true + break + } + } + if !matched { + return types.ActionContinue + } + + if !ctx.HasRequestBody() { + ctx.DontReadRequestBody() + return types.ActionContinue + } + + // Prepare for body processing + proxywasm.RemoveHttpRequestHeader("content-length") + // 100MB buffer limit + ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes) + + return types.HeaderStopIteration +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action { + if len(body) == 0 { + return types.ActionContinue + } + + oldModel := gjson.GetBytes(body, config.modelKey).String() + + newModel := config.defaultModel + if newModel == "" { + newModel = oldModel + } + + // Exact match + if target, ok := config.exactModelMapping[oldModel]; ok { + newModel = target + } else { + // Prefix match + for _, mapping := range config.prefixModelMapping { + if strings.HasPrefix(oldModel, mapping.Prefix) { + newModel = mapping.Target + break + } + } + } + + if newModel != "" && newModel != oldModel { + newBody, err := sjson.SetBytes(body, config.modelKey, newModel) + if err != nil { + log.Errorf("failed to update model: %v", err) + return types.ActionContinue + } + proxywasm.ReplaceHttpRequestBody(newBody) + log.Debugf("model mapped, before: %s, after: %s", oldModel, newModel) + } + + return types.ActionContinue +} diff --git a/plugins/wasm-go/extensions/model-mapper/main_test.go b/plugins/wasm-go/extensions/model-mapper/main_test.go new file mode 100644 index 000000000..7be9c7851 --- /dev/null +++ b/plugins/wasm-go/extensions/model-mapper/main_test.go @@ -0,0 +1,250 @@ +package main + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// Basic configs for wasm test host +var ( + basicConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "modelMapping": map[string]string{ + "gpt-3.5-turbo": "gpt-4", + }, + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + }) + return data + }() + + customConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "request.model", + "modelMapping": map[string]string{ + "*": "gpt-4o", + "gpt-3.5*": "gpt-4-mini", + "gpt-3.5-t": "gpt-4-turbo", + "gpt-3.5-t1": "gpt-4-turbo-1", + }, + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + "/v1/embeddings", + }, + }) + return data + }() +) + +func TestParseConfig(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic config with defaults", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "modelMapping": { + "gpt-3.5-turbo": "gpt-4", + "gpt-4*": "gpt-4o-mini", + "*": "gpt-4o" + } + }`) + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + // default modelKey + require.Equal(t, "model", cfg.modelKey) + // exact mapping + require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"]) + // prefix mapping + require.Len(t, cfg.prefixModelMapping, 1) + require.Equal(t, "gpt-4", cfg.prefixModelMapping[0].Prefix) + // default model + require.Equal(t, "gpt-4o", cfg.defaultModel) + // default enabled path suffixes + require.Contains(t, cfg.enableOnPathSuffix, "/completions") + require.Contains(t, cfg.enableOnPathSuffix, "/embeddings") + }) + + t.Run("custom modelKey and enableOnPathSuffix", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "modelKey": "request.model", + "modelMapping": { + "gpt-3.5-turbo": "gpt-4", + "gpt-3.5*": "gpt-4-mini" + }, + "enableOnPathSuffix": ["/v1/chat/completions", "/v1/embeddings"] + }`) + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + require.Equal(t, "request.model", cfg.modelKey) + require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"]) + require.Len(t, cfg.prefixModelMapping, 1) + require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix) + require.Equal(t, "gpt-4-mini", cfg.prefixModelMapping[0].Target) + require.Equal(t, 2, len(cfg.enableOnPathSuffix)) + require.Contains(t, cfg.enableOnPathSuffix, "/v1/chat/completions") + require.Contains(t, cfg.enableOnPathSuffix, "/v1/embeddings") + }) + + t.Run("modelMapping must be object", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "modelMapping": "invalid" + }`) + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.Error(t, err) + }) + + t.Run("enableOnPathSuffix must be array", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "enableOnPathSuffix": "not-array" + }`) + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.Error(t, err) + }) + }) +} + +func TestOnHttpRequestHeaders(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("skip when path not matched", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalHeaders := [][2]string{ + {":authority", "example.com"}, + {":path", "/v1/other"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"content-length", "123"}, + } + action := host.CallOnHttpRequestHeaders(originalHeaders) + require.Equal(t, types.ActionContinue, action) + + newHeaders := host.GetRequestHeaders() + // content-length should still exist because path is not enabled + foundContentLength := false + for _, h := range newHeaders { + if strings.ToLower(h[0]) == "content-length" { + foundContentLength = true + break + } + } + require.True(t, foundContentLength) + }) + + t.Run("process when path and content-type match", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalHeaders := [][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"content-length", "123"}, + } + action := host.CallOnHttpRequestHeaders(originalHeaders) + require.Equal(t, types.HeaderStopIteration, action) + + newHeaders := host.GetRequestHeaders() + // content-length should be removed + for _, h := range newHeaders { + require.NotEqual(t, strings.ToLower(h[0]), "content-length") + } + }) + }) +} + +func TestOnHttpRequestBody_ModelMapping(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("exact mapping", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String()) + }) + + t.Run("default model when key missing", func(t *testing.T) { + // use customConfig where default model is set with "*" + host, status := test.NewTestHost(customConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "request": { + "messages": [{"role": "user", "content": "hello"}] + } + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + // default model should be set at request.model + require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String()) + }) + + t.Run("prefix mapping takes effect", func(t *testing.T) { + host, status := test.NewTestHost(customConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "request": { + "model": "gpt-3.5-turbo-16k", + "messages": [{"role": "user", "content": "hello"}] + } + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String()) + }) + }) +} diff --git a/plugins/wasm-go/extensions/model-router/Makefile b/plugins/wasm-go/extensions/model-router/Makefile new file mode 100644 index 000000000..c6e7aac48 --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/Makefile @@ -0,0 +1,2 @@ +build-go: + GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go \ No newline at end of file diff --git a/plugins/wasm-go/extensions/model-router/README.md b/plugins/wasm-go/extensions/model-router/README.md new file mode 100644 index 000000000..a63165e2f --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/README.md @@ -0,0 +1,98 @@ +## 功能说明 +`model-router`插件实现了基于LLM协议中的model参数路由的功能 + +## 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | 选填 | model | 请求body中model参数的位置 | +| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 | +| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 | +| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 | + +## 运行属性 + +插件执行阶段:认证阶段 +插件执行优先级:900 + +## 效果说明 + +### 基于 model 参数进行路由 + +需要做如下配置: + +```yaml +modelToHeader: x-higress-llm-model +``` + +插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + +经过这个插件后,将添加下面这个请求头(可以用于路由匹配): + +x-higress-llm-model: qwen-long + +### 提取 model 参数中的 provider 字段用于路由 + +> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider + +需要做如下配置: + +```yaml +addProviderHeader: x-higress-llm-provider +``` + +插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是: + +```json +{ + "model": "dashscope/qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + +经过这个插件后,将添加下面这个请求头(可以用于路由匹配): + +x-higress-llm-provider: dashscope + +原始的 LLM 请求体将被改成: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "higress项目主仓库的github地址是什么" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` diff --git a/plugins/wasm-go/extensions/model-router/README_EN.md b/plugins/wasm-go/extensions/model-router/README_EN.md new file mode 100644 index 000000000..217528a8c --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/README_EN.md @@ -0,0 +1,97 @@ +## Feature Description +The `model-router` plugin implements routing functionality based on the model parameter in LLM protocols. + +## Configuration Fields + +| Name | Data Type | Requirement | Default Value | Description | +| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | +| `modelKey` | string | Optional | model | Location of the model parameter in the request body | +| `addProviderHeader` | string | Optional | - | Which request header to add the provider name parsed from the model parameter | +| `modelToHeader` | string | Optional | - | Which request header to directly add the model parameter to | +| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths | + +## Runtime Properties + +Plugin execution phase: Authentication phase +Plugin execution priority: 900 + +## Effect Description + +### Routing Based on Model Parameter + +The following configuration is needed: + +```yaml +modelToHeader: x-higress-llm-model +``` + +The plugin extracts the model parameter from the request and sets it to the x-higress-llm-model request header for subsequent routing. For example, the original LLM request body is: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the Higress project's main repository?" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + +After processing by this plugin, the following request header will be added (can be used for route matching): + +x-higress-llm-model: qwen-long + +### Extracting Provider Field from Model Parameter for Routing + +> Note that this mode requires the client to specify the provider in the model parameter using the `/` delimiter + +The following configuration is needed: + +```yaml +addProviderHeader: x-higress-llm-provider +``` + +The plugin extracts the provider part (if any) from the model parameter in the request, sets it to the x-higress-llm-provider request header for subsequent routing, and rewrites the model parameter to only contain the model name part. For example, the original LLM request body is: + +```json +{ + "model": "dashscope/qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the Higress project's main repository?" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} +``` + +After processing by this plugin, the following request header will be added (can be used for route matching): + +x-higress-llm-provider: dashscope + +The original LLM request body will be changed to: + +```json +{ + "model": "qwen-long", + "frequency_penalty": 0, + "max_tokens": 800, + "stream": false, + "messages": [{ + "role": "user", + "content": "What is the GitHub address of the Higress project's main repository?" + }], + "presence_penalty": 0, + "temperature": 0.7, + "top_p": 0.95 +} diff --git a/plugins/wasm-go/extensions/model-router/go.mod b/plugins/wasm-go/extensions/model-router/go.mod new file mode 100644 index 000000000..3a0d7c3dc --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/go.mod @@ -0,0 +1,24 @@ +module model-router + +go 1.24.1 + +toolchain go1.24.7 + +require ( + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c + github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/resp v0.1.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/wasm-go/extensions/model-router/go.sum b/plugins/wasm-go/extensions/model-router/go.sum new file mode 100644 index 000000000..9d45243f7 --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y= +github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc= +github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/wasm-go/extensions/model-router/main.go b/plugins/wasm-go/extensions/model-router/main.go new file mode 100644 index 000000000..d67b21956 --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/main.go @@ -0,0 +1,259 @@ +package main + +import ( + "bytes" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB +) + +func main() {} + +func init() { + wrapper.SetCtx( + "model-router", + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.WithRebuildAfterRequests[ModelRouterConfig](1000), + wrapper.WithRebuildMaxMemBytes[ModelRouterConfig](200*1024*1024), + ) +} + +type ModelRouterConfig struct { + modelKey string + addProviderHeader string + modelToHeader string + enableOnPathSuffix []string +} + +func parseConfig(json gjson.Result, config *ModelRouterConfig) error { + config.modelKey = json.Get("modelKey").String() + if config.modelKey == "" { + config.modelKey = "model" + } + config.addProviderHeader = json.Get("addProviderHeader").String() + config.modelToHeader = json.Get("modelToHeader").String() + + enableOnPathSuffix := json.Get("enableOnPathSuffix") + if enableOnPathSuffix.Exists() && enableOnPathSuffix.IsArray() { + for _, item := range enableOnPathSuffix.Array() { + config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String()) + } + } else { + // Default suffixes if not provided + config.enableOnPathSuffix = []string{ + "/completions", + "/embeddings", + "/images/generations", + "/audio/speech", + "/fine_tuning/jobs", + "/moderations", + "/image-synthesis", + "/video-synthesis", + "/rerank", + "/messages", + } + } + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config ModelRouterConfig) types.Action { + path, err := proxywasm.GetHttpRequestHeader(":path") + if err != nil { + return types.ActionContinue + } + + // Remove query parameters for suffix check + if idx := strings.Index(path, "?"); idx != -1 { + path = path[:idx] + } + + enable := false + for _, suffix := range config.enableOnPathSuffix { + if suffix == "*" || strings.HasSuffix(path, suffix) { + enable = true + break + } + } + + if !enable { + ctx.DontReadRequestBody() + return types.ActionContinue + } + + if !ctx.HasRequestBody() { + return types.ActionContinue + } + + // Prepare for body processing + proxywasm.RemoveHttpRequestHeader("content-length") + // 100MB buffer limit + ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes) + + return types.HeaderStopIteration +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action { + contentType, err := proxywasm.GetHttpRequestHeader("content-type") + if err != nil { + return types.ActionContinue + } + + if strings.Contains(contentType, "application/json") { + return handleJsonBody(ctx, config, body) + } else if strings.Contains(contentType, "multipart/form-data") { + return handleMultipartBody(ctx, config, body, contentType) + } + + return types.ActionContinue +} + +func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action { + + modelValue := gjson.GetBytes(body, config.modelKey).String() + if modelValue == "" { + return types.ActionContinue + } + + if config.modelToHeader != "" { + _ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue) + } + + if config.addProviderHeader != "" { + parts := strings.SplitN(modelValue, "/", 2) + if len(parts) == 2 { + provider := parts[0] + model := parts[1] + _ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider) + + newBody, err := sjson.SetBytes(body, config.modelKey, model) + if err != nil { + log.Errorf("failed to update model in json body: %v", err) + return types.ActionContinue + } + _ = proxywasm.ReplaceHttpRequestBody(newBody) + log.Debugf("model route to provider: %s, model: %s", provider, model) + } else { + log.Debugf("model route to provider not work, model: %s", modelValue) + } + } + + return types.ActionContinue +} + +func handleMultipartBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte, contentType string) types.Action { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + log.Errorf("failed to parse content type: %v", err) + return types.ActionContinue + } + boundary, ok := params["boundary"] + if !ok { + log.Errorf("no boundary in content type") + return types.ActionContinue + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + var newBody bytes.Buffer + writer := multipart.NewWriter(&newBody) + writer.SetBoundary(boundary) + + modified := false + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + log.Errorf("failed to read multipart part: %v", err) + return types.ActionContinue + } + + // Read part content + partContent, err := io.ReadAll(part) + if err != nil { + log.Errorf("failed to read part content: %v", err) + return types.ActionContinue + } + + formName := part.FormName() + if formName == config.modelKey { + modelValue := string(partContent) + + if config.modelToHeader != "" { + _ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue) + } + + if config.addProviderHeader != "" { + parts := strings.SplitN(modelValue, "/", 2) + if len(parts) == 2 { + provider := parts[0] + model := parts[1] + _ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider) + + // Write modified part + h := make(http.Header) + for k, v := range part.Header { + h[k] = v + } + + pw, err := writer.CreatePart(textproto.MIMEHeader(h)) + if err != nil { + log.Errorf("failed to create part: %v", err) + return types.ActionContinue + } + _, err = pw.Write([]byte(model)) + if err != nil { + log.Errorf("failed to write part content: %v", err) + return types.ActionContinue + } + modified = true + log.Debugf("model route to provider: %s, model: %s", provider, model) + continue + } else { + log.Debugf("model route to provider not work, model: %s", modelValue) + } + } + } + + // Write original part + h := make(http.Header) + for k, v := range part.Header { + h[k] = v + } + pw, err := writer.CreatePart(textproto.MIMEHeader(h)) + if err != nil { + log.Errorf("failed to create part: %v", err) + return types.ActionContinue + } + _, err = pw.Write(partContent) + if err != nil { + log.Errorf("failed to write part content: %v", err) + return types.ActionContinue + } + } + + writer.Close() + + if modified { + _ = proxywasm.ReplaceHttpRequestBody(newBody.Bytes()) + } + + return types.ActionContinue +} diff --git a/plugins/wasm-go/extensions/model-router/main_test.go b/plugins/wasm-go/extensions/model-router/main_test.go new file mode 100644 index 000000000..9d6263ac8 --- /dev/null +++ b/plugins/wasm-go/extensions/model-router/main_test.go @@ -0,0 +1,288 @@ +package main + +import ( + "bytes" + "encoding/json" + "io" + "mime/multipart" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// Basic configs for wasm test host +var ( + basicConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "addProviderHeader": "x-provider", + "modelToHeader": "x-model", + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + }) + return data + }() + + defaultSuffixConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "addProviderHeader": "x-provider", + "modelToHeader": "x-model", + }) + return data + }() +) + +func getHeader(headers [][2]string, key string) (string, bool) { + for _, h := range headers { + if strings.EqualFold(h[0], key) { + return h[1], true + } + } + return "", false +} + +func TestParseConfig(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic config with defaults", func(t *testing.T) { + var cfg ModelRouterConfig + err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg) + require.NoError(t, err) + + // default modelKey + require.Equal(t, "model", cfg.modelKey) + // headers + require.Equal(t, "x-provider", cfg.addProviderHeader) + require.Equal(t, "x-model", cfg.modelToHeader) + // default enabled path suffixes should contain common openai paths + require.Contains(t, cfg.enableOnPathSuffix, "/completions") + require.Contains(t, cfg.enableOnPathSuffix, "/embeddings") + }) + + t.Run("custom enableOnPathSuffix", func(t *testing.T) { + jsonData := []byte(`{ + "modelKey": "my_model", + "addProviderHeader": "x-prov", + "modelToHeader": "x-mod", + "enableOnPathSuffix": ["/foo", "/bar"] + }`) + var cfg ModelRouterConfig + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + require.Equal(t, "my_model", cfg.modelKey) + require.Equal(t, "x-prov", cfg.addProviderHeader) + require.Equal(t, "x-mod", cfg.modelToHeader) + require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix) + }) + }) +} + +func TestOnHttpRequestHeaders(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("skip when path not matched", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalHeaders := [][2]string{ + {":authority", "example.com"}, + {":path", "/v1/other"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"content-length", "123"}, + } + action := host.CallOnHttpRequestHeaders(originalHeaders) + require.Equal(t, types.ActionContinue, action) + + newHeaders := host.GetRequestHeaders() + _, found := getHeader(newHeaders, "content-length") + require.True(t, found, "content-length should be kept when path not enabled") + }) + + t.Run("process when path and content-type match", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalHeaders := [][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"content-length", "123"}, + } + action := host.CallOnHttpRequestHeaders(originalHeaders) + require.Equal(t, types.HeaderStopIteration, action) + + newHeaders := host.GetRequestHeaders() + _, found := getHeader(newHeaders, "content-length") + require.False(t, found, "content-length should be removed when buffering body") + }) + + t.Run("do not process for unsupported content-type", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalHeaders := [][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "text/plain"}, + {"content-length", "123"}, + } + action := host.CallOnHttpRequestHeaders(originalHeaders) + require.Equal(t, types.HeaderStopIteration, action) + + newHeaders := host.GetRequestHeaders() + _, found := getHeader(newHeaders, "content-length") + require.False(t, found, "content-length should not be removed for unsupported content-type") + }) + }) +} + +func TestOnHttpRequestBody_JSON(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + // model should be rewritten to only the model part + require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String()) + + headers := host.GetRequestHeaders() + hv, found := getHeader(headers, "x-model") + require.True(t, found) + require.Equal(t, "openai/gpt-4o", hv) + pv, found := getHeader(headers, "x-provider") + require.True(t, found) + require.Equal(t, "openai", pv) + }) + + t.Run("no change when model not provided", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + // body should remain nil or unchanged as plugin does nothing + if processed != nil { + require.JSONEq(t, string(origBody), string(processed)) + } + _, found := getHeader(host.GetRequestHeaders(), "x-provider") + require.False(t, found) + }) + }) +} + +func TestOnHttpRequestBody_Multipart(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + // model field + modelWriter, err := writer.CreateFormField("model") + require.NoError(t, err) + _, err = modelWriter.Write([]byte("openai/gpt-4o")) + require.NoError(t, err) + + // another field to ensure others are preserved + fileWriter, err := writer.CreateFormField("prompt") + require.NoError(t, err) + _, err = fileWriter.Write([]byte("hello")) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + contentType := "multipart/form-data; boundary=" + writer.Boundary() + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", contentType}, + }) + + action := host.CallOnHttpRequestBody(buf.Bytes()) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + + // Parse multipart body again to verify fields + reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary()) + + foundModel := false + foundPrompt := false + for { + part, err := reader.NextPart() + if err != nil { + break + } + name := part.FormName() + data, err := io.ReadAll(part) + require.NoError(t, err) + + switch name { + case "model": + foundModel = true + require.Equal(t, "gpt-4o", string(data)) + case "prompt": + foundPrompt = true + require.Equal(t, "hello", string(data)) + } + } + + require.True(t, foundModel) + require.True(t, foundPrompt) + + headers := host.GetRequestHeaders() + hv, found := getHeader(headers, "x-model") + require.True(t, found) + require.Equal(t, "openai/gpt-4o", hv) + pv, found := getHeader(headers, "x-provider") + require.True(t, found) + require.Equal(t, "openai", pv) + }) +}