diff --git a/plugins/wasm-go/extensions/model-mapper/main.go b/plugins/wasm-go/extensions/model-mapper/main.go index 68f5becb..44986f6f 100644 --- a/plugins/wasm-go/extensions/model-mapper/main.go +++ b/plugins/wasm-go/extensions/model-mapper/main.go @@ -42,6 +42,7 @@ type Config struct { prefixModelMapping []ModelMapping defaultModel string enableOnPathSuffix []string + modelToHeader string } func parseConfig(json gjson.Result, config *Config) error { @@ -49,6 +50,10 @@ func parseConfig(json gjson.Result, config *Config) error { if config.modelKey == "" { config.modelKey = "model" } + config.modelToHeader = json.Get("modelToHeader").String() + if config.modelToHeader == "" { + config.modelToHeader = "x-higress-llm-model" + } modelMapping := json.Get("modelMapping") if modelMapping.Exists() && !modelMapping.IsObject() { @@ -144,6 +149,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action { return types.ActionContinue } + // Disable re-route since the plugin may modify some headers related to the chosen route. + ctx.DisableReroute() // Prepare for body processing proxywasm.RemoveHttpRequestHeader("content-length") // 100MB buffer limit @@ -183,6 +190,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) type } if newModel != "" && newModel != oldModel { + // if x-higress-llm-model header is set, and it is not the same as the new model, update it + // this is to support fallback and token rate limit + model, _ := proxywasm.GetHttpRequestHeader(config.modelToHeader) + if model != "" && model != newModel { + proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, newModel) + } newBody, err := sjson.SetBytes(body, config.modelKey, newModel) if err != nil { log.Errorf("failed to update model: %v", err) diff --git a/plugins/wasm-go/extensions/model-mapper/main_test.go b/plugins/wasm-go/extensions/model-mapper/main_test.go index 7be9c785..d8084a33 100644 --- a/plugins/wasm-go/extensions/model-mapper/main_test.go +++ b/plugins/wasm-go/extensions/model-mapper/main_test.go @@ -42,6 +42,20 @@ var ( }) return data }() + + customHeaderConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "modelKey": "model", + "modelToHeader": "x-custom-model-header", + "modelMapping": map[string]string{ + "gpt-3.5-turbo": "gpt-4", + }, + "enableOnPathSuffix": []string{ + "/v1/chat/completions", + }, + }) + return data + }() ) func TestParseConfig(t *testing.T) { @@ -95,6 +109,21 @@ func TestParseConfig(t *testing.T) { require.Contains(t, cfg.enableOnPathSuffix, "/v1/embeddings") }) + t.Run("custom modelToHeader", func(t *testing.T) { + var cfg Config + jsonData := []byte(`{ + "modelKey": "model", + "modelToHeader": "x-custom-model-header", + "modelMapping": { + "gpt-3.5-turbo": "gpt-4" + } + }`) + err := parseConfig(gjson.ParseBytes(jsonData), &cfg) + require.NoError(t, err) + + require.Equal(t, "model", cfg.modelKey) + }) + t.Run("modelMapping must be object", func(t *testing.T) { var cfg Config jsonData := []byte(`{ @@ -246,5 +275,179 @@ func TestOnHttpRequestBody_ModelMapping(t *testing.T) { require.NotNil(t, processed) require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String()) }) + + t.Run("update model header when model changes", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-higress-llm-model", "gpt-3.5-turbo-fallback"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // verify x-higress-llm-model header was updated to the mapped target + newHeaders := host.GetRequestHeaders() + foundUpdatedHeader := false + for _, h := range newHeaders { + if strings.ToLower(h[0]) == "x-higress-llm-model" { + require.Equal(t, "gpt-4", h[1]) + foundUpdatedHeader = true + break + } + } + require.True(t, foundUpdatedHeader, "x-higress-llm-model header should be updated") + }) + + t.Run("skip model header update when header not set", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // verify x-higress-llm-model header was NOT added (should not exist) + newHeaders := host.GetRequestHeaders() + for _, h := range newHeaders { + require.NotEqual(t, strings.ToLower(h[0]), "x-higress-llm-model") + } + }) + + t.Run("skip model header update when header already matches new model", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-higress-llm-model", "gpt-4"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // verify x-higress-llm-model header has the correct value + newHeaders := host.GetRequestHeaders() + for _, h := range newHeaders { + if strings.ToLower(h[0]) == "x-higress-llm-model" { + require.Equal(t, "gpt-4", h[1]) + break + } + } + }) + + t.Run("no model mapping keeps header unchanged", func(t *testing.T) { + host, status := test.NewTestHost(basicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-higress-llm-model", "some-other-model"}, + }) + + origBody := []byte(`{ + "model": "unknown-model", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // model should remain unchanged (no mapping) + processed := host.GetRequestBody() + require.NotNil(t, processed) + require.Equal(t, "unknown-model", gjson.GetBytes(processed, "model").String()) + }) + + t.Run("use custom modelToHeader config", func(t *testing.T) { + host, status := test.NewTestHost(customHeaderConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + {"x-custom-model-header", "original-model"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // verify custom header was updated to the mapped target + newHeaders := host.GetRequestHeaders() + foundUpdatedHeader := false + for _, h := range newHeaders { + if strings.ToLower(h[0]) == "x-custom-model-header" { + require.Equal(t, "gpt-4", h[1]) + foundUpdatedHeader = true + break + } + } + require.True(t, foundUpdatedHeader, "x-custom-model-header should be updated") + }) + + t.Run("use custom modelToHeader with empty header value", func(t *testing.T) { + host, status := test.NewTestHost(customHeaderConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"content-type", "application/json"}, + }) + + origBody := []byte(`{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + }`) + action := host.CallOnHttpRequestBody(origBody) + require.Equal(t, types.ActionContinue, action) + + // verify custom header was NOT added when not present + newHeaders := host.GetRequestHeaders() + for _, h := range newHeaders { + require.NotEqual(t, strings.ToLower(h[0]), "x-custom-model-header") + } + }) }) }