From e497d8017a0002943dd20f0249026678f6c1ef62 Mon Sep 17 00:00:00 2001 From: rinfx Date: Fri, 15 May 2026 14:38:54 +0800 Subject: [PATCH] feat(model-mapper): sync model header on remap and disable reroute (#3827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 钰诚 --- .../wasm-go/extensions/model-mapper/main.go | 11 + .../extensions/model-mapper/main_test.go | 268 +++++++++++++++++- 2 files changed, 266 insertions(+), 13 deletions(-) diff --git a/plugins/wasm-go/extensions/model-mapper/main.go b/plugins/wasm-go/extensions/model-mapper/main.go index 68f5becb..eb4a8f57 100644 --- a/plugins/wasm-go/extensions/model-mapper/main.go +++ b/plugins/wasm-go/extensions/model-mapper/main.go @@ -42,6 +42,7 @@ type Config struct { prefixModelMapping []ModelMapping defaultModel string enableOnPathSuffix []string + modelToHeader string } func parseConfig(json gjson.Result, config *Config) error { @@ -50,6 +51,11 @@ func parseConfig(json gjson.Result, config *Config) error { config.modelKey = "model" } + config.modelToHeader = json.Get("modelToHeader").String() + if config.modelToHeader == "" { + config.modelToHeader = "x-higress-llm-model-final" + } + modelMapping := json.Get("modelMapping") if modelMapping.Exists() && !modelMapping.IsObject() { return errors.New("modelMapping must be an object") @@ -144,6 +150,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action { return types.ActionContinue } + // Disable re-route since the plugin may modify some headers related to the chosen route. + ctx.DisableReroute() // Prepare for body processing proxywasm.RemoveHttpRequestHeader("content-length") // 100MB buffer limit @@ -182,6 +190,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) type } } + // update x-higress-llm-model-final header + proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, newModel) + log.Debugf("set header %s: %s", config.modelToHeader, newModel) if newModel != "" && newModel != oldModel { newBody, err := sjson.SetBytes(body, config.modelKey, newModel) if err != nil { diff --git a/plugins/wasm-go/extensions/model-mapper/main_test.go b/plugins/wasm-go/extensions/model-mapper/main_test.go index 7be9c785..885fb64e 100644 --- a/plugins/wasm-go/extensions/model-mapper/main_test.go +++ b/plugins/wasm-go/extensions/model-mapper/main_test.go @@ -11,6 +11,15 @@ import ( "github.com/tidwall/gjson" ) +func getHeader(headers [][2]string, key string) (string, bool) { + for _, h := range headers { + if strings.EqualFold(h[0], key) { + return h[1], true + } + } + return "", false +} + // Basic configs for wasm test host var ( basicConfig = func() json.RawMessage { @@ -42,6 +51,20 @@ var ( }) return data }() + + headerSyncConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "modelMapping": map[string]string{ + "gpt-3.5-turbo": "gpt-4", + }, + "modelToHeader": "x-final-model", + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + }) + return data + }() ) func TestParseConfig(t *testing.T) { @@ -112,6 +135,61 @@ func TestParseConfig(t *testing.T) { err := parseConfig(gjson.ParseBytes(jsonData), &cfg) require.Error(t, err) }) + + t.Run("modelToHeader default and custom", func(t *testing.T) { + var cfgDefault Config + require.NoError(t, parseConfig(gjson.ParseBytes([]byte(`{"modelMapping":{}}`)), &cfgDefault)) + require.Equal(t, "x-higress-llm-model-final", cfgDefault.modelToHeader) + + var cfgCustom Config + err := parseConfig(gjson.ParseBytes([]byte(`{ + "modelToHeader": "x-my-model", + "modelMapping": {} + }`)), &cfgCustom) + require.NoError(t, err) + require.Equal(t, "x-my-model", cfgCustom.modelToHeader) + }) + + t.Run("empty modelMapping", func(t *testing.T) { + var cfg Config + err := parseConfig(gjson.ParseBytes([]byte(`{"modelMapping": {}}`)), &cfg) + require.NoError(t, err) + require.Empty(t, cfg.exactModelMapping) + require.Empty(t, cfg.prefixModelMapping) + require.Equal(t, "", cfg.defaultModel) + }) + + t.Run("prefix rules sorted by key for stable iteration", func(t *testing.T) { + var cfg Config + // Object key order in JSON is z then a; after sort, prefix "a" is tried before "z". + jsonData := []byte(`{ + "modelMapping": { + "z*": "Z", + "a*": "A" + } + }`) + require.NoError(t, parseConfig(gjson.ParseBytes(jsonData), &cfg)) + require.Len(t, cfg.prefixModelMapping, 2) + require.Equal(t, "a", cfg.prefixModelMapping[0].Prefix) + require.Equal(t, "A", cfg.prefixModelMapping[0].Target) + require.Equal(t, "z", cfg.prefixModelMapping[1].Prefix) + require.Equal(t, "Z", cfg.prefixModelMapping[1].Target) + }) + + t.Run("exact mapping wins over prefix", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "modelKey": "model", + "modelMapping": { + "gpt-3.5*": "from-prefix", + "gpt-3.5-turbo": "from-exact" + } + }`) + require.NoError(t, parseConfig(gjson.ParseBytes(jsonData), &cfg)) + require.Equal(t, "from-exact", cfg.exactModelMapping["gpt-3.5-turbo"]) + require.Len(t, cfg.prefixModelMapping, 1) + require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix) + }) }) } @@ -133,15 +211,8 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.ActionContinue, action) newHeaders := host.GetRequestHeaders() - // content-length should still exist because path is not enabled - foundContentLength := false - for _, h := range newHeaders { - if strings.ToLower(h[0]) == "content-length" { - foundContentLength = true - break - } - } - require.True(t, foundContentLength) + _, foundContentLength := getHeader(newHeaders, "content-length") + require.True(t, foundContentLength, "content-length should be kept when path is not enabled") }) t.Run("process when path and content-type match", func(t *testing.T) { @@ -160,10 +231,25 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopIteration, action) newHeaders := host.GetRequestHeaders() - // content-length should be removed - for _, h := range newHeaders { - require.NotEqual(t, strings.ToLower(h[0]), "content-length") - } + _, foundCL := getHeader(newHeaders, "content-length") + require.False(t, foundCL, "content-length should be removed when buffering body") + }) + + t.Run("path with query string still matches suffix", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions?trace=1"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"content-length", "99"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + _, foundCL := getHeader(host.GetRequestHeaders(), "content-length") + require.False(t, foundCL) }) }) } @@ -192,6 +278,9 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) { processed := host.GetRequestBody() require.NotNil(t, processed) require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String()) + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4", v) }) t.Run("default model when key missing", func(t *testing.T) { @@ -219,6 +308,9 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) { require.NotNil(t, processed) // default model should be set at request.model require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String()) + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4o", v) }) t.Run("prefix mapping takes effect", func(t *testing.T) { @@ -245,6 +337,156 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) { processed := host.GetRequestBody() require.NotNil(t, processed) require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String()) + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4-mini", v) + }) + + t.Run("exact mapping beats prefix for same family", func(t *testing.T) { + host, status := test.NewTestHost(customConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "request": { + "model": "gpt-3.5-t1", + "input": "hello" + } + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "gpt-4-turbo-1", gjson.GetBytes(processed, "request.model").String()) + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4-turbo-1", v) + }) + + t.Run("empty request body is a no-op", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + action := host.CallOnHttpRequestBody(nil) + require.Equal(t, types.ActionContinue, action) + require.Nil(t, host.GetRequestBody()) + + action = host.CallOnHttpRequestBody([]byte{}) + require.Equal(t, types.ActionContinue, action) + }) + + t.Run("invalid json body is skipped", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-higress-llm-model-final", "should-not-change"}, + }) + + bad := []byte(`not json`) + action := host.CallOnHttpRequestBody(bad) + require.Equal(t, types.ActionContinue, action) + out := host.GetRequestBody() + if out != nil { + require.Equal(t, string(bad), string(out)) + } + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "should-not-change", v, "invalid JSON must not refresh model header") + }) + + t.Run("no body rewrite when already mapped target but header still refreshed", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{"model":"gpt-4","messages":[]}`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + out := host.GetRequestBody() + if out != nil { + require.Equal(t, string(origBody), string(out)) + } + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4", v) + }) + + t.Run("modelToHeader always set to resolved model", func(t *testing.T) { + host, status := test.NewTestHost(headerSyncConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-final-model", "gpt-3.5-turbo"}, + }) + + origBody := []byte(`{"model":"gpt-3.5-turbo"}`) + require.Equal(t, types.ActionContinue, host.CallOnHttpRequestBody(origBody)) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String()) + + v, ok := getHeader(host.GetRequestHeaders(), "x-final-model") + require.True(t, ok) + require.Equal(t, "gpt-4", v) + }) + + t.Run("modelToHeader refreshed even when it already matches resolved model", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-higress-llm-model-final", "gpt-4"}, + }) + + origBody := []byte(`{"model":"gpt-3.5-turbo","messages":[]}`) + require.Equal(t, types.ActionContinue, host.CallOnHttpRequestBody(origBody)) + + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String()) + v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final") + require.True(t, ok) + require.Equal(t, "gpt-4", v) }) }) }