diff --git a/plugins/wasm-go/extensions/model-router/README.md b/plugins/wasm-go/extensions/model-router/README.md index a63165e2f..63ba35460 100644 --- a/plugins/wasm-go/extensions/model-router/README.md +++ b/plugins/wasm-go/extensions/model-router/README.md @@ -9,6 +9,22 @@ | `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 | | `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 | | `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 | +| `autoRouting` | object | 选填 | - | 自动路由配置,详见下方说明 | + +### autoRouting 配置 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------ | +| `enable` | bool | 必填 | false | 是否启用自动路由功能 | +| `defaultModel` | string | 选填 | - | 当没有规则匹配时使用的默认模型 | +| `rules` | array of object | 选填 | - | 路由规则数组,按顺序匹配 | + +### rules 配置 + +| 名称 | 数据类型 | 填写要求 | 描述 | +| --------- | -------- | -------- | ------------------------------------------------------------ | +| `pattern` | string | 必填 | 正则表达式,用于匹配用户消息内容 | +| `model` | string | 必填 | 匹配成功时设置的模型名称,将设置到 `x-higress-llm-model` 请求头 | ## 运行属性 @@ -96,3 +112,91 @@ x-higress-llm-provider: dashscope "top_p": 0.95 } ``` + +### 自动路由模式(基于用户消息内容) + +当请求中的 model 参数设置为 `higress/auto` 时,插件会自动分析用户消息内容,并根据配置的正则规则选择合适的模型进行路由。 + +配置示例: + +```yaml +autoRouting: + enable: true + defaultModel: "qwen-turbo" + rules: + - pattern: "(?i)(画|绘|生成图|图片|image|draw|paint)" + model: "qwen-vl-max" + - pattern: "(?i)(代码|编程|code|program|function|debug)" + model: "qwen-coder" + - pattern: "(?i)(翻译|translate|translation)" + model: "qwen-turbo" + - pattern: "(?i)(数学|计算|math|calculate)" + model: "qwen-math" +``` + +#### 工作原理 + +1. 当检测到请求体中的 model 参数值为 `higress/auto` 时,触发自动路由逻辑 +2. 从请求体的 `messages` 数组中提取最后一个 `role` 为 `user` 的消息内容 +3. 按配置的规则顺序,依次使用正则表达式匹配用户消息 +4. 匹配成功时,将对应的 model 值设置到 `x-higress-llm-model` 请求头 +5. 如果所有规则都未匹配,则使用 `defaultModel` 配置的默认模型 +6. 如果未配置 `defaultModel` 且无规则匹配,则不设置路由头(会记录警告日志) + +#### 使用示例 + +客户端请求: + +```json +{ + "model": "higress/auto", + "messages": [ + { + "role": "system", + "content": "你是一个有帮助的助手" + }, + { + "role": "user", + "content": "请帮我画一只可爱的小猫" + } + ] +} +``` + +由于用户消息中包含"画"关键词,匹配到第一条规则,插件会设置请求头: + +``` +x-higress-llm-model: qwen-vl-max +``` + +#### 支持的消息格式 + +自动路由支持两种常见的 content 格式: + +1. **字符串格式**(标准文本消息): +```json +{ + "role": "user", + "content": "用户消息内容" +} +``` + +2. **数组格式**(多模态消息,如包含图片): +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "用户消息内容"}, + {"type": "image_url", "image_url": {"url": "..."}} + ] +} +``` + +对于数组格式,插件会提取最后一个 `type` 为 `text` 的内容进行匹配。 + +#### 正则表达式说明 + +- 规则按配置顺序依次匹配,第一个匹配成功的规则生效 +- 支持标准 Go 正则语法 +- 推荐使用 `(?i)` 标志实现大小写不敏感匹配 +- 使用 `|` 可以匹配多个关键词 diff --git a/plugins/wasm-go/extensions/model-router/main.go b/plugins/wasm-go/extensions/model-router/main.go index 7ed24ba8e..6ad00d75f 100644 --- a/plugins/wasm-go/extensions/model-router/main.go +++ b/plugins/wasm-go/extensions/model-router/main.go @@ -8,6 +8,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + "regexp" "strings" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -20,6 +21,7 @@ import ( const ( DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB + AutoModelPrefix = "higress/auto" ) func main() {} @@ -35,11 +37,21 @@ func init() { ) } +// AutoRoutingRule defines a regex-based routing rule for auto model selection +type AutoRoutingRule struct { + Pattern *regexp.Regexp + Model string +} + type ModelRouterConfig struct { modelKey string addProviderHeader string modelToHeader string enableOnPathSuffix []string + // Auto routing configuration + enableAutoRouting bool + autoRoutingRules []AutoRoutingRule + defaultModel string } func parseConfig(json gjson.Result, config *ModelRouterConfig) error { @@ -70,6 +82,36 @@ func parseConfig(json gjson.Result, config *ModelRouterConfig) error { "/messages", } } + + // Parse auto routing configuration + autoRouting := json.Get("autoRouting") + if autoRouting.Exists() { + config.enableAutoRouting = autoRouting.Get("enable").Bool() + config.defaultModel = autoRouting.Get("defaultModel").String() + + rules := autoRouting.Get("rules") + if rules.Exists() && rules.IsArray() { + for _, rule := range rules.Array() { + patternStr := rule.Get("pattern").String() + model := rule.Get("model").String() + if patternStr == "" || model == "" { + log.Warnf("skipping invalid auto routing rule: pattern=%s, model=%s", patternStr, model) + continue + } + compiled, err := regexp.Compile(patternStr) + if err != nil { + log.Warnf("failed to compile regex pattern '%s': %v", patternStr, err) + continue + } + config.autoRoutingRules = append(config.autoRoutingRules, AutoRoutingRule{ + Pattern: compiled, + Model: model, + }) + log.Debugf("loaded auto routing rule: pattern=%s, model=%s", patternStr, model) + } + } + } + return nil } @@ -120,6 +162,43 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body [ return types.ActionContinue } +// extractLastUserMessage extracts the content of the last message with role "user" from the messages array +func extractLastUserMessage(body []byte) string { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return "" + } + + var lastUserContent string + for _, msg := range messages.Array() { + if msg.Get("role").String() == "user" { + content := msg.Get("content") + if content.IsArray() { + // Handle array content (e.g., multimodal messages with text and images) + for _, item := range content.Array() { + if item.Get("type").String() == "text" { + lastUserContent = item.Get("text").String() + } + } + } else { + lastUserContent = content.String() + } + } + } + return lastUserContent +} + +// matchAutoRoutingRule matches the user message against auto routing rules and returns the matched model +func matchAutoRoutingRule(config ModelRouterConfig, userMessage string) (string, bool) { + for _, rule := range config.autoRoutingRules { + if rule.Pattern.MatchString(userMessage) { + log.Debugf("auto routing rule matched: pattern=%s, model=%s", rule.Pattern.String(), rule.Model) + return rule.Model, true + } + } + return "", false +} + func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action { if !json.Valid(body) { log.Error("invalid json body") @@ -130,6 +209,27 @@ func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []by return types.ActionContinue } + // Check if auto routing should be triggered + if config.enableAutoRouting && modelValue == AutoModelPrefix { + userMessage := extractLastUserMessage(body) + if userMessage != "" { + if matchedModel, found := matchAutoRoutingRule(config, userMessage); found { + // Set the matched model to the header for routing + _ = proxywasm.ReplaceHttpRequestHeader("x-higress-llm-model", matchedModel) + log.Infof("auto routing: user message matched, routing to model: %s", matchedModel) + return types.ActionContinue + } + } + // No rule matched, use default model if configured + if config.defaultModel != "" { + _ = proxywasm.ReplaceHttpRequestHeader("x-higress-llm-model", config.defaultModel) + log.Infof("auto routing: no rule matched, using default model: %s", config.defaultModel) + } else { + log.Warnf("auto routing: no rule matched and no default model configured") + } + return types.ActionContinue + } + if config.modelToHeader != "" { _ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue) } diff --git a/plugins/wasm-go/extensions/model-router/main_test.go b/plugins/wasm-go/extensions/model-router/main_test.go index 9d6263ac8..f5666e331 100644 --- a/plugins/wasm-go/extensions/model-router/main_test.go +++ b/plugins/wasm-go/extensions/model-router/main_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "io" "mime/multipart" + "regexp" "strings" "testing" @@ -286,3 +287,406 @@ func TestOnHttpRequestBody_Multipart(t *testing.T) { require.Equal(t, "openai", pv) }) } + +// Auto routing config for tests +var autoRoutingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "modelToHeader": "x-model", + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + "autoRouting": map[string]interface{}{ + "enable": true, + "defaultModel": "qwen-turbo", + "rules": []map[string]string{ + {"pattern": "(?i)(画|绘|生成图|图片|image|draw|paint)", "model": "qwen-vl-max"}, + {"pattern": "(?i)(代码|编程|code|program|function|debug)", "model": "qwen-coder"}, + {"pattern": "(?i)(翻译|translate|translation)", "model": "qwen-turbo"}, + {"pattern": "(?i)(数学|计算|math|calculate)", "model": "qwen-math"}, + }, + }, + }) + return data +}() + +var autoRoutingNoDefaultConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "modelToHeader": "x-model", + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + "autoRouting": map[string]interface{}{ + "enable": true, + "rules": []map[string]string{ + {"pattern": "(?i)(画|绘)", "model": "qwen-vl-max"}, + }, + }, + }) + return data +}() + +func TestParseConfigAutoRouting(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("parse auto routing config", func(t *testing.T) { + var cfg ModelRouterConfig + err := parseConfig(gjson.ParseBytes(autoRoutingConfig), &cfg) + require.NoError(t, err) + + require.True(t, cfg.enableAutoRouting) + require.Equal(t, "qwen-turbo", cfg.defaultModel) + require.Len(t, cfg.autoRoutingRules, 4) + + // Verify first rule + require.Equal(t, "qwen-vl-max", cfg.autoRoutingRules[0].Model) + require.NotNil(t, cfg.autoRoutingRules[0].Pattern) + }) + + t.Run("skip invalid regex patterns", func(t *testing.T) { + jsonData := []byte(`{ + "autoRouting": { + "enable": true, + "rules": [ + {"pattern": "[invalid", "model": "model1"}, + {"pattern": "valid", "model": "model2"} + ] + } + }`) + var cfg ModelRouterConfig + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + // Only valid rule should be parsed + require.Len(t, cfg.autoRoutingRules, 1) + require.Equal(t, "model2", cfg.autoRoutingRules[0].Model) + }) + + t.Run("skip rules with empty pattern or model", func(t *testing.T) { + jsonData := []byte(`{ + "autoRouting": { + "enable": true, + "rules": [ + {"pattern": "", "model": "model1"}, + {"pattern": "test", "model": ""}, + {"pattern": "valid", "model": "model2"} + ] + } + }`) + var cfg ModelRouterConfig + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + require.Len(t, cfg.autoRoutingRules, 1) + require.Equal(t, "model2", cfg.autoRoutingRules[0].Model) + }) + }) +} + +func TestExtractLastUserMessage(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("extract from simple string content", func(t *testing.T) { + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine"}, + {"role": "user", "content": "Please draw a cat"} + ] + }`) + result := extractLastUserMessage(body) + require.Equal(t, "Please draw a cat", result) + }) + + t.Run("extract from array content (multimodal)", func(t *testing.T) { + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ]} + ] + }`) + result := extractLastUserMessage(body) + require.Equal(t, "What is in this image?", result) + }) + + t.Run("extract last text from array with multiple text items", func(t *testing.T) { + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "First text"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, + {"type": "text", "text": "Second text about drawing"} + ]} + ] + }`) + result := extractLastUserMessage(body) + require.Equal(t, "Second text about drawing", result) + }) + + t.Run("return empty when no messages", func(t *testing.T) { + body := []byte(`{"model": "higress/auto"}`) + result := extractLastUserMessage(body) + require.Equal(t, "", result) + }) + + t.Run("return empty when no user messages", func(t *testing.T) { + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "assistant", "content": "Hello!"} + ] + }`) + result := extractLastUserMessage(body) + require.Equal(t, "", result) + }) + + t.Run("handle multiple user messages", func(t *testing.T) { + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "帮我写一段代码"} + ] + }`) + result := extractLastUserMessage(body) + require.Equal(t, "帮我写一段代码", result) + }) + }) +} + +func TestMatchAutoRoutingRule(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + config := ModelRouterConfig{ + autoRoutingRules: []AutoRoutingRule{ + {Pattern: regexp.MustCompile(`(?i)(画|绘|图片)`), Model: "qwen-vl-max"}, + {Pattern: regexp.MustCompile(`(?i)(代码|编程|code)`), Model: "qwen-coder"}, + {Pattern: regexp.MustCompile(`(?i)(数学|计算)`), Model: "qwen-math"}, + }, + } + + t.Run("match drawing keywords", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "请帮我画一只猫") + require.True(t, found) + require.Equal(t, "qwen-vl-max", model) + }) + + t.Run("match code keywords", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "Write a Python code to sort a list") + require.True(t, found) + require.Equal(t, "qwen-coder", model) + }) + + t.Run("match Chinese code keywords", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "帮我写一段编程代码") + require.True(t, found) + // First matching rule wins (代码 matches first rule with 代码) + require.Equal(t, "qwen-coder", model) + }) + + t.Run("match math keywords", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "计算123+456等于多少") + require.True(t, found) + require.Equal(t, "qwen-math", model) + }) + + t.Run("no match returns false", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "今天天气怎么样?") + require.False(t, found) + require.Equal(t, "", model) + }) + + t.Run("case insensitive matching", func(t *testing.T) { + model, found := matchAutoRoutingRule(config, "Write some CODE for me") + require.True(t, found) + require.Equal(t, "qwen-coder", model) + }) + + t.Run("first matching rule wins", func(t *testing.T) { + // Message contains both "图片" and "代码" + model, found := matchAutoRoutingRule(config, "生成一张图片的代码") + require.True(t, found) + // "图片" rule comes first + require.Equal(t, "qwen-vl-max", model) + }) + }) +} + +func TestAutoRoutingIntegration(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("auto routing with matching rule", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "请帮我画一只可爱的小猫"} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + modelHeader, found := getHeader(headers, "x-higress-llm-model") + require.True(t, found, "x-higress-llm-model header should be set") + require.Equal(t, "qwen-vl-max", modelHeader) + }) + + t.Run("auto routing with code keywords", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": "Write a function to calculate fibonacci numbers"} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + modelHeader, found := getHeader(headers, "x-higress-llm-model") + require.True(t, found) + require.Equal(t, "qwen-coder", modelHeader) + }) + + t.Run("auto routing falls back to default model", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": "今天天气怎么样?"} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + modelHeader, found := getHeader(headers, "x-higress-llm-model") + require.True(t, found) + require.Equal(t, "qwen-turbo", modelHeader) + }) + + t.Run("auto routing no default model configured", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingNoDefaultConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": "今天天气怎么样?"} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + _, found := getHeader(headers, "x-higress-llm-model") + require.False(t, found, "x-higress-llm-model should not be set when no rule matches and no default") + }) + + t.Run("normal routing when model is not higress/auto", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "qwen-long", + "messages": [ + {"role": "user", "content": "请帮我画一只猫"} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + modelHeader, found := getHeader(headers, "x-model") + require.True(t, found) + require.Equal(t, "qwen-long", modelHeader) + + // x-higress-llm-model should NOT be set (auto routing not triggered) + _, found = getHeader(headers, "x-higress-llm-model") + require.False(t, found) + }) + + t.Run("auto routing with multimodal content", func(t *testing.T) { + host, status := test.NewTestHost(autoRoutingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + body := []byte(`{ + "model": "higress/auto", + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "帮我翻译这段话"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ]} + ] + }`) + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + headers := host.GetRequestHeaders() + modelHeader, found := getHeader(headers, "x-higress-llm-model") + require.True(t, found) + require.Equal(t, "qwen-turbo", modelHeader) // matches 翻译 rule + }) + }) +}