feat: Support extracting model argument from body in multipart/form-data format (#1940)

This commit is contained in:
Kent Dong
2025-04-22 13:52:50 +08:00
committed by GitHub
parent b8133a95b2
commit 1c37c361e1
4 changed files with 334 additions and 26 deletions

View File

@@ -55,6 +55,7 @@ constexpr std::string_view Host(":authority");
constexpr std::string_view Path(":path");
constexpr std::string_view EnvoyOriginalPath("x-envoy-original-path");
constexpr std::string_view Accept("accept");
constexpr std::string_view ContentDisposition("content-disposition");
constexpr std::string_view ContentMD5("content-md5");
constexpr std::string_view ContentType("content-type");
constexpr std::string_view ContentLength("content-length");
@@ -68,6 +69,7 @@ constexpr std::string_view StrictTransportSecurity("strict-transport-security");
namespace ContentTypeValues {
constexpr std::string_view Grpc{"application/grpc"};
constexpr std::string_view Json{"application/json"};
constexpr std::string_view MultipartFormData{"multipart/form-data"};
} // namespace ContentTypeValues
class PercentEncoding {

View File

@@ -16,6 +16,7 @@
#include <array>
#include <limits>
#include <regex>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -123,6 +124,7 @@ bool PluginRootContext::configure(size_t configuration_size) {
}
FilterHeadersStatus PluginRootContext::onHeader(
PluginContext& ctx,
const ModelRouterConfigRule& rule) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
@@ -150,19 +152,49 @@ FilterHeadersStatus PluginRootContext::onHeader(
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
auto content_type_ptr =
getRequestHeader(Wasm::Common::Http::Header::ContentType);
if (!absl::StrContains(content_type_value->view(),
auto content_type_value = content_type_ptr->view();
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::Json)) {
return FilterHeadersStatus::Continue;
ctx.mode_ = MODE_JSON;
LOG_DEBUG("Enable JSON mode.");
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
return FilterHeadersStatus::StopIteration;
}
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
return FilterHeadersStatus::StopIteration;
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
// Get the boundary from the content type
auto boundary_start = content_type_value.find("boundary=");
if (boundary_start == std::string::npos) {
LOG_WARN(absl::StrCat("No boundary found in a multipart/form-data content-type: ", content_type_value));
return FilterHeadersStatus::Continue;
}
boundary_start += 9;
auto boundary_end = content_type_value.find(';', boundary_start);
if (boundary_end == std::string::npos) {
boundary_end = content_type_value.size();
}
auto boundary_length = boundary_end - boundary_start;
if (boundary_length < 1 || boundary_length > 70) {
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
LOG_WARN(absl::StrCat("Invalid boundary value in a multipart/form-data content-type: ", content_type_value));
return FilterHeadersStatus::Continue;
}
auto boundary_value = content_type_value.substr(boundary_start, boundary_end - boundary_start);
ctx.mode_ = MODE_MULTIPART;
ctx.boundary_ = boundary_value;
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=", boundary_value));
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
return FilterHeadersStatus::StopIteration;
}
return FilterHeadersStatus::Continue;
}
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_provider_header = rule.add_provider_header_;
@@ -198,10 +230,85 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
return FilterDataStatus::Continue;
}
FilterDataStatus PluginRootContext::onMultipartBody(
PluginContext& ctx,
const ModelRouterConfigRule& rule,
WasmDataPtr& body,
bool end_stream) {
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
const auto boundary = ctx.boundary_;
const auto body_view = body->view();
const auto model_param_header = absl::StrCat("Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
for (size_t pos = 0; (pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
pos += boundary.length();
size_t end_pos = body_view.find(boundary, pos);
if (end_pos == std::string_view::npos) {
end_pos = body_view.length();
}
std::string_view part = body_view.substr(pos, end_pos - pos);
LOG_DEBUG(absl::StrCat("Part: ", part));
auto part_pos = pos;
pos = end_pos;
// Check if this part contains the model parameter
if (!absl::StrContains(part, model_param_header)) {
LOG_DEBUG("Part does not contain model parameter");
continue;
}
size_t value_start = part.find(CRLF_CRLF);
if (value_start == std::string_view::npos) {
LOG_DEBUG("No value start found in part");
break;
}
value_start += 4; // Skip the "\r\n\r\n"
// model parameter should be only one line
size_t value_end = part.find(CRLF, value_start);
if (value_end == std::string_view::npos) {
LOG_DEBUG("No value end found in part");
break;
}
auto model_value = part.substr(value_start, value_end - value_start);
LOG_DEBUG(absl::StrCat("Model value: ", model_value));
if (!model_to_header.empty()) {
replaceRequestHeader(model_to_header, model_value);
}
if (!add_provider_header.empty()) {
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
size_t new_size = 0;
auto new_buffer_data = absl::StrCat(body_view.substr(0, part_pos + value_start), model, body_view.substr(part_pos + value_end));
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0, std::numeric_limits<size_t>::max(), new_buffer_data, &new_size);
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
} else {
LOG_DEBUG(absl::StrCat("model route to provider not work, model:",
model_value));
}
}
// We are done now. We can stop processing the body.
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ", body_view.length() , " bytes."));
ctx.mode_ = MODE_BYPASS;
return FilterDataStatus::Continue;
}
if (end_stream) {
LOG_DEBUG("No model parameter found in the body");
return FilterDataStatus::Continue;
}
return FilterDataStatus::StopIterationAndBuffer;
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->onHeaders([rootCtx, this](const auto& config) {
auto ret = rootCtx->onHeader(config);
auto ret = rootCtx->onHeader(*this, config);
if (ret == FilterHeadersStatus::StopIteration) {
this->config_ = &config;
}
@@ -214,14 +321,28 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
if (config_ == nullptr) {
return FilterDataStatus::Continue;
}
body_total_size_ += body_size;
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
auto* rootCtx = rootContext();
return rootCtx->onBody(*config_, body->view());
body_total_size_ += body_size;
switch (mode_) {
case MODE_JSON:
{
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onJsonBody(*config_, body->view());
}
case MODE_MULTIPART:
{
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);
}
case MODE_BYPASS:
default:
return FilterDataStatus::Continue;
}
}
#ifdef NULL_PLUGIN

View File

@@ -36,6 +36,13 @@ namespace model_router {
#endif
#define MODE_BYPASS 0
#define MODE_JSON 1
#define MODE_MULTIPART 2
#define CRLF ("\r\n")
#define CRLF_CRLF ("\r\n\r\n")
struct ModelRouterConfigRule {
std::string model_key_ = "model";
std::string add_provider_header_;
@@ -45,6 +52,8 @@ struct ModelRouterConfigRule {
"/audio/speech", "/fine_tuning/jobs", "/moderations"};
};
class PluginContext;
// PluginRootContext is the root context for all streams processed by the
// thread. It has the same lifetime as the worker thread and acts as target for
// interactions that outlives individual stream, e.g. timer, async calls.
@@ -55,8 +64,9 @@ class PluginRootContext : public RootContext,
: RootContext(id, root_id) {}
~PluginRootContext() {}
bool onConfigure(size_t) override;
FilterHeadersStatus onHeader(const ModelRouterConfigRule&);
FilterDataStatus onBody(const ModelRouterConfigRule&, std::string_view);
FilterHeadersStatus onHeader(PluginContext& ctx, const ModelRouterConfigRule&);
FilterDataStatus onJsonBody(const ModelRouterConfigRule&, std::string_view);
FilterDataStatus onMultipartBody(PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body, bool end_stream);
bool configure(size_t);
private:
@@ -69,6 +79,8 @@ class PluginContext : public Context {
explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
FilterDataStatus onRequestBody(size_t, bool) override;
int mode_;
std::string boundary_;
private:
inline PluginRootContext* rootContext() {

View File

@@ -15,6 +15,7 @@
#include "extensions/model_router/plugin.h"
#include <cstddef>
#include <regex>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -86,7 +87,7 @@ class ModelRouterTest : public ::testing::Test {
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view* result) {
if (header == "content-type") {
*result = "application/json";
*result = content_type_;
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
@@ -125,6 +126,7 @@ class ModelRouterTest : public ::testing::Test {
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string path_;
std::string content_type_ = "application/json";
BufferBase body_;
BufferBase config_;
};
@@ -133,7 +135,7 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -155,14 +157,14 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeader) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -181,14 +183,14 @@ TEST_F(ModelRouterTest, ModelToHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, IgnorePath) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
@@ -208,7 +210,7 @@ TEST_F(ModelRouterTest, IgnorePath) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
@@ -242,7 +244,178 @@ TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
route_name_ = "route-a";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
std::string configuration = R"({
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen/qwen-turbo
----------------------------100751621174704322650451
Content-Disposition: form-data; name="file"; filename="test-data.json"
Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t start, size_t length, std::string_view body) {
std::cerr << "===============" << "\n";
std::cerr << body << "\n";
std::cerr << "===============" << "\n";
EXPECT_EQ(start, 0);
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
auto expected_body= std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen-turbo
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_EQ(body, expected_body);
return WasmResult::Ok;
});
EXPECT_CALL(*mock_context_,
replaceHeaderMapValue(testing::_,
std::string_view("x-higress-llm-provider"),
std::string_view("qwen")));
body_.set(request_data);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 + 50 /* "qwen-turbo" + CRLF + boundary */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
batch
----------------------------100751621174704322650451
Content-Disposition: form-data; name="model"
qwen-max
----------------------------100751621174704322650451
Content-Disposition: form-data; name="file"; filename="test-data.json"
Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(
*mock_context_,
replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"),
std::string_view("qwen-max")));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
}
} // namespace model_router