diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.cc b/plugins/wasm-cpp/extensions/model_router/plugin.cc index 457864d26..66a90973f 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin.cc @@ -101,7 +101,7 @@ bool PluginRootContext::configure(size_t configuration_size) { configuration_data->view())); return false; } - if (!parseAuthRuleConfig(result.value())) { + if (!parseRuleConfig(result.value())) { LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", configuration_data->view())); return false; diff --git a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc index 9ce599805..dc351ecdc 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc @@ -40,6 +40,11 @@ class MockContext : public proxy_wasm::ContextBase { MOCK_METHOD(WasmResult, getHeaderMapValue, (WasmHeaderMapType /* type */, std::string_view /* key */, std::string_view* /*result */)); + MOCK_METHOD(WasmResult, replaceHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */, + std::string_view /* value */)); + MOCK_METHOD(WasmResult, removeHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */)); MOCK_METHOD(WasmResult, addHeaderMapValue, (WasmHeaderMapType, std::string_view, std::string_view)); MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*)); @@ -87,6 +92,16 @@ class ModelRouterTest : public ::testing::Test { } return WasmResult::Ok; }); + ON_CALL(*mock_context_, + replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_, + testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key, + std::string_view value) { return WasmResult::Ok; }); + ON_CALL(*mock_context_, + removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key) { + return WasmResult::Ok; + }); ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_, testing::_)) .WillByDefault([&](WasmHeaderMapType, std::string_view header, @@ -128,10 +143,10 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) { return WasmResult::Ok; }); - EXPECT_CALL( - *mock_context_, - addHeaderMapValue(testing::_, std::string_view("x-higress-llm-provider"), - std::string_view("qwen"))); + EXPECT_CALL(*mock_context_, + replaceHeaderMapValue(testing::_, + std::string_view("x-higress-llm-provider"), + std::string_view("qwen"))); body_.set(request_json); EXPECT_EQ(context_->onRequestHeaders(0, false), @@ -139,6 +154,39 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) { EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); } +TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) { + std::string configuration = R"( +{ + "_rules_": [ + { + "_match_route_": ["route-a"], + "enable": true + } +]})"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + std::string request_json = R"({"model": "qwen/qwen-long"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-long"})"); + return WasmResult::Ok; + }); + + EXPECT_CALL(*mock_context_, + replaceHeaderMapValue(testing::_, + std::string_view("x-higress-llm-provider"), + std::string_view("qwen"))); + + body_.set(request_json); + route_name_ = "route-a"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); +} + } // namespace model_router } // namespace null_plugin } // namespace proxy_wasm