mirror of
https://github.com/alibaba/higress.git
synced 2026-05-21 11:17:28 +08:00
feat(model-mapper): sync model header on remap and disable reroute (#3827)
Signed-off-by: 钰诚 <yucheng.lxr@alibaba-inc.com>
This commit is contained in:
@@ -42,6 +42,7 @@ type Config struct {
|
||||
prefixModelMapping []ModelMapping
|
||||
defaultModel string
|
||||
enableOnPathSuffix []string
|
||||
modelToHeader string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
@@ -50,6 +51,11 @@ func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.modelKey = "model"
|
||||
}
|
||||
|
||||
config.modelToHeader = json.Get("modelToHeader").String()
|
||||
if config.modelToHeader == "" {
|
||||
config.modelToHeader = "x-higress-llm-model-final"
|
||||
}
|
||||
|
||||
modelMapping := json.Get("modelMapping")
|
||||
if modelMapping.Exists() && !modelMapping.IsObject() {
|
||||
return errors.New("modelMapping must be an object")
|
||||
@@ -144,6 +150,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Disable re-route since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
// Prepare for body processing
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
// 100MB buffer limit
|
||||
@@ -182,6 +190,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) type
|
||||
}
|
||||
}
|
||||
|
||||
// update x-higress-llm-model-final header
|
||||
proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, newModel)
|
||||
log.Debugf("set header %s: %s", config.modelToHeader, newModel)
|
||||
if newModel != "" && newModel != oldModel {
|
||||
newBody, err := sjson.SetBytes(body, config.modelKey, newModel)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,6 +11,15 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func getHeader(headers [][2]string, key string) (string, bool) {
|
||||
for _, h := range headers {
|
||||
if strings.EqualFold(h[0], key) {
|
||||
return h[1], true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Basic configs for wasm test host
|
||||
var (
|
||||
basicConfig = func() json.RawMessage {
|
||||
@@ -42,6 +51,20 @@ var (
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
headerSyncConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
},
|
||||
"modelToHeader": "x-final-model",
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
@@ -112,6 +135,61 @@ func TestParseConfig(t *testing.T) {
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("modelToHeader default and custom", func(t *testing.T) {
|
||||
var cfgDefault Config
|
||||
require.NoError(t, parseConfig(gjson.ParseBytes([]byte(`{"modelMapping":{}}`)), &cfgDefault))
|
||||
require.Equal(t, "x-higress-llm-model-final", cfgDefault.modelToHeader)
|
||||
|
||||
var cfgCustom Config
|
||||
err := parseConfig(gjson.ParseBytes([]byte(`{
|
||||
"modelToHeader": "x-my-model",
|
||||
"modelMapping": {}
|
||||
}`)), &cfgCustom)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "x-my-model", cfgCustom.modelToHeader)
|
||||
})
|
||||
|
||||
t.Run("empty modelMapping", func(t *testing.T) {
|
||||
var cfg Config
|
||||
err := parseConfig(gjson.ParseBytes([]byte(`{"modelMapping": {}}`)), &cfg)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.exactModelMapping)
|
||||
require.Empty(t, cfg.prefixModelMapping)
|
||||
require.Equal(t, "", cfg.defaultModel)
|
||||
})
|
||||
|
||||
t.Run("prefix rules sorted by key for stable iteration", func(t *testing.T) {
|
||||
var cfg Config
|
||||
// Object key order in JSON is z then a; after sort, prefix "a" is tried before "z".
|
||||
jsonData := []byte(`{
|
||||
"modelMapping": {
|
||||
"z*": "Z",
|
||||
"a*": "A"
|
||||
}
|
||||
}`)
|
||||
require.NoError(t, parseConfig(gjson.ParseBytes(jsonData), &cfg))
|
||||
require.Len(t, cfg.prefixModelMapping, 2)
|
||||
require.Equal(t, "a", cfg.prefixModelMapping[0].Prefix)
|
||||
require.Equal(t, "A", cfg.prefixModelMapping[0].Target)
|
||||
require.Equal(t, "z", cfg.prefixModelMapping[1].Prefix)
|
||||
require.Equal(t, "Z", cfg.prefixModelMapping[1].Target)
|
||||
})
|
||||
|
||||
t.Run("exact mapping wins over prefix", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelKey": "model",
|
||||
"modelMapping": {
|
||||
"gpt-3.5*": "from-prefix",
|
||||
"gpt-3.5-turbo": "from-exact"
|
||||
}
|
||||
}`)
|
||||
require.NoError(t, parseConfig(gjson.ParseBytes(jsonData), &cfg))
|
||||
require.Equal(t, "from-exact", cfg.exactModelMapping["gpt-3.5-turbo"])
|
||||
require.Len(t, cfg.prefixModelMapping, 1)
|
||||
require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -133,15 +211,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should still exist because path is not enabled
|
||||
foundContentLength := false
|
||||
for _, h := range newHeaders {
|
||||
if strings.ToLower(h[0]) == "content-length" {
|
||||
foundContentLength = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundContentLength)
|
||||
_, foundContentLength := getHeader(newHeaders, "content-length")
|
||||
require.True(t, foundContentLength, "content-length should be kept when path is not enabled")
|
||||
})
|
||||
|
||||
t.Run("process when path and content-type match", func(t *testing.T) {
|
||||
@@ -160,10 +231,25 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should be removed
|
||||
for _, h := range newHeaders {
|
||||
require.NotEqual(t, strings.ToLower(h[0]), "content-length")
|
||||
}
|
||||
_, foundCL := getHeader(newHeaders, "content-length")
|
||||
require.False(t, foundCL, "content-length should be removed when buffering body")
|
||||
})
|
||||
|
||||
t.Run("path with query string still matches suffix", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions?trace=1"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "99"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
_, foundCL := getHeader(host.GetRequestHeaders(), "content-length")
|
||||
require.False(t, foundCL)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -192,6 +278,9 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4", v)
|
||||
})
|
||||
|
||||
t.Run("default model when key missing", func(t *testing.T) {
|
||||
@@ -219,6 +308,9 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
|
||||
require.NotNil(t, processed)
|
||||
// default model should be set at request.model
|
||||
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String())
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4o", v)
|
||||
})
|
||||
|
||||
t.Run("prefix mapping takes effect", func(t *testing.T) {
|
||||
@@ -245,6 +337,156 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String())
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4-mini", v)
|
||||
})
|
||||
|
||||
t.Run("exact mapping beats prefix for same family", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"request": {
|
||||
"model": "gpt-3.5-t1",
|
||||
"input": "hello"
|
||||
}
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4-turbo-1", gjson.GetBytes(processed, "request.model").String())
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4-turbo-1", v)
|
||||
})
|
||||
|
||||
t.Run("empty request body is a no-op", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(nil)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
require.Nil(t, host.GetRequestBody())
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte{})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
t.Run("invalid json body is skipped", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"x-higress-llm-model-final", "should-not-change"},
|
||||
})
|
||||
|
||||
bad := []byte(`not json`)
|
||||
action := host.CallOnHttpRequestBody(bad)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
out := host.GetRequestBody()
|
||||
if out != nil {
|
||||
require.Equal(t, string(bad), string(out))
|
||||
}
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "should-not-change", v, "invalid JSON must not refresh model header")
|
||||
})
|
||||
|
||||
t.Run("no body rewrite when already mapped target but header still refreshed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{"model":"gpt-4","messages":[]}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
out := host.GetRequestBody()
|
||||
if out != nil {
|
||||
require.Equal(t, string(origBody), string(out))
|
||||
}
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4", v)
|
||||
})
|
||||
|
||||
t.Run("modelToHeader always set to resolved model", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(headerSyncConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"x-final-model", "gpt-3.5-turbo"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{"model":"gpt-3.5-turbo"}`)
|
||||
require.Equal(t, types.ActionContinue, host.CallOnHttpRequestBody(origBody))
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
|
||||
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-final-model")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4", v)
|
||||
})
|
||||
|
||||
t.Run("modelToHeader refreshed even when it already matches resolved model", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"x-higress-llm-model-final", "gpt-4"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{"model":"gpt-3.5-turbo","messages":[]}`)
|
||||
require.Equal(t, types.ActionContinue, host.CallOnHttpRequestBody(origBody))
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
|
||||
v, ok := getHeader(host.GetRequestHeaders(), "x-higress-llm-model-final")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "gpt-4", v)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user