add model-mapper plugin & optimize model-router plugin (#1538)

This commit is contained in:
澄潭
2024-11-22 22:24:42 +08:00
committed by GitHub
parent 96575b982e
commit e68a8ac25f
11 changed files with 1067 additions and 73 deletions

View File

@@ -51,41 +51,54 @@ constexpr std::string_view DefaultMaxBodyBytes = "10485760";
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelRouterConfigRule& rule) {
if (auto it = configuration.find("enable"); it != configuration.end()) {
if (it->is_boolean()) {
rule.enable_ = it->get<bool>();
} else {
LOG_WARN("Invalid type for enable. Expected boolean.");
return false;
}
}
if (auto it = configuration.find("model_key"); it != configuration.end()) {
if (auto it = configuration.find("modelKey"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for model_key. Expected string.");
LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
if (auto it = configuration.find("add_header_key");
if (auto it = configuration.find("addProviderHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.add_header_key_ = it->get<std::string>();
rule.add_provider_header_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for add_header_key. Expected string.");
LOG_ERROR("Invalid type for addProviderHeader. Expected string.");
return false;
}
}
if (auto it = configuration.find("modelToHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.model_to_header_ = it->get<std::string>();
} else {
LOG_ERROR("Invalid type for modelToHeader. Expected string.");
return false;
}
}
if (!JsonArrayIterate(
configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
if (item.is_string()) {
rule.enable_on_path_suffix_.emplace_back(item.get<std::string>());
return true;
}
return false;
})) {
LOG_ERROR("Invalid type for item in enableOnPathSuffix. Expected string.");
return false;
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
LOG_ERROR("configuration has errors initialization will not continue.");
return false;
}
return true;
@@ -97,13 +110,13 @@ bool PluginRootContext::configure(size_t configuration_size) {
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
@@ -111,7 +124,25 @@ bool PluginRootContext::configure(size_t configuration_size) {
FilterHeadersStatus PluginRootContext::onHeader(
const ModelRouterConfigRule& rule) {
if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
auto params_pos = path.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = path.size();
} else {
uri_end = params_pos;
}
bool enable = false;
for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
enable = true;
break;
}
}
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
@@ -128,7 +159,8 @@ FilterHeadersStatus PluginRootContext::onHeader(
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_header_key = rule.add_header_key_;
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
@@ -137,18 +169,24 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
auto body_json = body_json_opt.value();
if (body_json.contains(model_key)) {
std::string model_value = body_json[model_key];
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_header_key, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route not work, model:", model_value));
if (!model_to_header.empty()) {
replaceRequestHeader(model_to_header, model_value);
}
if (!add_provider_header.empty()) {
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route to provider not work, model:",
model_value));
}
}
}
return FilterDataStatus::Continue;