add model router plugin (#1414)

This commit is contained in:
澄潭
2024-10-21 16:46:18 +08:00
committed by GitHub
parent badf4b7101
commit f8d62a8ac3
8 changed files with 621 additions and 6 deletions

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
name = "model_router.wasm",
srcs = [
"plugin.cc",
"plugin.h",
],
deps = [
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics_higress",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"//common:http_util",
"//common:rule_util",
],
)
cc_library(
name = "model_router_lib",
srcs = [
"plugin.cc",
],
hdrs = [
"plugin.h",
],
copts = ["-DNULL_PLUGIN"],
deps = [
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"@proxy_wasm_cpp_host//:lib",
"//common:http_util_nullvm",
"//common:rule_util_nullvm",
],
)
cc_test(
name = "model_router_test",
srcs = [
"plugin_test.cc",
],
copts = ["-DNULL_PLUGIN"],
deps = [
":model_router_lib",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@proxy_wasm_cpp_host//:lib",
],
)
declare_wasm_image_targets(
name = "model_router",
wasm_file = ":model_router.wasm",
)

View File

@@ -0,0 +1,64 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`260`
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `enable` | bool | 选填 | false | 是否开启基于model参数路由 |
| `model_key` | string | 选填 | model | 请求body中model参数的位置 |
| `add_header_key` | string | 选填 | x-higress-llm-provider | 从model参数中解析出的provider名字放到哪个请求header中 |
## 效果说明
如下开启基于model参数路由的功能
```yaml
enable: true
```
开启后,插件将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
"model": "qwen/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: qwen
原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,63 @@
## Function Description
The `model-router` plugin implements the functionality of routing based on the `model` parameter in the LLM protocol.
## Runtime Properties
Plugin Execution Phase: `Default Phase`
Plugin Execution Priority: `260`
## Configuration Fields
| Name | Data Type | Filling Requirement | Default Value | Description |
| -------------------- | ------------- | --------------------- | ---------------------- | ----------------------------------------------------- |
| `enable` | bool | Optional | false | Whether to enable routing based on the `model` parameter |
| `model_key` | string | Optional | model | The location of the `model` parameter in the request body |
| `add_header_key` | string | Optional | x-higress-llm-provider | The header where the parsed provider name from the `model` parameter will be placed |
## Effect Description
To enable routing based on the `model` parameter, use the following configuration:
```yaml
enable: true
```
After enabling, the plugin extracts the provider part (if any) from the `model` parameter in the request, and sets it in the `x-higress-llm-provider` request header for subsequent routing. It also rewrites the `model` parameter to the model name part. For example, the original LLM request body is:
```json
{
"model": "openai/gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by the plugin, the following request header (which can be used for routing matching) will be added:
`x-higress-llm-provider: openai`
The original LLM request body will be modified to:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,189 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_router/plugin.h"
#include <array>
#include <limits>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "common/http_util.h"
#include "common/json_util.h"
using ::nlohmann::json;
using ::Wasm::Common::JsonArrayIterate;
using ::Wasm::Common::JsonGetField;
using ::Wasm::Common::JsonObjectIterate;
using ::Wasm::Common::JsonValueAs;
#ifdef NULL_PLUGIN
namespace proxy_wasm {
namespace null_plugin {
namespace model_router {
PROXY_WASM_NULL_PLUGIN_REGISTRY
#endif
static RegisterContextFactory register_ModelRouter(
CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext));
namespace {
constexpr std::string_view SetDecoderBufferLimitKey =
"SetRequestBodyBufferLimit";
constexpr std::string_view DefaultMaxBodyBytes = "10485760";
} // namespace
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelRouterConfigRule& rule) {
if (auto it = configuration.find("enable"); it != configuration.end()) {
if (it->is_boolean()) {
rule.enable_ = it->get<bool>();
} else {
LOG_WARN("Invalid type for enable. Expected boolean.");
return false;
}
}
if (auto it = configuration.find("model_key"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for model_key. Expected string.");
return false;
}
}
if (auto it = configuration.find("add_header_key");
it != configuration.end()) {
if (it->is_string()) {
rule.add_header_key_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for add_header_key. Expected string.");
return false;
}
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
return false;
}
return true;
}
bool PluginRootContext::configure(size_t configuration_size) {
auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration,
0, configuration_size);
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseAuthRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
}
FilterHeadersStatus PluginRootContext::onHeader(
const ModelRouterConfigRule& rule) {
if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
getRequestHeader(Wasm::Common::Http::Header::ContentType);
if (!absl::StrContains(content_type_value->view(),
Wasm::Common::Http::ContentTypeValues::Json)) {
return FilterHeadersStatus::Continue;
}
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
return FilterHeadersStatus::StopIteration;
}
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_header_key = rule.add_header_key_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
return FilterDataStatus::Continue;
}
auto body_json = body_json_opt.value();
if (body_json.contains(model_key)) {
std::string model_value = body_json[model_key];
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_header_key, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route not work, model:", model_value));
}
}
return FilterDataStatus::Continue;
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->onHeaders([rootCtx, this](const auto& config) {
auto ret = rootCtx->onHeader(config);
if (ret == FilterHeadersStatus::StopIteration) {
this->config_ = &config;
}
return ret;
});
}
FilterDataStatus PluginContext::onRequestBody(size_t body_size,
bool end_stream) {
if (config_ == nullptr) {
return FilterDataStatus::Continue;
}
body_total_size_ += body_size;
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
auto* rootCtx = rootContext();
return rootCtx->onBody(*config_, body->view());
}
#ifdef NULL_PLUGIN
} // namespace model_router
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,85 @@
/*
* Copyright (c) 2022 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <string>
#include <unordered_set>
#include "common/route_rule_matcher.h"
#define ASSERT(_X) assert(_X)
#ifndef NULL_PLUGIN
#include "proxy_wasm_intrinsics.h"
#else
#include "include/proxy-wasm/null_plugin.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_router {
#endif
struct ModelRouterConfigRule {
bool enable_ = false;
std::string model_key_ = "model";
std::string add_header_key_ = "x-higress-llm-provider";
};
// PluginRootContext is the root context for all streams processed by the
// thread. It has the same lifetime as the worker thread and acts as target for
// interactions that outlives individual stream, e.g. timer, async calls.
class PluginRootContext : public RootContext,
public RouteRuleMatcher<ModelRouterConfigRule> {
public:
PluginRootContext(uint32_t id, std::string_view root_id)
: RootContext(id, root_id) {}
~PluginRootContext() {}
bool onConfigure(size_t) override;
FilterHeadersStatus onHeader(const ModelRouterConfigRule&);
FilterDataStatus onBody(const ModelRouterConfigRule&, std::string_view);
bool configure(size_t);
private:
bool parsePluginConfig(const json&, ModelRouterConfigRule&) override;
};
// Per-stream context.
class PluginContext : public Context {
public:
explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
FilterDataStatus onRequestBody(size_t, bool) override;
private:
inline PluginRootContext* rootContext() {
return dynamic_cast<PluginRootContext*>(this->root());
}
size_t body_total_size_ = 0;
const ModelRouterConfigRule* config_ = nullptr;
};
#ifdef NULL_PLUGIN
} // namespace model_router
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,144 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_router/plugin.h"
#include <cstddef>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "include/proxy-wasm/context.h"
#include "include/proxy-wasm/null.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_router {
NullPluginRegistry* context_registry_;
RegisterNullVmPluginFactory register_model_router_plugin("model_router", []() {
return std::make_unique<NullPlugin>(model_router::context_registry_);
});
class MockContext : public proxy_wasm::ContextBase {
public:
MockContext(WasmBase* wasm) : ContextBase(wasm) {}
MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType));
MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view));
MOCK_METHOD(WasmResult, setBuffer,
(WasmBufferType, size_t, size_t, std::string_view));
MOCK_METHOD(WasmResult, getHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */,
std::string_view* /*result */));
MOCK_METHOD(WasmResult, addHeaderMapValue,
(WasmHeaderMapType, std::string_view, std::string_view));
MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*));
MOCK_METHOD(WasmResult, setProperty, (std::string_view, std::string_view));
};
class ModelRouterTest : public ::testing::Test {
protected:
ModelRouterTest() {
// Initialize test VM
test_vm_ = createNullVm();
wasm_base_ = std::make_unique<WasmBase>(
std::move(test_vm_), "test-vm", "", "",
std::unordered_map<std::string, std::string>{},
AllowedCapabilitiesMap{});
wasm_base_->load("model_router");
wasm_base_->initialize();
// Initialize host side context
mock_context_ = std::make_unique<MockContext>(wasm_base_.get());
current_context_ = mock_context_.get();
// Initialize Wasm sandbox context
root_context_ = std::make_unique<PluginRootContext>(0, "");
context_ = std::make_unique<PluginContext>(1, root_context_.get());
ON_CALL(*mock_context_, log(testing::_, testing::_))
.WillByDefault([](uint32_t, std::string_view m) {
std::cerr << m << "\n";
return WasmResult::Ok;
});
ON_CALL(*mock_context_, getBuffer(testing::_))
.WillByDefault([&](WasmBufferType type) {
if (type == WasmBufferType::HttpRequestBody) {
return &body_;
}
return &config_;
});
ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view* result) {
if (header == "content-type") {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
}
return WasmResult::Ok;
});
ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view value) { return WasmResult::Ok; });
ON_CALL(*mock_context_, getProperty(testing::_, testing::_))
.WillByDefault([&](std::string_view path, std::string* result) {
*result = route_name_;
return WasmResult::Ok;
});
ON_CALL(*mock_context_, setProperty(testing::_, testing::_))
.WillByDefault(
[&](std::string_view, std::string_view) { return WasmResult::Ok; });
}
~ModelRouterTest() override {}
std::unique_ptr<WasmBase> wasm_base_;
std::unique_ptr<WasmVm> test_vm_;
std::unique_ptr<MockContext> mock_context_;
std::unique_ptr<PluginRootContext> root_context_;
std::unique_ptr<PluginContext> context_;
std::string route_name_;
BufferBase body_;
BufferBase config_;
};
TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
"enable": true
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
EXPECT_CALL(
*mock_context_,
addHeaderMapValue(testing::_, std::string_view("x-higress-llm-provider"),
std::string_view("qwen")));
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
} // namespace model_router
} // namespace null_plugin
} // namespace proxy_wasm