mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
feat: enhance model mapper and router with rebuild triggers and path extensions (#3218)
This commit is contained in:
1
plugins/wasm-cpp/.clang-format
Normal file
1
plugins/wasm-cpp/.clang-format
Normal file
@@ -0,0 +1 @@
|
|||||||
|
BasedOnStyle: Google
|
||||||
@@ -135,8 +135,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PluginRootContext::incrementRequestCount() {
|
||||||
|
request_count_++;
|
||||||
|
if (request_count_ >= REBUILD_THRESHOLD) {
|
||||||
|
LOG_DEBUG("Request count reached threshold, triggering rebuild");
|
||||||
|
setFilterState("wasm_need_rebuild", "true");
|
||||||
|
request_count_ = 0; // Reset counter after setting rebuild flag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
FilterHeadersStatus PluginRootContext::onHeader(
|
FilterHeadersStatus PluginRootContext::onHeader(
|
||||||
const ModelMapperConfigRule& rule) {
|
const ModelMapperConfigRule& rule) {
|
||||||
|
// Increment request count and check for rebuild
|
||||||
|
incrementRequestCount();
|
||||||
|
|
||||||
|
// Check memory threshold and trigger rebuild if needed
|
||||||
|
std::string value;
|
||||||
|
if (getValue({"plugin_vm_memory"}, &value)) {
|
||||||
|
// The value is stored as binary uint64_t, convert to string for logging
|
||||||
|
if (value.size() == sizeof(uint64_t)) {
|
||||||
|
uint64_t memory_size;
|
||||||
|
memcpy(&memory_size, value.data(), sizeof(uint64_t));
|
||||||
|
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
|
||||||
|
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
|
||||||
|
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
|
||||||
|
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
|
||||||
|
setFilterState("wasm_need_rebuild", "true");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("invalid memory size format");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("get vm memory size failed");
|
||||||
|
}
|
||||||
|
|
||||||
if (!Wasm::Common::Http::hasRequestBody()) {
|
if (!Wasm::Common::Http::hasRequestBody()) {
|
||||||
return FilterHeadersStatus::Continue;
|
return FilterHeadersStatus::Continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,9 +42,10 @@ struct ModelMapperConfigRule {
|
|||||||
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
|
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
|
||||||
std::string default_model_mapping_;
|
std::string default_model_mapping_;
|
||||||
std::vector<std::string> enable_on_path_suffix_ = {
|
std::vector<std::string> enable_on_path_suffix_ = {
|
||||||
"/completions", "/embeddings", "/images/generations",
|
"/completions", "/embeddings", "/images/generations",
|
||||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||||
"/image-synthesis", "/video-synthesis"};
|
"/image-synthesis", "/video-synthesis", "/rerank",
|
||||||
|
"/messages"};
|
||||||
};
|
};
|
||||||
|
|
||||||
// PluginRootContext is the root context for all streams processed by the
|
// PluginRootContext is the root context for all streams processed by the
|
||||||
@@ -60,9 +61,13 @@ class PluginRootContext : public RootContext,
|
|||||||
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
|
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
|
||||||
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
|
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
|
||||||
bool configure(size_t);
|
bool configure(size_t);
|
||||||
|
void incrementRequestCount();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
|
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
|
||||||
|
uint64_t request_count_ = 0;
|
||||||
|
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
|
||||||
|
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Per-stream context.
|
// Per-stream context.
|
||||||
|
|||||||
@@ -123,9 +123,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PluginRootContext::incrementRequestCount() {
|
||||||
|
request_count_++;
|
||||||
|
if (request_count_ >= REBUILD_THRESHOLD) {
|
||||||
|
LOG_DEBUG("Request count reached threshold, triggering rebuild");
|
||||||
|
setFilterState("wasm_need_rebuild", "true");
|
||||||
|
request_count_ = 0; // Reset counter after setting rebuild flag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
FilterHeadersStatus PluginRootContext::onHeader(
|
FilterHeadersStatus PluginRootContext::onHeader(
|
||||||
PluginContext& ctx,
|
PluginContext& ctx, const ModelRouterConfigRule& rule) {
|
||||||
const ModelRouterConfigRule& rule) {
|
// Increment request count and check for rebuild
|
||||||
|
incrementRequestCount();
|
||||||
|
|
||||||
|
// Check memory threshold and trigger rebuild if needed
|
||||||
|
std::string value;
|
||||||
|
if (getValue({"plugin_vm_memory"}, &value)) {
|
||||||
|
// The value is stored as binary uint64_t, convert to string for logging
|
||||||
|
if (value.size() == sizeof(uint64_t)) {
|
||||||
|
uint64_t memory_size;
|
||||||
|
memcpy(&memory_size, value.data(), sizeof(uint64_t));
|
||||||
|
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
|
||||||
|
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
|
||||||
|
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
|
||||||
|
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
|
||||||
|
setFilterState("wasm_need_rebuild", "true");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("invalid memory size format");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("get vm memory size failed");
|
||||||
|
}
|
||||||
|
|
||||||
if (!Wasm::Common::Http::hasRequestBody()) {
|
if (!Wasm::Common::Http::hasRequestBody()) {
|
||||||
return FilterHeadersStatus::Continue;
|
return FilterHeadersStatus::Continue;
|
||||||
}
|
}
|
||||||
@@ -157,7 +188,7 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
|||||||
auto content_type_value = content_type_ptr->view();
|
auto content_type_value = content_type_ptr->view();
|
||||||
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
|
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
|
||||||
if (absl::StrContains(content_type_value,
|
if (absl::StrContains(content_type_value,
|
||||||
Wasm::Common::Http::ContentTypeValues::Json)) {
|
Wasm::Common::Http::ContentTypeValues::Json)) {
|
||||||
ctx.mode_ = MODE_JSON;
|
ctx.mode_ = MODE_JSON;
|
||||||
LOG_DEBUG("Enable JSON mode.");
|
LOG_DEBUG("Enable JSON mode.");
|
||||||
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
||||||
@@ -165,12 +196,15 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
|||||||
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
|
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
|
||||||
return FilterHeadersStatus::StopIteration;
|
return FilterHeadersStatus::StopIteration;
|
||||||
}
|
}
|
||||||
if (absl::StrContains(content_type_value,
|
if (absl::StrContains(
|
||||||
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
|
content_type_value,
|
||||||
|
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
|
||||||
// Get the boundary from the content type
|
// Get the boundary from the content type
|
||||||
auto boundary_start = content_type_value.find("boundary=");
|
auto boundary_start = content_type_value.find("boundary=");
|
||||||
if (boundary_start == std::string::npos) {
|
if (boundary_start == std::string::npos) {
|
||||||
LOG_WARN(absl::StrCat("No boundary found in a multipart/form-data content-type: ", content_type_value));
|
LOG_WARN(absl::StrCat(
|
||||||
|
"No boundary found in a multipart/form-data content-type: ",
|
||||||
|
content_type_value));
|
||||||
return FilterHeadersStatus::Continue;
|
return FilterHeadersStatus::Continue;
|
||||||
}
|
}
|
||||||
boundary_start += 9;
|
boundary_start += 9;
|
||||||
@@ -181,21 +215,25 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
|||||||
auto boundary_length = boundary_end - boundary_start;
|
auto boundary_length = boundary_end - boundary_start;
|
||||||
if (boundary_length < 1 || boundary_length > 70) {
|
if (boundary_length < 1 || boundary_length > 70) {
|
||||||
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
|
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
|
||||||
LOG_WARN(absl::StrCat("Invalid boundary value in a multipart/form-data content-type: ", content_type_value));
|
LOG_WARN(absl::StrCat(
|
||||||
|
"Invalid boundary value in a multipart/form-data content-type: ",
|
||||||
|
content_type_value));
|
||||||
return FilterHeadersStatus::Continue;
|
return FilterHeadersStatus::Continue;
|
||||||
}
|
}
|
||||||
auto boundary_value = content_type_value.substr(boundary_start, boundary_end - boundary_start);
|
auto boundary_value = content_type_value.substr(
|
||||||
|
boundary_start, boundary_end - boundary_start);
|
||||||
ctx.mode_ = MODE_MULTIPART;
|
ctx.mode_ = MODE_MULTIPART;
|
||||||
ctx.boundary_ = boundary_value;
|
ctx.boundary_ = boundary_value;
|
||||||
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=", boundary_value));
|
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=",
|
||||||
|
boundary_value));
|
||||||
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
||||||
return FilterHeadersStatus::StopIteration;
|
return FilterHeadersStatus::StopIteration;
|
||||||
}
|
}
|
||||||
return FilterHeadersStatus::Continue;
|
return FilterHeadersStatus::Continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule,
|
FilterDataStatus PluginRootContext::onJsonBody(
|
||||||
std::string_view body) {
|
const ModelRouterConfigRule& rule, std::string_view body) {
|
||||||
const auto& model_key = rule.model_key_;
|
const auto& model_key = rule.model_key_;
|
||||||
const auto& add_provider_header = rule.add_provider_header_;
|
const auto& add_provider_header = rule.add_provider_header_;
|
||||||
const auto& model_to_header = rule.model_to_header_;
|
const auto& model_to_header = rule.model_to_header_;
|
||||||
@@ -231,18 +269,18 @@ FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule
|
|||||||
}
|
}
|
||||||
|
|
||||||
FilterDataStatus PluginRootContext::onMultipartBody(
|
FilterDataStatus PluginRootContext::onMultipartBody(
|
||||||
PluginContext& ctx,
|
PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body,
|
||||||
const ModelRouterConfigRule& rule,
|
|
||||||
WasmDataPtr& body,
|
|
||||||
bool end_stream) {
|
bool end_stream) {
|
||||||
const auto& add_provider_header = rule.add_provider_header_;
|
const auto& add_provider_header = rule.add_provider_header_;
|
||||||
const auto& model_to_header = rule.model_to_header_;
|
const auto& model_to_header = rule.model_to_header_;
|
||||||
|
|
||||||
const auto boundary = ctx.boundary_;
|
const auto boundary = ctx.boundary_;
|
||||||
const auto body_view = body->view();
|
const auto body_view = body->view();
|
||||||
const auto model_param_header = absl::StrCat("Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
|
const auto model_param_header = absl::StrCat(
|
||||||
|
"Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
|
||||||
|
|
||||||
for (size_t pos = 0; (pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
|
for (size_t pos = 0;
|
||||||
|
(pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
|
||||||
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
|
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
|
||||||
pos += boundary.length();
|
pos += boundary.length();
|
||||||
size_t end_pos = body_view.find(boundary, pos);
|
size_t end_pos = body_view.find(boundary, pos);
|
||||||
@@ -264,7 +302,7 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
|||||||
LOG_DEBUG("No value start found in part");
|
LOG_DEBUG("No value start found in part");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
value_start += 4; // Skip the "\r\n\r\n"
|
value_start += 4; // Skip the "\r\n\r\n"
|
||||||
// model parameter should be only one line
|
// model parameter should be only one line
|
||||||
size_t value_end = part.find(CRLF, value_start);
|
size_t value_end = part.find(CRLF, value_start);
|
||||||
if (value_end == std::string_view::npos) {
|
if (value_end == std::string_view::npos) {
|
||||||
@@ -283,8 +321,12 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
|||||||
const auto& model = model_value.substr(pos + 1);
|
const auto& model = model_value.substr(pos + 1);
|
||||||
replaceRequestHeader(add_provider_header, provider);
|
replaceRequestHeader(add_provider_header, provider);
|
||||||
size_t new_size = 0;
|
size_t new_size = 0;
|
||||||
auto new_buffer_data = absl::StrCat(body_view.substr(0, part_pos + value_start), model, body_view.substr(part_pos + value_end));
|
auto new_buffer_data =
|
||||||
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0, std::numeric_limits<size_t>::max(), new_buffer_data, &new_size);
|
absl::StrCat(body_view.substr(0, part_pos + value_start), model,
|
||||||
|
body_view.substr(part_pos + value_end));
|
||||||
|
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0,
|
||||||
|
std::numeric_limits<size_t>::max(),
|
||||||
|
new_buffer_data, &new_size);
|
||||||
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
|
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
|
||||||
", model:", model));
|
", model:", model));
|
||||||
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
|
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
|
||||||
@@ -294,7 +336,8 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We are done now. We can stop processing the body.
|
// We are done now. We can stop processing the body.
|
||||||
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ", body_view.length() , " bytes."));
|
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ",
|
||||||
|
body_view.length(), " bytes."));
|
||||||
ctx.mode_ = MODE_BYPASS;
|
ctx.mode_ = MODE_BYPASS;
|
||||||
return FilterDataStatus::Continue;
|
return FilterDataStatus::Continue;
|
||||||
}
|
}
|
||||||
@@ -324,8 +367,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
|
|||||||
auto* rootCtx = rootContext();
|
auto* rootCtx = rootContext();
|
||||||
body_total_size_ += body_size;
|
body_total_size_ += body_size;
|
||||||
switch (mode_) {
|
switch (mode_) {
|
||||||
case MODE_JSON:
|
case MODE_JSON: {
|
||||||
{
|
|
||||||
if (!end_stream) {
|
if (!end_stream) {
|
||||||
return FilterDataStatus::StopIterationAndBuffer;
|
return FilterDataStatus::StopIterationAndBuffer;
|
||||||
}
|
}
|
||||||
@@ -333,8 +375,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
|
|||||||
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
||||||
return rootCtx->onJsonBody(*config_, body->view());
|
return rootCtx->onJsonBody(*config_, body->view());
|
||||||
}
|
}
|
||||||
case MODE_MULTIPART:
|
case MODE_MULTIPART: {
|
||||||
{
|
|
||||||
auto body =
|
auto body =
|
||||||
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
||||||
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);
|
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);
|
||||||
|
|||||||
@@ -48,9 +48,10 @@ struct ModelRouterConfigRule {
|
|||||||
std::string add_provider_header_;
|
std::string add_provider_header_;
|
||||||
std::string model_to_header_;
|
std::string model_to_header_;
|
||||||
std::vector<std::string> enable_on_path_suffix_ = {
|
std::vector<std::string> enable_on_path_suffix_ = {
|
||||||
"/completions", "/embeddings", "/images/generations",
|
"/completions", "/embeddings", "/images/generations",
|
||||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||||
"/image-synthesis", "/video-synthesis"};
|
"/image-synthesis", "/video-synthesis", "/rerank",
|
||||||
|
"/messages"};
|
||||||
};
|
};
|
||||||
|
|
||||||
class PluginContext;
|
class PluginContext;
|
||||||
@@ -65,13 +66,20 @@ class PluginRootContext : public RootContext,
|
|||||||
: RootContext(id, root_id) {}
|
: RootContext(id, root_id) {}
|
||||||
~PluginRootContext() {}
|
~PluginRootContext() {}
|
||||||
bool onConfigure(size_t) override;
|
bool onConfigure(size_t) override;
|
||||||
FilterHeadersStatus onHeader(PluginContext& ctx, const ModelRouterConfigRule&);
|
FilterHeadersStatus onHeader(PluginContext& ctx,
|
||||||
|
const ModelRouterConfigRule&);
|
||||||
FilterDataStatus onJsonBody(const ModelRouterConfigRule&, std::string_view);
|
FilterDataStatus onJsonBody(const ModelRouterConfigRule&, std::string_view);
|
||||||
FilterDataStatus onMultipartBody(PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body, bool end_stream);
|
FilterDataStatus onMultipartBody(PluginContext& ctx,
|
||||||
|
const ModelRouterConfigRule& rule,
|
||||||
|
WasmDataPtr& body, bool end_stream);
|
||||||
bool configure(size_t);
|
bool configure(size_t);
|
||||||
|
void incrementRequestCount();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool parsePluginConfig(const json&, ModelRouterConfigRule&) override;
|
bool parsePluginConfig(const json&, ModelRouterConfigRule&) override;
|
||||||
|
uint64_t request_count_ = 0;
|
||||||
|
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
|
||||||
|
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Per-stream context.
|
// Per-stream context.
|
||||||
|
|||||||
@@ -157,7 +157,8 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
|
|||||||
body_.set(request_json);
|
body_.set(request_json);
|
||||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||||
FilterHeadersStatus::StopIteration);
|
FilterHeadersStatus::StopIteration);
|
||||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelRouterTest, ModelToHeader) {
|
TEST_F(ModelRouterTest, ModelToHeader) {
|
||||||
@@ -183,7 +184,8 @@ TEST_F(ModelRouterTest, ModelToHeader) {
|
|||||||
body_.set(request_json);
|
body_.set(request_json);
|
||||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||||
FilterHeadersStatus::StopIteration);
|
FilterHeadersStatus::StopIteration);
|
||||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelRouterTest, IgnorePath) {
|
TEST_F(ModelRouterTest, IgnorePath) {
|
||||||
@@ -210,7 +212,8 @@ TEST_F(ModelRouterTest, IgnorePath) {
|
|||||||
body_.set(request_json);
|
body_.set(request_json);
|
||||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||||
FilterHeadersStatus::Continue);
|
FilterHeadersStatus::Continue);
|
||||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
|
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
|
||||||
@@ -244,10 +247,10 @@ TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
|
|||||||
route_name_ = "route-a";
|
route_name_ = "route-a";
|
||||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||||
FilterHeadersStatus::StopIteration);
|
FilterHeadersStatus::StopIteration);
|
||||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
|
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
|
||||||
std::string configuration = R"({
|
std::string configuration = R"({
|
||||||
"addProviderHeader": "x-higress-llm-provider"
|
"addProviderHeader": "x-higress-llm-provider"
|
||||||
@@ -257,8 +260,11 @@ TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
|
|||||||
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
||||||
|
|
||||||
path_ = "/v1/chat/completions";
|
path_ = "/v1/chat/completions";
|
||||||
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
|
content_type_ =
|
||||||
std::string request_data = std::regex_replace(R"(
|
"multipart/form-data; "
|
||||||
|
"boundary=--------------------------100751621174704322650451";
|
||||||
|
std::string request_data = std::regex_replace(
|
||||||
|
R"(
|
||||||
----------------------------100751621174704322650451
|
----------------------------100751621174704322650451
|
||||||
Content-Disposition: form-data; name="purpose"
|
Content-Disposition: form-data; name="purpose"
|
||||||
|
|
||||||
@@ -274,16 +280,21 @@ Content-Type: application/json
|
|||||||
[
|
[
|
||||||
]
|
]
|
||||||
----------------------------100751621174704322650451--
|
----------------------------100751621174704322650451--
|
||||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
)",
|
||||||
|
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||||
EXPECT_CALL(*mock_context_,
|
EXPECT_CALL(*mock_context_,
|
||||||
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
||||||
.WillOnce([&](WasmBufferType, size_t start, size_t length, std::string_view body) {
|
.WillOnce([&](WasmBufferType, size_t start, size_t length,
|
||||||
std::cerr << "===============" << "\n";
|
std::string_view body) {
|
||||||
|
std::cerr << "==============="
|
||||||
|
<< "\n";
|
||||||
std::cerr << body << "\n";
|
std::cerr << body << "\n";
|
||||||
std::cerr << "===============" << "\n";
|
std::cerr << "==============="
|
||||||
|
<< "\n";
|
||||||
EXPECT_EQ(start, 0);
|
EXPECT_EQ(start, 0);
|
||||||
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
|
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
|
||||||
auto expected_body= std::regex_replace(R"(
|
auto expected_body = std::regex_replace(
|
||||||
|
R"(
|
||||||
----------------------------100751621174704322650451
|
----------------------------100751621174704322650451
|
||||||
Content-Disposition: form-data; name="purpose"
|
Content-Disposition: form-data; name="purpose"
|
||||||
|
|
||||||
@@ -292,7 +303,9 @@ batch
|
|||||||
Content-Disposition: form-data; name="model"
|
Content-Disposition: form-data; name="model"
|
||||||
|
|
||||||
qwen-turbo
|
qwen-turbo
|
||||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
)",
|
||||||
|
std::regex("\n"),
|
||||||
|
"\r\n"); // Multipart data requires CRLF line endings
|
||||||
EXPECT_EQ(body, expected_body);
|
EXPECT_EQ(body, expected_body);
|
||||||
return WasmResult::Ok;
|
return WasmResult::Ok;
|
||||||
});
|
});
|
||||||
@@ -308,42 +321,54 @@ qwen-turbo
|
|||||||
|
|
||||||
auto last_body_size = 0;
|
auto last_body_size = 0;
|
||||||
|
|
||||||
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
auto body = request_data.substr(
|
||||||
|
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
|
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
|
||||||
|
2 /* "model" + CRLF + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
|
body = request_data.substr(
|
||||||
|
0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
|
body = request_data.substr(
|
||||||
|
0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 + 50 /* "qwen-turbo" + CRLF + boundary */);
|
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 +
|
||||||
|
50 /* "qwen-turbo" + CRLF + boundary */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body_.set(request_data);
|
body_.set(request_data);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
|
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
|
||||||
std::string configuration = R"(
|
std::string configuration = R"(
|
||||||
{
|
{
|
||||||
"modelToHeader": "x-higress-llm-model"
|
"modelToHeader": "x-higress-llm-model"
|
||||||
})";
|
})";
|
||||||
@@ -352,8 +377,11 @@ TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
|
|||||||
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
||||||
|
|
||||||
path_ = "/v1/chat/completions";
|
path_ = "/v1/chat/completions";
|
||||||
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
|
content_type_ =
|
||||||
std::string request_data = std::regex_replace(R"(
|
"multipart/form-data; "
|
||||||
|
"boundary=--------------------------100751621174704322650451";
|
||||||
|
std::string request_data = std::regex_replace(
|
||||||
|
R"(
|
||||||
----------------------------100751621174704322650451
|
----------------------------100751621174704322650451
|
||||||
Content-Disposition: form-data; name="purpose"
|
Content-Disposition: form-data; name="purpose"
|
||||||
|
|
||||||
@@ -369,7 +397,8 @@ Content-Type: application/json
|
|||||||
[
|
[
|
||||||
]
|
]
|
||||||
----------------------------100751621174704322650451--
|
----------------------------100751621174704322650451--
|
||||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
)",
|
||||||
|
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||||
EXPECT_CALL(*mock_context_,
|
EXPECT_CALL(*mock_context_,
|
||||||
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
||||||
.Times(0);
|
.Times(0);
|
||||||
@@ -384,38 +413,50 @@ Content-Type: application/json
|
|||||||
|
|
||||||
auto last_body_size = 0;
|
auto last_body_size = 0;
|
||||||
|
|
||||||
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
auto body = request_data.substr(
|
||||||
|
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
|
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
|
||||||
|
2 /* "model" + CRLF + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
|
body = request_data.substr(
|
||||||
|
0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::StopIterationAndBuffer);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
|
body = request_data.substr(
|
||||||
|
0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
|
body = request_data.substr(
|
||||||
|
0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
|
||||||
body_.set(body);
|
body_.set(body);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
last_body_size = body.size();
|
last_body_size = body.size();
|
||||||
|
|
||||||
body_.set(request_data);
|
body_.set(request_data);
|
||||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
|
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
|
||||||
|
FilterDataStatus::Continue);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace model_router
|
} // namespace model_router
|
||||||
|
|||||||
Reference in New Issue
Block a user