mirror of
https://github.com/alibaba/higress.git
synced 2026-03-04 00:20:50 +08:00
693 lines
21 KiB
Go
693 lines
21 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"mime/multipart"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/higress-group/wasm-go/pkg/test"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
// Basic configs for wasm test host
|
|
var (
|
|
basicConfig = func() json.RawMessage {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"modelKey": "model",
|
|
"addProviderHeader": "x-provider",
|
|
"modelToHeader": "x-model",
|
|
"enableOnPathSuffix": []string{
|
|
"/v1/chat/completions",
|
|
},
|
|
})
|
|
return data
|
|
}()
|
|
|
|
defaultSuffixConfig = func() json.RawMessage {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"modelKey": "model",
|
|
"addProviderHeader": "x-provider",
|
|
"modelToHeader": "x-model",
|
|
})
|
|
return data
|
|
}()
|
|
)
|
|
|
|
func getHeader(headers [][2]string, key string) (string, bool) {
|
|
for _, h := range headers {
|
|
if strings.EqualFold(h[0], key) {
|
|
return h[1], true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func TestParseConfig(t *testing.T) {
|
|
test.RunGoTest(t, func(t *testing.T) {
|
|
t.Run("basic config with defaults", func(t *testing.T) {
|
|
var cfg ModelRouterConfig
|
|
err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg)
|
|
require.NoError(t, err)
|
|
|
|
// default modelKey
|
|
require.Equal(t, "model", cfg.modelKey)
|
|
// headers
|
|
require.Equal(t, "x-provider", cfg.addProviderHeader)
|
|
require.Equal(t, "x-model", cfg.modelToHeader)
|
|
// default enabled path suffixes should contain common openai paths
|
|
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
|
|
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
|
|
})
|
|
|
|
t.Run("custom enableOnPathSuffix", func(t *testing.T) {
|
|
jsonData := []byte(`{
|
|
"modelKey": "my_model",
|
|
"addProviderHeader": "x-prov",
|
|
"modelToHeader": "x-mod",
|
|
"enableOnPathSuffix": ["/foo", "/bar"]
|
|
}`)
|
|
var cfg ModelRouterConfig
|
|
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, "my_model", cfg.modelKey)
|
|
require.Equal(t, "x-prov", cfg.addProviderHeader)
|
|
require.Equal(t, "x-mod", cfg.modelToHeader)
|
|
require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestOnHttpRequestHeaders(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
t.Run("skip when path not matched", func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
originalHeaders := [][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/other"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
{"content-length", "123"},
|
|
}
|
|
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
newHeaders := host.GetRequestHeaders()
|
|
_, found := getHeader(newHeaders, "content-length")
|
|
require.True(t, found, "content-length should be kept when path not enabled")
|
|
})
|
|
|
|
t.Run("process when path and content-type match", func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
originalHeaders := [][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
{"content-length", "123"},
|
|
}
|
|
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
|
require.Equal(t, types.HeaderStopIteration, action)
|
|
|
|
newHeaders := host.GetRequestHeaders()
|
|
_, found := getHeader(newHeaders, "content-length")
|
|
require.False(t, found, "content-length should be removed when buffering body")
|
|
})
|
|
|
|
t.Run("do not process for unsupported content-type", func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
originalHeaders := [][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "text/plain"},
|
|
{"content-length", "123"},
|
|
}
|
|
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
|
require.Equal(t, types.HeaderStopIteration, action)
|
|
|
|
newHeaders := host.GetRequestHeaders()
|
|
_, found := getHeader(newHeaders, "content-length")
|
|
require.False(t, found, "content-length should not be removed for unsupported content-type")
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestOnHttpRequestBody_JSON(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
origBody := []byte(`{
|
|
"model": "openai/gpt-4o",
|
|
"messages": [{"role": "user", "content": "hello"}]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(origBody)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
processed := host.GetRequestBody()
|
|
require.NotNil(t, processed)
|
|
// model should be rewritten to only the model part
|
|
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String())
|
|
|
|
headers := host.GetRequestHeaders()
|
|
hv, found := getHeader(headers, "x-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "openai/gpt-4o", hv)
|
|
pv, found := getHeader(headers, "x-provider")
|
|
require.True(t, found)
|
|
require.Equal(t, "openai", pv)
|
|
})
|
|
|
|
t.Run("no change when model not provided", func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
origBody := []byte(`{
|
|
"messages": [{"role": "user", "content": "hello"}]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(origBody)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
processed := host.GetRequestBody()
|
|
// body should remain nil or unchanged as plugin does nothing
|
|
if processed != nil {
|
|
require.JSONEq(t, string(origBody), string(processed))
|
|
}
|
|
_, found := getHeader(host.GetRequestHeaders(), "x-provider")
|
|
require.False(t, found)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestOnHttpRequestBody_Multipart(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
host, status := test.NewTestHost(basicConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
var buf bytes.Buffer
|
|
writer := multipart.NewWriter(&buf)
|
|
|
|
// model field
|
|
modelWriter, err := writer.CreateFormField("model")
|
|
require.NoError(t, err)
|
|
_, err = modelWriter.Write([]byte("openai/gpt-4o"))
|
|
require.NoError(t, err)
|
|
|
|
// another field to ensure others are preserved
|
|
fileWriter, err := writer.CreateFormField("prompt")
|
|
require.NoError(t, err)
|
|
_, err = fileWriter.Write([]byte("hello"))
|
|
require.NoError(t, err)
|
|
|
|
err = writer.Close()
|
|
require.NoError(t, err)
|
|
|
|
contentType := "multipart/form-data; boundary=" + writer.Boundary()
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", contentType},
|
|
})
|
|
|
|
action := host.CallOnHttpRequestBody(buf.Bytes())
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
processed := host.GetRequestBody()
|
|
require.NotNil(t, processed)
|
|
|
|
// Parse multipart body again to verify fields
|
|
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
|
|
|
|
foundModel := false
|
|
foundPrompt := false
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err != nil {
|
|
break
|
|
}
|
|
name := part.FormName()
|
|
data, err := io.ReadAll(part)
|
|
require.NoError(t, err)
|
|
|
|
switch name {
|
|
case "model":
|
|
foundModel = true
|
|
require.Equal(t, "gpt-4o", string(data))
|
|
case "prompt":
|
|
foundPrompt = true
|
|
require.Equal(t, "hello", string(data))
|
|
}
|
|
}
|
|
|
|
require.True(t, foundModel)
|
|
require.True(t, foundPrompt)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
hv, found := getHeader(headers, "x-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "openai/gpt-4o", hv)
|
|
pv, found := getHeader(headers, "x-provider")
|
|
require.True(t, found)
|
|
require.Equal(t, "openai", pv)
|
|
})
|
|
}
|
|
|
|
// Auto routing config for tests
|
|
var autoRoutingConfig = func() json.RawMessage {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"modelKey": "model",
|
|
"modelToHeader": "x-model",
|
|
"enableOnPathSuffix": []string{
|
|
"/v1/chat/completions",
|
|
},
|
|
"autoRouting": map[string]interface{}{
|
|
"enable": true,
|
|
"defaultModel": "qwen-turbo",
|
|
"rules": []map[string]string{
|
|
{"pattern": "(?i)(画|绘|生成图|图片|image|draw|paint)", "model": "qwen-vl-max"},
|
|
{"pattern": "(?i)(代码|编程|code|program|function|debug)", "model": "qwen-coder"},
|
|
{"pattern": "(?i)(翻译|translate|translation)", "model": "qwen-turbo"},
|
|
{"pattern": "(?i)(数学|计算|math|calculate)", "model": "qwen-math"},
|
|
},
|
|
},
|
|
})
|
|
return data
|
|
}()
|
|
|
|
var autoRoutingNoDefaultConfig = func() json.RawMessage {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"modelKey": "model",
|
|
"modelToHeader": "x-model",
|
|
"enableOnPathSuffix": []string{
|
|
"/v1/chat/completions",
|
|
},
|
|
"autoRouting": map[string]interface{}{
|
|
"enable": true,
|
|
"rules": []map[string]string{
|
|
{"pattern": "(?i)(画|绘)", "model": "qwen-vl-max"},
|
|
},
|
|
},
|
|
})
|
|
return data
|
|
}()
|
|
|
|
func TestParseConfigAutoRouting(t *testing.T) {
|
|
test.RunGoTest(t, func(t *testing.T) {
|
|
t.Run("parse auto routing config", func(t *testing.T) {
|
|
var cfg ModelRouterConfig
|
|
err := parseConfig(gjson.ParseBytes(autoRoutingConfig), &cfg)
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, cfg.enableAutoRouting)
|
|
require.Equal(t, "qwen-turbo", cfg.defaultModel)
|
|
require.Len(t, cfg.autoRoutingRules, 4)
|
|
|
|
// Verify first rule
|
|
require.Equal(t, "qwen-vl-max", cfg.autoRoutingRules[0].Model)
|
|
require.NotNil(t, cfg.autoRoutingRules[0].Pattern)
|
|
})
|
|
|
|
t.Run("skip invalid regex patterns", func(t *testing.T) {
|
|
jsonData := []byte(`{
|
|
"autoRouting": {
|
|
"enable": true,
|
|
"rules": [
|
|
{"pattern": "[invalid", "model": "model1"},
|
|
{"pattern": "valid", "model": "model2"}
|
|
]
|
|
}
|
|
}`)
|
|
var cfg ModelRouterConfig
|
|
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
|
require.NoError(t, err)
|
|
|
|
// Only valid rule should be parsed
|
|
require.Len(t, cfg.autoRoutingRules, 1)
|
|
require.Equal(t, "model2", cfg.autoRoutingRules[0].Model)
|
|
})
|
|
|
|
t.Run("skip rules with empty pattern or model", func(t *testing.T) {
|
|
jsonData := []byte(`{
|
|
"autoRouting": {
|
|
"enable": true,
|
|
"rules": [
|
|
{"pattern": "", "model": "model1"},
|
|
{"pattern": "test", "model": ""},
|
|
{"pattern": "valid", "model": "model2"}
|
|
]
|
|
}
|
|
}`)
|
|
var cfg ModelRouterConfig
|
|
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, cfg.autoRoutingRules, 1)
|
|
require.Equal(t, "model2", cfg.autoRoutingRules[0].Model)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestExtractLastUserMessage(t *testing.T) {
|
|
test.RunGoTest(t, func(t *testing.T) {
|
|
t.Run("extract from simple string content", func(t *testing.T) {
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
{"role": "user", "content": "Hello, how are you?"},
|
|
{"role": "assistant", "content": "I am fine"},
|
|
{"role": "user", "content": "Please draw a cat"}
|
|
]
|
|
}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "Please draw a cat", result)
|
|
})
|
|
|
|
t.Run("extract from array content (multimodal)", func(t *testing.T) {
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": [
|
|
{"type": "text", "text": "What is in this image?"},
|
|
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
|
|
]}
|
|
]
|
|
}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "What is in this image?", result)
|
|
})
|
|
|
|
t.Run("extract last text from array with multiple text items", func(t *testing.T) {
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": [
|
|
{"type": "text", "text": "First text"},
|
|
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
|
{"type": "text", "text": "Second text about drawing"}
|
|
]}
|
|
]
|
|
}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "Second text about drawing", result)
|
|
})
|
|
|
|
t.Run("return empty when no messages", func(t *testing.T) {
|
|
body := []byte(`{"model": "higress/auto"}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "", result)
|
|
})
|
|
|
|
t.Run("return empty when no user messages", func(t *testing.T) {
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
{"role": "assistant", "content": "Hello!"}
|
|
]
|
|
}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "", result)
|
|
})
|
|
|
|
t.Run("handle multiple user messages", func(t *testing.T) {
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": "First question"},
|
|
{"role": "assistant", "content": "First answer"},
|
|
{"role": "user", "content": "帮我写一段代码"}
|
|
]
|
|
}`)
|
|
result := extractLastUserMessage(body)
|
|
require.Equal(t, "帮我写一段代码", result)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestMatchAutoRoutingRule(t *testing.T) {
|
|
test.RunGoTest(t, func(t *testing.T) {
|
|
config := ModelRouterConfig{
|
|
autoRoutingRules: []AutoRoutingRule{
|
|
{Pattern: regexp.MustCompile(`(?i)(画|绘|图片)`), Model: "qwen-vl-max"},
|
|
{Pattern: regexp.MustCompile(`(?i)(代码|编程|code)`), Model: "qwen-coder"},
|
|
{Pattern: regexp.MustCompile(`(?i)(数学|计算)`), Model: "qwen-math"},
|
|
},
|
|
}
|
|
|
|
t.Run("match drawing keywords", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "请帮我画一只猫")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-vl-max", model)
|
|
})
|
|
|
|
t.Run("match code keywords", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "Write a Python code to sort a list")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-coder", model)
|
|
})
|
|
|
|
t.Run("match Chinese code keywords", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "帮我写一段编程代码")
|
|
require.True(t, found)
|
|
// First matching rule wins (代码 matches first rule with 代码)
|
|
require.Equal(t, "qwen-coder", model)
|
|
})
|
|
|
|
t.Run("match math keywords", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "计算123+456等于多少")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-math", model)
|
|
})
|
|
|
|
t.Run("no match returns false", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "今天天气怎么样?")
|
|
require.False(t, found)
|
|
require.Equal(t, "", model)
|
|
})
|
|
|
|
t.Run("case insensitive matching", func(t *testing.T) {
|
|
model, found := matchAutoRoutingRule(config, "Write some CODE for me")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-coder", model)
|
|
})
|
|
|
|
t.Run("first matching rule wins", func(t *testing.T) {
|
|
// Message contains both "图片" and "代码"
|
|
model, found := matchAutoRoutingRule(config, "生成一张图片的代码")
|
|
require.True(t, found)
|
|
// "图片" rule comes first
|
|
require.Equal(t, "qwen-vl-max", model)
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestAutoRoutingIntegration(t *testing.T) {
|
|
test.RunTest(t, func(t *testing.T) {
|
|
t.Run("auto routing with matching rule", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
{"role": "user", "content": "请帮我画一只可爱的小猫"}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
|
require.True(t, found, "x-higress-llm-model header should be set")
|
|
require.Equal(t, "qwen-vl-max", modelHeader)
|
|
})
|
|
|
|
t.Run("auto routing with code keywords", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": "Write a function to calculate fibonacci numbers"}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-coder", modelHeader)
|
|
})
|
|
|
|
t.Run("auto routing falls back to default model", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": "今天天气怎么样?"}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-turbo", modelHeader)
|
|
})
|
|
|
|
t.Run("auto routing no default model configured", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingNoDefaultConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": "今天天气怎么样?"}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
_, found := getHeader(headers, "x-higress-llm-model")
|
|
require.False(t, found, "x-higress-llm-model should not be set when no rule matches and no default")
|
|
})
|
|
|
|
t.Run("normal routing when model is not higress/auto", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "qwen-long",
|
|
"messages": [
|
|
{"role": "user", "content": "请帮我画一只猫"}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
modelHeader, found := getHeader(headers, "x-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-long", modelHeader)
|
|
|
|
// x-higress-llm-model should NOT be set (auto routing not triggered)
|
|
_, found = getHeader(headers, "x-higress-llm-model")
|
|
require.False(t, found)
|
|
})
|
|
|
|
t.Run("auto routing with multimodal content", func(t *testing.T) {
|
|
host, status := test.NewTestHost(autoRoutingConfig)
|
|
defer host.Reset()
|
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
|
|
|
host.CallOnHttpRequestHeaders([][2]string{
|
|
{":authority", "example.com"},
|
|
{":path", "/v1/chat/completions"},
|
|
{":method", "POST"},
|
|
{"content-type", "application/json"},
|
|
})
|
|
|
|
body := []byte(`{
|
|
"model": "higress/auto",
|
|
"messages": [
|
|
{"role": "user", "content": [
|
|
{"type": "text", "text": "帮我翻译这段话"},
|
|
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
|
|
]}
|
|
]
|
|
}`)
|
|
action := host.CallOnHttpRequestBody(body)
|
|
require.Equal(t, types.ActionContinue, action)
|
|
|
|
headers := host.GetRequestHeaders()
|
|
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
|
require.True(t, found)
|
|
require.Equal(t, "qwen-turbo", modelHeader) // matches 翻译 rule
|
|
})
|
|
})
|
|
}
|