feat: Support extracting model argument from body in multipart/form-data format (#1940)

This commit is contained in:
Kent Dong
2025-04-22 13:52:50 +08:00
committed by GitHub
parent b8133a95b2
commit 1c37c361e1
4 changed files with 334 additions and 26 deletions

View File

@@ -15,6 +15,7 @@
#include "extensions/model_router/plugin.h"
#include <cstddef>
#include <regex>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -86,7 +87,7 @@ class ModelRouterTest : public ::testing::Test {
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view* result) {
if (header == "content-type") {
*result = "application/json";
*result = content_type_;
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
@@ -125,6 +126,7 @@ class ModelRouterTest : public ::testing::Test {
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string path_;
std::string content_type_ = "application/json";
BufferBase body_;
BufferBase config_;
};
@@ -133,7 +135,7 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -155,14 +157,14 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeader) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -181,14 +183,14 @@ TEST_F(ModelRouterTest, ModelToHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, IgnorePath) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -208,7 +210,7 @@ TEST_F(ModelRouterTest, IgnorePath) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
@@ -242,7 +244,178 @@ TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
route_name_ = "route-a";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
std::string configuration = R"({
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen/qwen-turbo
----------------------------100751621174704322650451
Content-Disposition: form-data; name="file"; filename="test-data.json"
Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t start, size_t length, std::string_view body) {
std::cerr << "===============" << "\n";
std::cerr << body << "\n";
std::cerr << "===============" << "\n";
EXPECT_EQ(start, 0);
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
auto expected_body= std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen-turbo
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_EQ(body, expected_body);
return WasmResult::Ok;
});
EXPECT_CALL(*mock_context_,
replaceHeaderMapValue(testing::_,
std::string_view("x-higress-llm-provider"),
std::string_view("qwen")));
body_.set(request_data);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 + 50 /* "qwen-turbo" + CRLF + boundary */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen-max
----------------------------100751621174704322650451
Content-Disposition: form-data; name="file"; filename="test-data.json"
Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(
*mock_context_,
replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"),
std::string_view("qwen-max")));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
}
} // namespace model_router