add model-mapper plugin & optimize model-router plugin (#1538)

This commit is contained in:
澄潭
2024-11-22 22:24:42 +08:00
committed by GitHub
parent 96575b982e
commit e68a8ac25f
11 changed files with 1067 additions and 73 deletions

View File

@@ -1,33 +1,35 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`260`
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `enable` | bool | 选填 | false | 是否开启基于model参数路由 |
| `model_key` | string | 选填 | model | 请求body中model参数的位置 |
| `add_header_key` | string | 选填 | x-higress-llm-provider | 从model参数中解析出的provider名字放到哪个请求header中 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 |
## 运行属性
插件执行阶段:认证阶段
插件执行优先级900
## 效果说明
如下开启基于model参数路由的功能:
### 基于 model 参数进行路由
需要做如下配置:
```yaml
enable: true
modelToHeader: x-higress-llm-model
```
开启后,插件将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
插件将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
```json
{
"model": "qwen/qwen-long",
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
@@ -43,7 +45,39 @@ enable: true
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: qwen
x-higress-llm-model: qwen-long
### 提取 model 参数中的 provider 字段用于路由
> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
需要做如下配置:
```yaml
addProviderHeader: x-higress-llm-provider
```
插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: dashscope
原始的 LLM 请求体将被改成:

View File

@@ -1,38 +1,41 @@
## Function Description
The `model-router` plugin implements the functionality of routing based on the `model` parameter in the LLM protocol.
## Runtime Properties
Plugin Execution Phase: `Default Phase`
Plugin Execution Priority: `260`
The `model-router` plugin implements the function of routing based on the model parameter in the LLM protocol.
## Configuration Fields
| Name | Data Type | Filling Requirement | Default Value | Description |
| -------------------- | ------------- | --------------------- | ---------------------- | ----------------------------------------------------- |
| `enable` | bool | Optional | false | Whether to enable routing based on the `model` parameter |
| `model_key` | string | Optional | model | The location of the `model` parameter in the request body |
| `add_header_key` | string | Optional | x-higress-llm-provider | The header where the parsed provider name from the `model` parameter will be placed |
| Name | Data Type | Filling Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | The location of the model parameter in the request body |
| `addProviderHeader` | string | Optional | - | Which request header to place the provider name parsed from the model parameter |
| `modelToHeader` | string | Optional | - | Which request header to directly place the model parameter |
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only effective for requests with these specific path suffixes |
## Runtime Attributes
Plugin execution phase: Authentication phase
Plugin execution priority: 900
## Effect Description
To enable routing based on the `model` parameter, use the following configuration:
### Routing Based on the model Parameter
The following configuration is required:
```yaml
enable: true
modelToHeader: x-higress-llm-model
```
After enabling, the plugin extracts the provider part (if any) from the `model` parameter in the request, and sets it in the `x-higress-llm-provider` request header for subsequent routing. It also rewrites the `model` parameter to the model name part. For example, the original LLM request body is:
The plugin will extract the model parameter from the request and set it in the x-higress-llm-model request header, which can be used for subsequent routing. For example, the original LLM request body:
```json
{
"model": "openai/gpt-4o",
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
@@ -40,24 +43,55 @@ After enabling, the plugin extracts the provider part (if any) from the `model`
}
```
After processing by the plugin, the following request header (which can be used for routing matching) will be added:
After processing by this plugin, the following request header (which can be used for route matching) will be added:
`x-higress-llm-provider: openai`
x-higress-llm-model: qwen-long
The original LLM request body will be modified to:
### Extracting the provider Field from the model Parameter for Routing
> Note that this mode requires the client to specify the provider using a `/` separator in the model parameter.
The following configuration is required:
```yaml
addProviderHeader: x-higress-llm-provider
```
The plugin will extract the provider part (if present) from the model parameter in the request and set it in the x-higress-llm-provider request header, which can be used for subsequent routing, and rewrite the model parameter to the model name part. For example, the original LLM request body:
```json
{
"model": "gpt-4o",
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header (which can be used for route matching) will be added:
x-higress-llm-provider: dashscope
The original LLM request body will be changed to:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}

View File

@@ -51,41 +51,54 @@ constexpr std::string_view DefaultMaxBodyBytes = "10485760";
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelRouterConfigRule& rule) {
if (auto it = configuration.find("enable"); it != configuration.end()) {
if (it->is_boolean()) {
rule.enable_ = it->get<bool>();
} else {
LOG_WARN("Invalid type for enable. Expected boolean.");
return false;
}
}
if (auto it = configuration.find("model_key"); it != configuration.end()) {
if (auto it = configuration.find("modelKey"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for model_key. Expected string.");
LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
if (auto it = configuration.find("add_header_key");
if (auto it = configuration.find("addProviderHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.add_header_key_ = it->get<std::string>();
rule.add_provider_header_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for add_header_key. Expected string.");
LOG_ERROR("Invalid type for addProviderHeader. Expected string.");
return false;
}
}
if (auto it = configuration.find("modelToHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.model_to_header_ = it->get<std::string>();
} else {
LOG_ERROR("Invalid type for modelToHeader. Expected string.");
return false;
}
}
if (!JsonArrayIterate(
configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
if (item.is_string()) {
rule.enable_on_path_suffix_.emplace_back(item.get<std::string>());
return true;
}
return false;
})) {
LOG_ERROR("Invalid type for item in enableOnPathSuffix. Expected string.");
return false;
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
LOG_ERROR("configuration has errors initialization will not continue.");
return false;
}
return true;
@@ -97,13 +110,13 @@ bool PluginRootContext::configure(size_t configuration_size) {
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
@@ -111,7 +124,25 @@ bool PluginRootContext::configure(size_t configuration_size) {
FilterHeadersStatus PluginRootContext::onHeader(
const ModelRouterConfigRule& rule) {
if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
auto params_pos = path.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = path.size();
} else {
uri_end = params_pos;
}
bool enable = false;
for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
enable = true;
break;
}
}
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
@@ -128,7 +159,8 @@ FilterHeadersStatus PluginRootContext::onHeader(
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_header_key = rule.add_header_key_;
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
@@ -137,18 +169,24 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
auto body_json = body_json_opt.value();
if (body_json.contains(model_key)) {
std::string model_value = body_json[model_key];
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_header_key, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route not work, model:", model_value));
if (!model_to_header.empty()) {
replaceRequestHeader(model_to_header, model_value);
}
if (!add_provider_header.empty()) {
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route to provider not work, model:",
model_value));
}
}
}
return FilterDataStatus::Continue;

View File

@@ -37,9 +37,10 @@ namespace model_router {
#endif
struct ModelRouterConfigRule {
bool enable_ = false;
std::string model_key_ = "model";
std::string add_header_key_ = "x-higress-llm-provider";
std::string add_provider_header_;
std::string model_to_header_;
std::vector<std::string> enable_on_path_suffix_ = {"/v1/chat/completions"};
};
// PluginRootContext is the root context for all streams processed by the

View File

@@ -89,6 +89,8 @@ class ModelRouterTest : public ::testing::Test {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
*result = path_;
}
return WasmResult::Ok;
});
@@ -122,6 +124,7 @@ class ModelRouterTest : public ::testing::Test {
std::unique_ptr<PluginRootContext> root_context_;
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string path_;
BufferBase body_;
BufferBase config_;
};
@@ -129,12 +132,13 @@ class ModelRouterTest : public ::testing::Test {
TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
"enable": true
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
@@ -154,19 +158,73 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeader) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(
*mock_context_,
replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"),
std::string_view("qwen-long")));
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, IgnorePath) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/xxxx";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(*mock_context_,
replaceHeaderMapValue(testing::_,
std::string_view("x-higress-llm-provider"),
std::string_view("qwen")))
.Times(0);
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
std::string configuration = R"(
{
"_rules_": [
{
"_match_route_": ["route-a"],
"enable": true
"addProviderHeader": "x-higress-llm-provider"
}
]})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/api/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))