mirror of
https://github.com/alibaba/higress.git
synced 2026-06-08 20:27:31 +08:00
Replace model-router and model-mapper with Go implementation (#3317)
This commit is contained in:
288
plugins/wasm-go/extensions/model-router/main_test.go
Normal file
288
plugins/wasm-go/extensions/model-router/main_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic configs for wasm test host
|
||||
var (
|
||||
basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"addProviderHeader": "x-provider",
|
||||
"modelToHeader": "x-model",
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
defaultSuffixConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"addProviderHeader": "x-provider",
|
||||
"modelToHeader": "x-model",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
)
|
||||
|
||||
func getHeader(headers [][2]string, key string) (string, bool) {
|
||||
for _, h := range headers {
|
||||
if strings.EqualFold(h[0], key) {
|
||||
return h[1], true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("basic config with defaults", func(t *testing.T) {
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// default modelKey
|
||||
require.Equal(t, "model", cfg.modelKey)
|
||||
// headers
|
||||
require.Equal(t, "x-provider", cfg.addProviderHeader)
|
||||
require.Equal(t, "x-model", cfg.modelToHeader)
|
||||
// default enabled path suffixes should contain common openai paths
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
|
||||
})
|
||||
|
||||
t.Run("custom enableOnPathSuffix", func(t *testing.T) {
|
||||
jsonData := []byte(`{
|
||||
"modelKey": "my_model",
|
||||
"addProviderHeader": "x-prov",
|
||||
"modelToHeader": "x-mod",
|
||||
"enableOnPathSuffix": ["/foo", "/bar"]
|
||||
}`)
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "my_model", cfg.modelKey)
|
||||
require.Equal(t, "x-prov", cfg.addProviderHeader)
|
||||
require.Equal(t, "x-mod", cfg.modelToHeader)
|
||||
require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("skip when path not matched", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/other"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.True(t, found, "content-length should be kept when path not enabled")
|
||||
})
|
||||
|
||||
t.Run("process when path and content-type match", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.False(t, found, "content-length should be removed when buffering body")
|
||||
})
|
||||
|
||||
t.Run("do not process for unsupported content-type", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "text/plain"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.False(t, found, "content-length should not be removed for unsupported content-type")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_JSON(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"model": "openai/gpt-4o",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
// model should be rewritten to only the model part
|
||||
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String())
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai/gpt-4o", hv)
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai", pv)
|
||||
})
|
||||
|
||||
t.Run("no change when model not provided", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
// body should remain nil or unchanged as plugin does nothing
|
||||
if processed != nil {
|
||||
require.JSONEq(t, string(origBody), string(processed))
|
||||
}
|
||||
_, found := getHeader(host.GetRequestHeaders(), "x-provider")
|
||||
require.False(t, found)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_Multipart(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
// model field
|
||||
modelWriter, err := writer.CreateFormField("model")
|
||||
require.NoError(t, err)
|
||||
_, err = modelWriter.Write([]byte("openai/gpt-4o"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// another field to ensure others are preserved
|
||||
fileWriter, err := writer.CreateFormField("prompt")
|
||||
require.NoError(t, err)
|
||||
_, err = fileWriter.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
contentType := "multipart/form-data; boundary=" + writer.Boundary()
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(buf.Bytes())
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
|
||||
// Parse multipart body again to verify fields
|
||||
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
|
||||
|
||||
foundModel := false
|
||||
foundPrompt := false
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
name := part.FormName()
|
||||
data, err := io.ReadAll(part)
|
||||
require.NoError(t, err)
|
||||
|
||||
switch name {
|
||||
case "model":
|
||||
foundModel = true
|
||||
require.Equal(t, "gpt-4o", string(data))
|
||||
case "prompt":
|
||||
foundPrompt = true
|
||||
require.Equal(t, "hello", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, foundModel)
|
||||
require.True(t, foundPrompt)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai/gpt-4o", hv)
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai", pv)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user