feat: enhance model mapper and router with rebuild triggers and path extensions (#3218)

This commit is contained in:
澄潭
2025-12-12 18:10:57 +08:00
committed by GitHub
parent 5c17d3faa3
commit 0ada107ec5
6 changed files with 202 additions and 74 deletions

View File

@@ -123,9 +123,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
return true;
}
void PluginRootContext::incrementRequestCount() {
request_count_++;
if (request_count_ >= REBUILD_THRESHOLD) {
LOG_DEBUG("Request count reached threshold, triggering rebuild");
setFilterState("wasm_need_rebuild", "true");
request_count_ = 0; // Reset counter after setting rebuild flag
}
}
FilterHeadersStatus PluginRootContext::onHeader(
PluginContext& ctx,
const ModelRouterConfigRule& rule) {
PluginContext& ctx, const ModelRouterConfigRule& rule) {
// Increment request count and check for rebuild
incrementRequestCount();
// Check memory threshold and trigger rebuild if needed
std::string value;
if (getValue({"plugin_vm_memory"}, &value)) {
// The value is stored as binary uint64_t, convert to string for logging
if (value.size() == sizeof(uint64_t)) {
uint64_t memory_size;
memcpy(&memory_size, value.data(), sizeof(uint64_t));
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
setFilterState("wasm_need_rebuild", "true");
}
} else {
LOG_ERROR("invalid memory size format");
}
} else {
LOG_ERROR("get vm memory size failed");
}
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
@@ -157,7 +188,7 @@ FilterHeadersStatus PluginRootContext::onHeader(
auto content_type_value = content_type_ptr->view();
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::Json)) {
Wasm::Common::Http::ContentTypeValues::Json)) {
ctx.mode_ = MODE_JSON;
LOG_DEBUG("Enable JSON mode.");
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
@@ -165,12 +196,15 @@ FilterHeadersStatus PluginRootContext::onHeader(
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
return FilterHeadersStatus::StopIteration;
}
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
if (absl::StrContains(
content_type_value,
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
// Get the boundary from the content type
auto boundary_start = content_type_value.find("boundary=");
if (boundary_start == std::string::npos) {
LOG_WARN(absl::StrCat("No boundary found in a multipart/form-data content-type: ", content_type_value));
LOG_WARN(absl::StrCat(
"No boundary found in a multipart/form-data content-type: ",
content_type_value));
return FilterHeadersStatus::Continue;
}
boundary_start += 9;
@@ -181,21 +215,25 @@ FilterHeadersStatus PluginRootContext::onHeader(
auto boundary_length = boundary_end - boundary_start;
if (boundary_length < 1 || boundary_length > 70) {
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
LOG_WARN(absl::StrCat("Invalid boundary value in a multipart/form-data content-type: ", content_type_value));
LOG_WARN(absl::StrCat(
"Invalid boundary value in a multipart/form-data content-type: ",
content_type_value));
return FilterHeadersStatus::Continue;
}
auto boundary_value = content_type_value.substr(boundary_start, boundary_end - boundary_start);
auto boundary_value = content_type_value.substr(
boundary_start, boundary_end - boundary_start);
ctx.mode_ = MODE_MULTIPART;
ctx.boundary_ = boundary_value;
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=", boundary_value));
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=",
boundary_value));
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
return FilterHeadersStatus::StopIteration;
}
return FilterHeadersStatus::Continue;
}
FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule,
std::string_view body) {
FilterDataStatus PluginRootContext::onJsonBody(
const ModelRouterConfigRule& rule, std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
@@ -231,18 +269,18 @@ FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule
}
FilterDataStatus PluginRootContext::onMultipartBody(
PluginContext& ctx,
const ModelRouterConfigRule& rule,
WasmDataPtr& body,
PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body,
bool end_stream) {
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
const auto boundary = ctx.boundary_;
const auto body_view = body->view();
const auto model_param_header = absl::StrCat("Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
const auto model_param_header = absl::StrCat(
"Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
for (size_t pos = 0; (pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
for (size_t pos = 0;
(pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
pos += boundary.length();
size_t end_pos = body_view.find(boundary, pos);
@@ -264,7 +302,7 @@ FilterDataStatus PluginRootContext::onMultipartBody(
LOG_DEBUG("No value start found in part");
break;
}
value_start += 4; // Skip the "\r\n\r\n"
value_start += 4; // Skip the "\r\n\r\n"
// model parameter should be only one line
size_t value_end = part.find(CRLF, value_start);
if (value_end == std::string_view::npos) {
@@ -283,8 +321,12 @@ FilterDataStatus PluginRootContext::onMultipartBody(
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
size_t new_size = 0;
auto new_buffer_data = absl::StrCat(body_view.substr(0, part_pos + value_start), model, body_view.substr(part_pos + value_end));
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0, std::numeric_limits<size_t>::max(), new_buffer_data, &new_size);
auto new_buffer_data =
absl::StrCat(body_view.substr(0, part_pos + value_start), model,
body_view.substr(part_pos + value_end));
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(),
new_buffer_data, &new_size);
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
@@ -294,7 +336,8 @@ FilterDataStatus PluginRootContext::onMultipartBody(
}
}
// We are done now. We can stop processing the body.
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ", body_view.length() , " bytes."));
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ",
body_view.length(), " bytes."));
ctx.mode_ = MODE_BYPASS;
return FilterDataStatus::Continue;
}
@@ -324,8 +367,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
auto* rootCtx = rootContext();
body_total_size_ += body_size;
switch (mode_) {
case MODE_JSON:
{
case MODE_JSON: {
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
@@ -333,8 +375,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onJsonBody(*config_, body->view());
}
case MODE_MULTIPART:
{
case MODE_MULTIPART: {
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);