mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
feat(model-router): add keepOriginalModelName option to preserve full model name (#3916)
Signed-off-by: Cai Rui <yangjuan.cr@alibaba-inc.com>
This commit is contained in:
@@ -288,6 +288,126 @@ func TestOnHttpRequestBody_Multipart(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
var keepOriginalModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"addProviderHeader": "x-provider",
|
||||
"modelToHeader": "x-model",
|
||||
"keepOriginalModelName": true,
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfigKeepOriginalModelName(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("default false", func(t *testing.T) {
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(basicConfig), &cfg)
|
||||
require.NoError(t, err)
|
||||
require.False(t, cfg.keepOriginalModelName)
|
||||
})
|
||||
|
||||
t.Run("parse true", func(t *testing.T) {
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(keepOriginalModelConfig), &cfg)
|
||||
require.NoError(t, err)
|
||||
require.True(t, cfg.keepOriginalModelName)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeepOriginalModelName(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("json: provider header set but body model preserved", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(keepOriginalModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"model": "MiniMax/MiniMax-M2.7",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
// model header keeps the full name
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "MiniMax/MiniMax-M2.7", hv)
|
||||
// provider header IS set (split still extracts provider)
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "MiniMax", pv)
|
||||
|
||||
// body model must remain intact (not rewritten)
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "MiniMax/MiniMax-M2.7", gjson.GetBytes(processed, "model").String())
|
||||
})
|
||||
|
||||
t.Run("multipart: provider header set but body model preserved", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(keepOriginalModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
modelWriter, err := writer.CreateFormField("model")
|
||||
require.NoError(t, err)
|
||||
_, err = modelWriter.Write([]byte("MiniMax/MiniMax-M2.7"))
|
||||
require.NoError(t, err)
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
contentType := "multipart/form-data; boundary=" + writer.Boundary()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(buf.Bytes())
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "MiniMax/MiniMax-M2.7", hv)
|
||||
// provider header IS set
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "MiniMax", pv)
|
||||
|
||||
// body model should not be rewritten
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if part.FormName() == "model" {
|
||||
data, _ := io.ReadAll(part)
|
||||
require.Equal(t, "MiniMax/MiniMax-M2.7", string(data))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Auto routing config for tests
|
||||
var autoRoutingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
|
||||
Reference in New Issue
Block a user