add model-mapper plugin & optimize model-router plugin (#1538)

This commit is contained in:
澄潭
2024-11-22 22:24:42 +08:00
committed by GitHub
parent 96575b982e
commit e68a8ac25f
11 changed files with 1067 additions and 73 deletions

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
proxy_wasm_cc_binary(
name = "model_mapper.wasm",
srcs = [
"plugin.cc",
"plugin.h",
],
deps = [
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics_higress",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"//common:http_util",
"//common:rule_util",
],
)
cc_library(
name = "model_mapper_lib",
srcs = [
"plugin.cc",
],
hdrs = [
"plugin.h",
],
copts = ["-DNULL_PLUGIN"],
deps = [
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"@proxy_wasm_cpp_host//:lib",
"//common:http_util_nullvm",
"//common:rule_util_nullvm",
],
)
cc_test(
name = "model_mapper_test",
srcs = [
"plugin_test.cc",
],
copts = ["-DNULL_PLUGIN"],
deps = [
":model_mapper_lib",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@proxy_wasm_cpp_host//:lib",
],
)
declare_wasm_image_targets(
name = "model_mapper",
wasm_file = ":model_mapper.wasm",
)

View File

@@ -0,0 +1,63 @@
## 功能说明
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 ## 运行属性
插件执行阶段:认证阶段
插件执行优先级800
|
## 效果说明
如下配置
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
例如原本的请求是:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,65 @@
## Function Description
The `model-mapper` plugin implements the functionality of routing based on the model parameter in the LLM protocol.
## Configuration Fields
| Name | Data Type | Filling Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | The location of the model parameter in the request body. |
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model names in the request to the model names supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all models whose names start with “gpt-3-”;<br/>2. Supports using "*" as the key to configure a generic fallback mapping relationship;<br/>3. If the target name in the mapping is an empty string "", it means to keep the original model name. |
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only applies to requests with these specific path suffixes. |
## Runtime Properties
Plugin execution phase: Authentication phase
Plugin execution priority: 800
## Effect Description
With the following configuration:
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
After enabling, model parameters starting with `gpt-4-` will be rewritten to `qwen-max`, `gpt-4o` will be rewritten to `qwen-vl-plus`, and all other models will be rewritten to `qwen-turbo`.
For example, if the original request was:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the original LLM request body will be modified to:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,243 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_mapper/plugin.h"
#include <array>
#include <limits>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "common/http_util.h"
#include "common/json_util.h"
using ::nlohmann::json;
using ::Wasm::Common::JsonArrayIterate;
using ::Wasm::Common::JsonGetField;
using ::Wasm::Common::JsonObjectIterate;
using ::Wasm::Common::JsonValueAs;
#ifdef NULL_PLUGIN
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
PROXY_WASM_NULL_PLUGIN_REGISTRY
#endif
static RegisterContextFactory register_ModelMapper(
CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext));
namespace {
constexpr std::string_view SetDecoderBufferLimitKey =
"SetRequestBodyBufferLimit";
constexpr std::string_view DefaultMaxBodyBytes = "10485760";
} // namespace
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelMapperConfigRule& rule) {
if (auto it = configuration.find("modelKey"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
if (auto it = configuration.find("modelMapping"); it != configuration.end()) {
if (!it->is_object()) {
LOG_ERROR("Invalid type for modelMapping. Expected object.");
return false;
}
auto model_mapping = it->get<Wasm::Common::JsonObject>();
if (!JsonObjectIterate(model_mapping, [&](std::string key) -> bool {
auto model_json = model_mapping.find(key);
if (!model_json->is_string()) {
LOG_ERROR(
"Invalid type for item in modelMapping. Expected string.");
return false;
}
if (key == "*") {
rule.default_model_mapping_ = model_json->get<std::string>();
return true;
}
if (absl::EndsWith(key, "*")) {
rule.prefix_model_mapping_.emplace_back(
absl::StripSuffix(key, "*"), model_json->get<std::string>());
return true;
}
auto ret = rule.exact_model_mapping_.emplace(
key, model_json->get<std::string>());
if (!ret.second) {
LOG_ERROR("Duplicate key in modelMapping: " + key);
return false;
}
return true;
})) {
return false;
}
}
if (!JsonArrayIterate(
configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
if (item.is_string()) {
rule.enable_on_path_suffix_.emplace_back(item.get<std::string>());
return true;
}
return false;
})) {
LOG_WARN("Invalid type for item in enableOnPathSuffix. Expected string.");
return false;
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
return false;
}
return true;
}
bool PluginRootContext::configure(size_t configuration_size) {
auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration,
0, configuration_size);
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
}
FilterHeadersStatus PluginRootContext::onHeader(
const ModelMapperConfigRule& rule) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
auto params_pos = path.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = path.size();
} else {
uri_end = params_pos;
}
bool enable = false;
for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
enable = true;
break;
}
}
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
getRequestHeader(Wasm::Common::Http::Header::ContentType);
if (!absl::StrContains(content_type_value->view(),
Wasm::Common::Http::ContentTypeValues::Json)) {
return FilterHeadersStatus::Continue;
}
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
return FilterHeadersStatus::StopIteration;
}
FilterDataStatus PluginRootContext::onBody(const ModelMapperConfigRule& rule,
std::string_view body) {
const auto& exact_model_mapping = rule.exact_model_mapping_;
const auto& prefix_model_mapping = rule.prefix_model_mapping_;
const auto& default_model_mapping = rule.default_model_mapping_;
const auto& model_key = rule.model_key_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
return FilterDataStatus::Continue;
}
auto body_json = body_json_opt.value();
std::string old_model;
if (body_json.contains(model_key)) {
old_model = body_json[model_key];
}
std::string model =
default_model_mapping.empty() ? old_model : default_model_mapping;
if (auto it = exact_model_mapping.find(old_model);
it != exact_model_mapping.end()) {
model = it->second;
} else {
for (auto& prefix_model_pair : prefix_model_mapping) {
if (absl::StartsWith(old_model, prefix_model_pair.first)) {
model = prefix_model_pair.second;
break;
}
}
}
if (!model.empty() && model != old_model) {
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(
absl::StrCat("model mapped, before:", old_model, ", after:", model));
}
return FilterDataStatus::Continue;
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->onHeaders([rootCtx, this](const auto& config) {
auto ret = rootCtx->onHeader(config);
if (ret == FilterHeadersStatus::StopIteration) {
this->config_ = &config;
}
return ret;
});
}
FilterDataStatus PluginContext::onRequestBody(size_t body_size,
bool end_stream) {
if (config_ == nullptr) {
return FilterDataStatus::Continue;
}
body_total_size_ += body_size;
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
auto* rootCtx = rootContext();
return rootCtx->onBody(*config_, body->view());
}
#ifdef NULL_PLUGIN
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,87 @@
/*
* Copyright (c) 2022 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <string>
#include <unordered_set>
#include "common/route_rule_matcher.h"
#define ASSERT(_X) assert(_X)
#ifndef NULL_PLUGIN
#include "proxy_wasm_intrinsics.h"
#else
#include "include/proxy-wasm/null_plugin.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
#endif
struct ModelMapperConfigRule {
std::string model_key_ = "model";
std::map<std::string, std::string> exact_model_mapping_;
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
std::string default_model_mapping_;
std::vector<std::string> enable_on_path_suffix_ = {"/v1/chat/completions"};
};
// PluginRootContext is the root context for all streams processed by the
// thread. It has the same lifetime as the worker thread and acts as target for
// interactions that outlives individual stream, e.g. timer, async calls.
class PluginRootContext : public RootContext,
public RouteRuleMatcher<ModelMapperConfigRule> {
public:
PluginRootContext(uint32_t id, std::string_view root_id)
: RootContext(id, root_id) {}
~PluginRootContext() {}
bool onConfigure(size_t) override;
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
bool configure(size_t);
private:
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
};
// Per-stream context.
class PluginContext : public Context {
public:
explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
FilterDataStatus onRequestBody(size_t, bool) override;
private:
inline PluginRootContext* rootContext() {
return dynamic_cast<PluginRootContext*>(this->root());
}
size_t body_total_size_ = 0;
const ModelMapperConfigRule* config_ = nullptr;
};
#ifdef NULL_PLUGIN
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,301 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_mapper/plugin.h"
#include <cstddef>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "include/proxy-wasm/context.h"
#include "include/proxy-wasm/null.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
NullPluginRegistry* context_registry_;
RegisterNullVmPluginFactory register_model_mapper_plugin("model_mapper", []() {
return std::make_unique<NullPlugin>(model_mapper::context_registry_);
});
class MockContext : public proxy_wasm::ContextBase {
public:
MockContext(WasmBase* wasm) : ContextBase(wasm) {}
MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType));
MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view));
MOCK_METHOD(WasmResult, setBuffer,
(WasmBufferType, size_t, size_t, std::string_view));
MOCK_METHOD(WasmResult, getHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */,
std::string_view* /*result */));
MOCK_METHOD(WasmResult, replaceHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */,
std::string_view /* value */));
MOCK_METHOD(WasmResult, removeHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */));
MOCK_METHOD(WasmResult, addHeaderMapValue,
(WasmHeaderMapType, std::string_view, std::string_view));
MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*));
MOCK_METHOD(WasmResult, setProperty, (std::string_view, std::string_view));
};
class ModelMapperTest : public ::testing::Test {
protected:
ModelMapperTest() {
// Initialize test VM
test_vm_ = createNullVm();
wasm_base_ = std::make_unique<WasmBase>(
std::move(test_vm_), "test-vm", "", "",
std::unordered_map<std::string, std::string>{},
AllowedCapabilitiesMap{});
wasm_base_->load("model_mapper");
wasm_base_->initialize();
// Initialize host side context
mock_context_ = std::make_unique<MockContext>(wasm_base_.get());
current_context_ = mock_context_.get();
// Initialize Wasm sandbox context
root_context_ = std::make_unique<PluginRootContext>(0, "");
context_ = std::make_unique<PluginContext>(1, root_context_.get());
ON_CALL(*mock_context_, log(testing::_, testing::_))
.WillByDefault([](uint32_t, std::string_view m) {
std::cerr << m << "\n";
return WasmResult::Ok;
});
ON_CALL(*mock_context_, getBuffer(testing::_))
.WillByDefault([&](WasmBufferType type) {
if (type == WasmBufferType::HttpRequestBody) {
return &body_;
}
return &config_;
});
ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view* result) {
if (header == "content-type") {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
*result = path_;
}
return WasmResult::Ok;
});
ON_CALL(*mock_context_,
replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_,
testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view key,
std::string_view value) { return WasmResult::Ok; });
ON_CALL(*mock_context_,
removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view key) {
return WasmResult::Ok;
});
ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view value) { return WasmResult::Ok; });
ON_CALL(*mock_context_, getProperty(testing::_, testing::_))
.WillByDefault([&](std::string_view path, std::string* result) {
if (absl::StartsWith(path, "route_name")) {
*result = route_name_;
} else if (absl::StartsWith(path, "cluster_name")) {
*result = service_name_;
}
return WasmResult::Ok;
});
ON_CALL(*mock_context_, setProperty(testing::_, testing::_))
.WillByDefault(
[&](std::string_view, std::string_view) { return WasmResult::Ok; });
}
~ModelMapperTest() override {}
std::unique_ptr<WasmBase> wasm_base_;
std::unique_ptr<WasmVm> test_vm_;
std::unique_ptr<MockContext> mock_context_;
std::unique_ptr<PluginRootContext> root_context_;
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string service_name_;
std::string path_;
BufferBase body_;
BufferBase config_;
};
TEST_F(ModelMapperTest, ModelMappingTest) {
std::string configuration = R"(
{
"modelMapping": {
"*": "qwen-long",
"gpt-4*": "qwen-max",
"gpt-4o": "qwen-turbo",
"gpt-4o-mini": "qwen-plus",
"text-embedding-v1": ""
}
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "gpt-3.5"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-max"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4o"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4o-mini"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-plus"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "text-embedding-v1"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
}
TEST_F(ModelMapperTest, RouteLevelModelMappingTest) {
std::string configuration = R"(
{
"_rules_": [
{
"_match_route_": ["route-a"],
"_match_service_": ["service-1"],
"modelMapping": {
"*": "qwen-long"
}
},
{
"_match_route_": ["route-b"],
"_match_service_": ["service-2"],
"modelMapping": {
"*": "qwen-max"
}
},
{
"_match_route_": ["route-b"],
"_match_service_": ["service-3"],
"modelMapping": {
"*": "qwen-turbo"
}
}
]})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/api/v1/chat/completions";
std::string request_json = R"({"model": "gpt-4"})";
body_.set(request_json);
route_name_ = "route-a";
service_name_ = "outbound|80||service-1";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
route_name_ = "route-b";
service_name_ = "outbound|80||service-2";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-max"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
route_name_ = "route-b";
service_name_ = "outbound|80||service-3";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
}
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm