mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
feat(model-router): add auto routing based on user message content (#3403)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -286,3 +287,406 @@ func TestOnHttpRequestBody_Multipart(t *testing.T) {
|
||||
require.Equal(t, "openai", pv)
|
||||
})
|
||||
}
|
||||
|
||||
// Auto routing config for tests
|
||||
var autoRoutingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"modelToHeader": "x-model",
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
"autoRouting": map[string]interface{}{
|
||||
"enable": true,
|
||||
"defaultModel": "qwen-turbo",
|
||||
"rules": []map[string]string{
|
||||
{"pattern": "(?i)(画|绘|生成图|图片|image|draw|paint)", "model": "qwen-vl-max"},
|
||||
{"pattern": "(?i)(代码|编程|code|program|function|debug)", "model": "qwen-coder"},
|
||||
{"pattern": "(?i)(翻译|translate|translation)", "model": "qwen-turbo"},
|
||||
{"pattern": "(?i)(数学|计算|math|calculate)", "model": "qwen-math"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
var autoRoutingNoDefaultConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"modelToHeader": "x-model",
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
"autoRouting": map[string]interface{}{
|
||||
"enable": true,
|
||||
"rules": []map[string]string{
|
||||
{"pattern": "(?i)(画|绘)", "model": "qwen-vl-max"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfigAutoRouting(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("parse auto routing config", func(t *testing.T) {
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(autoRoutingConfig), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, cfg.enableAutoRouting)
|
||||
require.Equal(t, "qwen-turbo", cfg.defaultModel)
|
||||
require.Len(t, cfg.autoRoutingRules, 4)
|
||||
|
||||
// Verify first rule
|
||||
require.Equal(t, "qwen-vl-max", cfg.autoRoutingRules[0].Model)
|
||||
require.NotNil(t, cfg.autoRoutingRules[0].Pattern)
|
||||
})
|
||||
|
||||
t.Run("skip invalid regex patterns", func(t *testing.T) {
|
||||
jsonData := []byte(`{
|
||||
"autoRouting": {
|
||||
"enable": true,
|
||||
"rules": [
|
||||
{"pattern": "[invalid", "model": "model1"},
|
||||
{"pattern": "valid", "model": "model2"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only valid rule should be parsed
|
||||
require.Len(t, cfg.autoRoutingRules, 1)
|
||||
require.Equal(t, "model2", cfg.autoRoutingRules[0].Model)
|
||||
})
|
||||
|
||||
t.Run("skip rules with empty pattern or model", func(t *testing.T) {
|
||||
jsonData := []byte(`{
|
||||
"autoRouting": {
|
||||
"enable": true,
|
||||
"rules": [
|
||||
{"pattern": "", "model": "model1"},
|
||||
{"pattern": "test", "model": ""},
|
||||
{"pattern": "valid", "model": "model2"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, cfg.autoRoutingRules, 1)
|
||||
require.Equal(t, "model2", cfg.autoRoutingRules[0].Model)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractLastUserMessage(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("extract from simple string content", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I am fine"},
|
||||
{"role": "user", "content": "Please draw a cat"}
|
||||
]
|
||||
}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "Please draw a cat", result)
|
||||
})
|
||||
|
||||
t.Run("extract from array content (multimodal)", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "What is in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "What is in this image?", result)
|
||||
})
|
||||
|
||||
t.Run("extract last text from array with multiple text items", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "First text"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
||||
{"type": "text", "text": "Second text about drawing"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "Second text about drawing", result)
|
||||
})
|
||||
|
||||
t.Run("return empty when no messages", func(t *testing.T) {
|
||||
body := []byte(`{"model": "higress/auto"}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("return empty when no user messages", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "assistant", "content": "Hello!"}
|
||||
]
|
||||
}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("handle multiple user messages", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": "First question"},
|
||||
{"role": "assistant", "content": "First answer"},
|
||||
{"role": "user", "content": "帮我写一段代码"}
|
||||
]
|
||||
}`)
|
||||
result := extractLastUserMessage(body)
|
||||
require.Equal(t, "帮我写一段代码", result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchAutoRoutingRule(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
config := ModelRouterConfig{
|
||||
autoRoutingRules: []AutoRoutingRule{
|
||||
{Pattern: regexp.MustCompile(`(?i)(画|绘|图片)`), Model: "qwen-vl-max"},
|
||||
{Pattern: regexp.MustCompile(`(?i)(代码|编程|code)`), Model: "qwen-coder"},
|
||||
{Pattern: regexp.MustCompile(`(?i)(数学|计算)`), Model: "qwen-math"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("match drawing keywords", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "请帮我画一只猫")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-vl-max", model)
|
||||
})
|
||||
|
||||
t.Run("match code keywords", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "Write a Python code to sort a list")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-coder", model)
|
||||
})
|
||||
|
||||
t.Run("match Chinese code keywords", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "帮我写一段编程代码")
|
||||
require.True(t, found)
|
||||
// First matching rule wins (代码 matches first rule with 代码)
|
||||
require.Equal(t, "qwen-coder", model)
|
||||
})
|
||||
|
||||
t.Run("match math keywords", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "计算123+456等于多少")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-math", model)
|
||||
})
|
||||
|
||||
t.Run("no match returns false", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "今天天气怎么样?")
|
||||
require.False(t, found)
|
||||
require.Equal(t, "", model)
|
||||
})
|
||||
|
||||
t.Run("case insensitive matching", func(t *testing.T) {
|
||||
model, found := matchAutoRoutingRule(config, "Write some CODE for me")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-coder", model)
|
||||
})
|
||||
|
||||
t.Run("first matching rule wins", func(t *testing.T) {
|
||||
// Message contains both "图片" and "代码"
|
||||
model, found := matchAutoRoutingRule(config, "生成一张图片的代码")
|
||||
require.True(t, found)
|
||||
// "图片" rule comes first
|
||||
require.Equal(t, "qwen-vl-max", model)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAutoRoutingIntegration(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("auto routing with matching rule", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "请帮我画一只可爱的小猫"}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
||||
require.True(t, found, "x-higress-llm-model header should be set")
|
||||
require.Equal(t, "qwen-vl-max", modelHeader)
|
||||
})
|
||||
|
||||
t.Run("auto routing with code keywords", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a function to calculate fibonacci numbers"}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-coder", modelHeader)
|
||||
})
|
||||
|
||||
t.Run("auto routing falls back to default model", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天天气怎么样?"}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-turbo", modelHeader)
|
||||
})
|
||||
|
||||
t.Run("auto routing no default model configured", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingNoDefaultConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天天气怎么样?"}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
_, found := getHeader(headers, "x-higress-llm-model")
|
||||
require.False(t, found, "x-higress-llm-model should not be set when no rule matches and no default")
|
||||
})
|
||||
|
||||
t.Run("normal routing when model is not higress/auto", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "qwen-long",
|
||||
"messages": [
|
||||
{"role": "user", "content": "请帮我画一只猫"}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
modelHeader, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-long", modelHeader)
|
||||
|
||||
// x-higress-llm-model should NOT be set (auto routing not triggered)
|
||||
_, found = getHeader(headers, "x-higress-llm-model")
|
||||
require.False(t, found)
|
||||
})
|
||||
|
||||
t.Run("auto routing with multimodal content", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(autoRoutingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := []byte(`{
|
||||
"model": "higress/auto",
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "帮我翻译这段话"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
modelHeader, found := getHeader(headers, "x-higress-llm-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "qwen-turbo", modelHeader) // matches 翻译 rule
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user