Replace model-router and model-mapper with Go implementation (#3317)

This commit is contained in:
rinfx
2026-01-13 20:14:29 +08:00
committed by jingze
parent 032a69556f
commit e23ab3ca7c
14 changed files with 1418 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
build-go:
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go

View File

@@ -0,0 +1,98 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 |
## 运行属性
插件执行阶段:认证阶段
插件执行优先级900
## 效果说明
### 基于 model 参数进行路由
需要做如下配置:
```yaml
modelToHeader: x-higress-llm-model
```
插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-model: qwen-long
### 提取 model 参数中的 provider 字段用于路由
> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
需要做如下配置:
```yaml
addProviderHeader: x-higress-llm-provider
```
插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: dashscope
原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,97 @@
## Feature Description
The `model-router` plugin implements routing functionality based on the model parameter in LLM protocols.
## Configuration Fields
| Name | Data Type | Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | Location of the model parameter in the request body |
| `addProviderHeader` | string | Optional | - | Which request header to add the provider name parsed from the model parameter |
| `modelToHeader` | string | Optional | - | Which request header to directly add the model parameter to |
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths |
## Runtime Properties
Plugin execution phase: Authentication phase
Plugin execution priority: 900
## Effect Description
### Routing Based on Model Parameter
The following configuration is needed:
```yaml
modelToHeader: x-higress-llm-model
```
The plugin extracts the model parameter from the request and sets it to the x-higress-llm-model request header for subsequent routing. For example, the original LLM request body is:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header will be added (can be used for route matching):
x-higress-llm-model: qwen-long
### Extracting Provider Field from Model Parameter for Routing
> Note that this mode requires the client to specify the provider in the model parameter using the `/` delimiter
The following configuration is needed:
```yaml
addProviderHeader: x-higress-llm-provider
```
The plugin extracts the provider part (if any) from the model parameter in the request, sets it to the x-higress-llm-provider request header for subsequent routing, and rewrites the model parameter to only contain the model name part. For example, the original LLM request body is:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header will be added (can be used for route matching):
x-higress-llm-provider: dashscope
The original LLM request body will be changed to:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}

View File

@@ -0,0 +1,24 @@
module model-router
go 1.24.1
toolchain go1.24.7
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,30 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,259 @@
package main
import (
"bytes"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
)
func main() {}
func init() {
wrapper.SetCtx(
"model-router",
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.WithRebuildAfterRequests[ModelRouterConfig](1000),
wrapper.WithRebuildMaxMemBytes[ModelRouterConfig](200*1024*1024),
)
}
type ModelRouterConfig struct {
modelKey string
addProviderHeader string
modelToHeader string
enableOnPathSuffix []string
}
func parseConfig(json gjson.Result, config *ModelRouterConfig) error {
config.modelKey = json.Get("modelKey").String()
if config.modelKey == "" {
config.modelKey = "model"
}
config.addProviderHeader = json.Get("addProviderHeader").String()
config.modelToHeader = json.Get("modelToHeader").String()
enableOnPathSuffix := json.Get("enableOnPathSuffix")
if enableOnPathSuffix.Exists() && enableOnPathSuffix.IsArray() {
for _, item := range enableOnPathSuffix.Array() {
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
}
} else {
// Default suffixes if not provided
config.enableOnPathSuffix = []string{
"/completions",
"/embeddings",
"/images/generations",
"/audio/speech",
"/fine_tuning/jobs",
"/moderations",
"/image-synthesis",
"/video-synthesis",
"/rerank",
"/messages",
}
}
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ModelRouterConfig) types.Action {
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
return types.ActionContinue
}
// Remove query parameters for suffix check
if idx := strings.Index(path, "?"); idx != -1 {
path = path[:idx]
}
enable := false
for _, suffix := range config.enableOnPathSuffix {
if suffix == "*" || strings.HasSuffix(path, suffix) {
enable = true
break
}
}
if !enable {
ctx.DontReadRequestBody()
return types.ActionContinue
}
if !ctx.HasRequestBody() {
return types.ActionContinue
}
// Prepare for body processing
proxywasm.RemoveHttpRequestHeader("content-length")
// 100MB buffer limit
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
contentType, err := proxywasm.GetHttpRequestHeader("content-type")
if err != nil {
return types.ActionContinue
}
if strings.Contains(contentType, "application/json") {
return handleJsonBody(ctx, config, body)
} else if strings.Contains(contentType, "multipart/form-data") {
return handleMultipartBody(ctx, config, body, contentType)
}
return types.ActionContinue
}
func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
modelValue := gjson.GetBytes(body, config.modelKey).String()
if modelValue == "" {
return types.ActionContinue
}
if config.modelToHeader != "" {
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
}
if config.addProviderHeader != "" {
parts := strings.SplitN(modelValue, "/", 2)
if len(parts) == 2 {
provider := parts[0]
model := parts[1]
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
newBody, err := sjson.SetBytes(body, config.modelKey, model)
if err != nil {
log.Errorf("failed to update model in json body: %v", err)
return types.ActionContinue
}
_ = proxywasm.ReplaceHttpRequestBody(newBody)
log.Debugf("model route to provider: %s, model: %s", provider, model)
} else {
log.Debugf("model route to provider not work, model: %s", modelValue)
}
}
return types.ActionContinue
}
func handleMultipartBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte, contentType string) types.Action {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
log.Errorf("failed to parse content type: %v", err)
return types.ActionContinue
}
boundary, ok := params["boundary"]
if !ok {
log.Errorf("no boundary in content type")
return types.ActionContinue
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
var newBody bytes.Buffer
writer := multipart.NewWriter(&newBody)
writer.SetBoundary(boundary)
modified := false
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
log.Errorf("failed to read multipart part: %v", err)
return types.ActionContinue
}
// Read part content
partContent, err := io.ReadAll(part)
if err != nil {
log.Errorf("failed to read part content: %v", err)
return types.ActionContinue
}
formName := part.FormName()
if formName == config.modelKey {
modelValue := string(partContent)
if config.modelToHeader != "" {
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
}
if config.addProviderHeader != "" {
parts := strings.SplitN(modelValue, "/", 2)
if len(parts) == 2 {
provider := parts[0]
model := parts[1]
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
// Write modified part
h := make(http.Header)
for k, v := range part.Header {
h[k] = v
}
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
if err != nil {
log.Errorf("failed to create part: %v", err)
return types.ActionContinue
}
_, err = pw.Write([]byte(model))
if err != nil {
log.Errorf("failed to write part content: %v", err)
return types.ActionContinue
}
modified = true
log.Debugf("model route to provider: %s, model: %s", provider, model)
continue
} else {
log.Debugf("model route to provider not work, model: %s", modelValue)
}
}
}
// Write original part
h := make(http.Header)
for k, v := range part.Header {
h[k] = v
}
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
if err != nil {
log.Errorf("failed to create part: %v", err)
return types.ActionContinue
}
_, err = pw.Write(partContent)
if err != nil {
log.Errorf("failed to write part content: %v", err)
return types.ActionContinue
}
}
writer.Close()
if modified {
_ = proxywasm.ReplaceHttpRequestBody(newBody.Bytes())
}
return types.ActionContinue
}

