mirror of
https://github.com/alibaba/higress.git
synced 2026-03-17 00:40:48 +08:00
Replace model-router and model-mapper with Go implementation (#3317)
This commit is contained in:
2
plugins/wasm-go/extensions/model-mapper/Makefile
Normal file
2
plugins/wasm-go/extensions/model-mapper/Makefile
Normal file
@@ -0,0 +1,2 @@
|
||||
build-go:
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
|
||||
61
plugins/wasm-go/extensions/model-mapper/README.md
Normal file
61
plugins/wasm-go/extensions/model-mapper/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# 功能说明
|
||||
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
|
||||
|
||||
# 配置字段
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
|
||||
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效 |
|
||||
|
||||
|
||||
## 效果说明
|
||||
|
||||
如下配置
|
||||
|
||||
```yaml
|
||||
modelMapping:
|
||||
'gpt-4-*': "qwen-max"
|
||||
'gpt-4o': "qwen-vl-plus"
|
||||
'*': "qwen-turbo"
|
||||
```
|
||||
|
||||
开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
|
||||
|
||||
例如原本的请求是:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
经过这个插件后,原始的 LLM 请求体将被改成:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-vl-plus",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
61
plugins/wasm-go/extensions/model-mapper/README_EN.md
Normal file
61
plugins/wasm-go/extensions/model-mapper/README_EN.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Function Description
|
||||
The `model-mapper` plugin implements model parameter mapping functionality based on the LLM protocol.
|
||||
|
||||
# Configuration Fields
|
||||
|
||||
| Name | Type | Requirement | Default Value | Description |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `modelKey` | string | Optional | model | The position of the model parameter in the request body. |
|
||||
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model name in the request to the model name supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all names starting with "gpt-3-";<br/>2. Supports using "*" as a key to configure a generic fallback mapping;<br/>3. If the target mapping name is an empty string "", it indicates keeping the original model name. |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes. |
|
||||
|
||||
|
||||
## Effect Description
|
||||
|
||||
Configuration example:
|
||||
|
||||
```yaml
|
||||
modelMapping:
|
||||
'gpt-4-*': "qwen-max"
|
||||
'gpt-4o': "qwen-vl-plus"
|
||||
'*': "qwen-turbo"
|
||||
```
|
||||
|
||||
After enabling, model parameters starting with `gpt-4-` will be replaced with `qwen-max`, `gpt-4o` will be replaced with `qwen-vl-plus`, and all other models will be replaced with `qwen-turbo`.
|
||||
|
||||
For example, the original request is:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the github address of the main repository of the higress project"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
After processing by this plugin, the original LLM request body will be modified to:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-vl-plus",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the github address of the main repository of the higress project"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
24
plugins/wasm-go/extensions/model-mapper/go.mod
Normal file
24
plugins/wasm-go/extensions/model-mapper/go.mod
Normal file
@@ -0,0 +1,24 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/model-mapper
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.7
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
30
plugins/wasm-go/extensions/model-mapper/go.sum
Normal file
30
plugins/wasm-go/extensions/model-mapper/go.sum
Normal file
@@ -0,0 +1,30 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
192
plugins/wasm-go/extensions/model-mapper/main.go
Normal file
192
plugins/wasm-go/extensions/model-mapper/main.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"model-mapper",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.WithRebuildAfterRequests[Config](1000),
|
||||
wrapper.WithRebuildMaxMemBytes[Config](200*1024*1024),
|
||||
)
|
||||
}
|
||||
|
||||
type ModelMapping struct {
|
||||
Prefix string
|
||||
Target string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
modelKey string
|
||||
exactModelMapping map[string]string
|
||||
prefixModelMapping []ModelMapping
|
||||
defaultModel string
|
||||
enableOnPathSuffix []string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.modelKey = json.Get("modelKey").String()
|
||||
if config.modelKey == "" {
|
||||
config.modelKey = "model"
|
||||
}
|
||||
|
||||
modelMapping := json.Get("modelMapping")
|
||||
if modelMapping.Exists() && !modelMapping.IsObject() {
|
||||
return errors.New("modelMapping must be an object")
|
||||
}
|
||||
|
||||
config.exactModelMapping = make(map[string]string)
|
||||
config.prefixModelMapping = make([]ModelMapping, 0)
|
||||
|
||||
// To replicate C++ behavior (nlohmann::json iterates keys alphabetically),
|
||||
// we collect entries and sort them by key.
|
||||
type mappingEntry struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
var entries []mappingEntry
|
||||
modelMapping.ForEach(func(key, value gjson.Result) bool {
|
||||
entries = append(entries, mappingEntry{
|
||||
key: key.String(),
|
||||
value: value.String(),
|
||||
})
|
||||
return true
|
||||
})
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].key < entries[j].key
|
||||
})
|
||||
|
||||
for _, entry := range entries {
|
||||
key := entry.key
|
||||
value := entry.value
|
||||
if key == "*" {
|
||||
config.defaultModel = value
|
||||
} else if strings.HasSuffix(key, "*") {
|
||||
prefix := strings.TrimSuffix(key, "*")
|
||||
config.prefixModelMapping = append(config.prefixModelMapping, ModelMapping{
|
||||
Prefix: prefix,
|
||||
Target: value,
|
||||
})
|
||||
} else {
|
||||
config.exactModelMapping[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
enableOnPathSuffix := json.Get("enableOnPathSuffix")
|
||||
if enableOnPathSuffix.Exists() {
|
||||
if !enableOnPathSuffix.IsArray() {
|
||||
return errors.New("enableOnPathSuffix must be an array")
|
||||
}
|
||||
for _, item := range enableOnPathSuffix.Array() {
|
||||
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
|
||||
}
|
||||
} else {
|
||||
config.enableOnPathSuffix = []string{
|
||||
"/completions",
|
||||
"/embeddings",
|
||||
"/images/generations",
|
||||
"/audio/speech",
|
||||
"/fine_tuning/jobs",
|
||||
"/moderations",
|
||||
"/image-synthesis",
|
||||
"/video-synthesis",
|
||||
"/rerank",
|
||||
"/messages",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
// Check path suffix
|
||||
path, err := proxywasm.GetHttpRequestHeader(":path")
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Strip query parameters
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
path = path[:idx]
|
||||
}
|
||||
|
||||
matched := false
|
||||
for _, suffix := range config.enableOnPathSuffix {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if !ctx.HasRequestBody() {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Prepare for body processing
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
// 100MB buffer limit
|
||||
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
if len(body) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
oldModel := gjson.GetBytes(body, config.modelKey).String()
|
||||
|
||||
newModel := config.defaultModel
|
||||
if newModel == "" {
|
||||
newModel = oldModel
|
||||
}
|
||||
|
||||
// Exact match
|
||||
if target, ok := config.exactModelMapping[oldModel]; ok {
|
||||
newModel = target
|
||||
} else {
|
||||
// Prefix match
|
||||
for _, mapping := range config.prefixModelMapping {
|
||||
if strings.HasPrefix(oldModel, mapping.Prefix) {
|
||||
newModel = mapping.Target
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if newModel != "" && newModel != oldModel {
|
||||
newBody, err := sjson.SetBytes(body, config.modelKey, newModel)
|
||||
if err != nil {
|
||||
log.Errorf("failed to update model: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
log.Debugf("model mapped, before: %s, after: %s", oldModel, newModel)
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
250
plugins/wasm-go/extensions/model-mapper/main_test.go
Normal file
250
plugins/wasm-go/extensions/model-mapper/main_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic configs for wasm test host
|
||||
var (
|
||||
basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
},
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
customConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "request.model",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4o",
|
||||
"gpt-3.5*": "gpt-4-mini",
|
||||
"gpt-3.5-t": "gpt-4-turbo",
|
||||
"gpt-3.5-t1": "gpt-4-turbo-1",
|
||||
},
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/embeddings",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("basic config with defaults", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelMapping": {
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
"gpt-4*": "gpt-4o-mini",
|
||||
"*": "gpt-4o"
|
||||
}
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// default modelKey
|
||||
require.Equal(t, "model", cfg.modelKey)
|
||||
// exact mapping
|
||||
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
|
||||
// prefix mapping
|
||||
require.Len(t, cfg.prefixModelMapping, 1)
|
||||
require.Equal(t, "gpt-4", cfg.prefixModelMapping[0].Prefix)
|
||||
// default model
|
||||
require.Equal(t, "gpt-4o", cfg.defaultModel)
|
||||
// default enabled path suffixes
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
|
||||
})
|
||||
|
||||
t.Run("custom modelKey and enableOnPathSuffix", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelKey": "request.model",
|
||||
"modelMapping": {
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
"gpt-3.5*": "gpt-4-mini"
|
||||
},
|
||||
"enableOnPathSuffix": ["/v1/chat/completions", "/v1/embeddings"]
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "request.model", cfg.modelKey)
|
||||
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
|
||||
require.Len(t, cfg.prefixModelMapping, 1)
|
||||
require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix)
|
||||
require.Equal(t, "gpt-4-mini", cfg.prefixModelMapping[0].Target)
|
||||
require.Equal(t, 2, len(cfg.enableOnPathSuffix))
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/v1/chat/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/v1/embeddings")
|
||||
})
|
||||
|
||||
t.Run("modelMapping must be object", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelMapping": "invalid"
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("enableOnPathSuffix must be array", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"enableOnPathSuffix": "not-array"
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("skip when path not matched", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/other"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should still exist because path is not enabled
|
||||
foundContentLength := false
|
||||
for _, h := range newHeaders {
|
||||
if strings.ToLower(h[0]) == "content-length" {
|
||||
foundContentLength = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundContentLength)
|
||||
})
|
||||
|
||||
t.Run("process when path and content-type match", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should be removed
|
||||
for _, h := range newHeaders {
|
||||
require.NotEqual(t, strings.ToLower(h[0]), "content-length")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("exact mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
|
||||
})
|
||||
|
||||
t.Run("default model when key missing", func(t *testing.T) {
|
||||
// use customConfig where default model is set with "*"
|
||||
host, status := test.NewTestHost(customConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"request": {
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
// default model should be set at request.model
|
||||
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String())
|
||||
})
|
||||
|
||||
t.Run("prefix mapping takes effect", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"request": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String())
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user