View File

@@ -0,0 +1,288 @@
package main
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// Basic configs for wasm test host
var (
basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "model",
"addProviderHeader": "x-provider",
"modelToHeader": "x-model",
"enableOnPathSuffix": []string{
"/v1/chat/completions",
},
})
return data
}()
defaultSuffixConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "model",
"addProviderHeader": "x-provider",
"modelToHeader": "x-model",
})
return data
}()
)
func getHeader(headers [][2]string, key string) (string, bool) {
for _, h := range headers {
if strings.EqualFold(h[0], key) {
return h[1], true
}
}
return "", false
}
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
t.Run("basic config with defaults", func(t *testing.T) {
var cfg ModelRouterConfig
err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg)
require.NoError(t, err)
// default modelKey
require.Equal(t, "model", cfg.modelKey)
// headers
require.Equal(t, "x-provider", cfg.addProviderHeader)
require.Equal(t, "x-model", cfg.modelToHeader)
// default enabled path suffixes should contain common openai paths
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
})
t.Run("custom enableOnPathSuffix", func(t *testing.T) {
jsonData := []byte(`{
"modelKey": "my_model",
"addProviderHeader": "x-prov",
"modelToHeader": "x-mod",
"enableOnPathSuffix": ["/foo", "/bar"]
}`)
var cfg ModelRouterConfig
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.NoError(t, err)
require.Equal(t, "my_model", cfg.modelKey)
require.Equal(t, "x-prov", cfg.addProviderHeader)
require.Equal(t, "x-mod", cfg.modelToHeader)
require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("skip when path not matched", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/other"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.True(t, found, "content-length should be kept when path not enabled")
})
t.Run("process when path and content-type match", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.HeaderStopIteration, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.False(t, found, "content-length should be removed when buffering body")
})
t.Run("do not process for unsupported content-type", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "text/plain"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.HeaderStopIteration, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.False(t, found, "content-length should not be removed for unsupported content-type")
})
})
}
func TestOnHttpRequestBody_JSON(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"model": "openai/gpt-4o",
"messages": [{"role": "user", "content": "hello"}]
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
// model should be rewritten to only the model part
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String())
headers := host.GetRequestHeaders()
hv, found := getHeader(headers, "x-model")
require.True(t, found)
require.Equal(t, "openai/gpt-4o", hv)
pv, found := getHeader(headers, "x-provider")
require.True(t, found)
require.Equal(t, "openai", pv)
})
t.Run("no change when model not provided", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"messages": [{"role": "user", "content": "hello"}]
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
// body should remain nil or unchanged as plugin does nothing
if processed != nil {
require.JSONEq(t, string(origBody), string(processed))
}
_, found := getHeader(host.GetRequestHeaders(), "x-provider")
require.False(t, found)
})
})
}
func TestOnHttpRequestBody_Multipart(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// model field
modelWriter, err := writer.CreateFormField("model")
require.NoError(t, err)
_, err = modelWriter.Write([]byte("openai/gpt-4o"))
require.NoError(t, err)
// another field to ensure others are preserved
fileWriter, err := writer.CreateFormField("prompt")
require.NoError(t, err)
_, err = fileWriter.Write([]byte("hello"))
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
contentType := "multipart/form-data; boundary=" + writer.Boundary()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", contentType},
})
action := host.CallOnHttpRequestBody(buf.Bytes())
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
// Parse multipart body again to verify fields
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
foundModel := false
foundPrompt := false
for {
part, err := reader.NextPart()
if err != nil {
break
}
name := part.FormName()
data, err := io.ReadAll(part)
require.NoError(t, err)
switch name {
case "model":
foundModel = true
require.Equal(t, "gpt-4o", string(data))
case "prompt":
foundPrompt = true
require.Equal(t, "hello", string(data))
}
}
require.True(t, foundModel)
require.True(t, foundPrompt)
headers := host.GetRequestHeaders()
hv, found := getHeader(headers, "x-model")
require.True(t, found)
require.Equal(t, "openai/gpt-4o", hv)
pv, found := getHeader(headers, "x-provider")
require.True(t, found)
require.Equal(t, "openai", pv)
})
}