mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 21:21:01 +08:00
Compare commits
43 Commits
v1.4.0-rc.
...
v1.4.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63539ca15c | ||
|
|
1eea75f130 | ||
|
|
d333656cc3 | ||
|
|
51dca7055a | ||
|
|
ab1bc0a73a | ||
|
|
ffee7dc5ea | ||
|
|
1ea87f0e7a | ||
|
|
7164653446 | ||
|
|
2a1a391054 | ||
|
|
0785d4aac4 | ||
|
|
4ca4bec2b5 | ||
|
|
174350d3fb | ||
|
|
0380cb03d3 | ||
|
|
15d9f76ff9 | ||
|
|
5f15017963 | ||
|
|
634de3f7f8 | ||
|
|
12cc44b324 | ||
|
|
d53c713561 | ||
|
|
5acc6f73b2 | ||
|
|
2db0b60a98 | ||
|
|
c6e3db95e0 | ||
|
|
ed976c6d06 | ||
|
|
6a40d83ec0 | ||
|
|
2807ddfbb7 | ||
|
|
6e4ade05a8 | ||
|
|
bdd050b926 | ||
|
|
38ddc49360 | ||
|
|
26ec0d3d55 | ||
|
|
909f8bc719 | ||
|
|
863d0e5872 | ||
|
|
3e7a63bd9b | ||
|
|
206152daa0 | ||
|
|
812edf1490 | ||
|
|
b00f79f3af | ||
|
|
ed05da13f4 | ||
|
|
53bccf89f4 | ||
|
|
51b9d9ec4b | ||
|
|
50f79c9099 | ||
|
|
93966bf14b | ||
|
|
ffa690994b | ||
|
|
ca1ad1dc73 | ||
|
|
e09edff827 | ||
|
|
2fee28d4e8 |
@@ -138,11 +138,11 @@ export ENVOY_TAR_PATH:=/home/package/envoy.tar.gz
|
||||
|
||||
external/package/envoy-amd64.tar.gz:
|
||||
# cd external/proxy; BUILD_WITH_CONTAINER=1 make test_release
|
||||
cd external/package; wget -O envoy-amd64.tar.gz "https://github.com/alibaba/higress/releases/download/v1.4.0-rc.1/envoy-symbol-amd64.tar.gz"
|
||||
cd external/package; wget -O envoy-amd64.tar.gz "https://github.com/alibaba/higress/releases/download/v1.4.0/envoy-symbol-amd64.tar.gz"
|
||||
|
||||
external/package/envoy-arm64.tar.gz:
|
||||
# cd external/proxy; BUILD_WITH_CONTAINER=1 make test_release
|
||||
cd external/package; wget -O envoy-arm64.tar.gz "https://github.com/alibaba/higress/releases/download/v1.4.0-rc.1/envoy-symbol-arm64.tar.gz"
|
||||
cd external/package; wget -O envoy-arm64.tar.gz "https://github.com/alibaba/higress/releases/download/v1.4.0/envoy-symbol-arm64.tar.gz"
|
||||
|
||||
build-pilot:
|
||||
cd external/istio; rm -rf out/linux_amd64; GOOS_LOCAL=linux TARGET_OS=linux TARGET_ARCH=amd64 BUILD_WITH_CONTAINER=1 make build-linux
|
||||
@@ -177,8 +177,8 @@ install: pre-install
|
||||
cd helm/higress; helm dependency build
|
||||
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
|
||||
|
||||
ENVOY_LATEST_IMAGE_TAG ?= sha-d91b22f
|
||||
ISTIO_LATEST_IMAGE_TAG ?= sha-d91b22f
|
||||
ENVOY_LATEST_IMAGE_TAG ?= sha-93966bf
|
||||
ISTIO_LATEST_IMAGE_TAG ?= sha-b00f79f
|
||||
|
||||
install-dev: pre-install
|
||||
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'
|
||||
@@ -305,7 +305,7 @@ run-higress-e2e-test:
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-controller --for=condition=Available
|
||||
@echo -e "\n\033[36mWaiting higress-gateway to be ready...\033[0m\n"
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-gateway --for=condition=Available
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go --ingress-class=higress --debug=true --test-area=all
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go --ingress-class=higress --debug=true --test-area=all --execute-tests=$(TEST_SHORTNAME)
|
||||
|
||||
# run-higress-e2e-test-run starts to run ingress e2e conformance tests.
|
||||
.PHONY: run-higress-e2e-test-run
|
||||
@@ -315,7 +315,7 @@ run-higress-e2e-test-run:
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-controller --for=condition=Available
|
||||
@echo -e "\n\033[36mWaiting higress-gateway to be ready...\033[0m\n"
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-gateway --for=condition=Available
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go --ingress-class=higress --debug=true --test-area=run
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go --ingress-class=higress --debug=true --test-area=run --execute-tests=$(TEST_SHORTNAME)
|
||||
|
||||
# run-higress-e2e-test-clean starts to clean ingress e2e tests.
|
||||
.PHONY: run-higress-e2e-test-clean
|
||||
@@ -345,7 +345,7 @@ run-higress-e2e-test-wasmplugin:
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-controller --for=condition=Available
|
||||
@echo -e "\n\033[36mWaiting higress-gateway to be ready...\033[0m\n"
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-gateway --for=condition=Available
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go -isWasmPluginTest=true -wasmPluginType=$(PLUGIN_TYPE) -wasmPluginName=$(PLUGIN_NAME) --ingress-class=higress --debug=true --test-area=all
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go -isWasmPluginTest=true -wasmPluginType=$(PLUGIN_TYPE) -wasmPluginName=$(PLUGIN_NAME) --ingress-class=higress --debug=true --test-area=all --execute-tests=$(TEST_SHORTNAME)
|
||||
|
||||
# run-higress-e2e-test-wasmplugin-run starts to run ingress e2e conformance tests.
|
||||
.PHONY: run-higress-e2e-test-wasmplugin-run
|
||||
@@ -355,7 +355,7 @@ run-higress-e2e-test-wasmplugin-run:
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-controller --for=condition=Available
|
||||
@echo -e "\n\033[36mWaiting higress-gateway to be ready...\033[0m\n"
|
||||
kubectl wait --timeout=10m -n higress-system deployment/higress-gateway --for=condition=Available
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go -isWasmPluginTest=true -wasmPluginType=$(PLUGIN_TYPE) -wasmPluginName=$(PLUGIN_NAME) --ingress-class=higress --debug=true --test-area=run
|
||||
go test -v -tags conformance ./test/e2e/e2e_test.go -isWasmPluginTest=true -wasmPluginType=$(PLUGIN_TYPE) -wasmPluginName=$(PLUGIN_NAME) --ingress-class=higress --debug=true --test-area=run --execute-tests=$(TEST_SHORTNAME)
|
||||
|
||||
# run-higress-e2e-test-wasmplugin-clean starts to clean ingress e2e tests.
|
||||
.PHONY: run-higress-e2e-test-wasmplugin-clean
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/pkg/cmd/hgctl/kubernetes"
|
||||
"github.com/alibaba/higress/pkg/cmd/options"
|
||||
"istio.io/istio/istioctl/pkg/writer/envoy/configdump"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"sigs.k8s.io/yaml"
|
||||
)
|
||||
@@ -61,6 +62,23 @@ func NewDefaultGetEnvoyConfigOptions() *GetEnvoyConfigOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func setupConfigdumpEnvoyConfigWriter(debug []byte, stdout io.Writer) (*configdump.ConfigWriter, error) {
|
||||
cw := &configdump.ConfigWriter{Stdout: stdout}
|
||||
err := cw.Prime(debug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cw, nil
|
||||
}
|
||||
|
||||
func GetEnvoyConfigWriter(config *GetEnvoyConfigOptions, stdout io.Writer) (*configdump.ConfigWriter, error) {
|
||||
configDump, err := retrieveConfigDump(config.PodName, config.PodNamespace, config.BindAddress, config.IncludeEds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return setupConfigdumpEnvoyConfigWriter(configDump, stdout)
|
||||
}
|
||||
|
||||
func GetEnvoyConfig(config *GetEnvoyConfigOptions) ([]byte, error) {
|
||||
configDump, err := retrieveConfigDump(config.PodName, config.PodNamespace, config.BindAddress, config.IncludeEds)
|
||||
if err != nil {
|
||||
@@ -144,14 +162,12 @@ func formatGatewayConfig(configDump any, output string) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if output == "yaml" {
|
||||
out, err = yaml.JSONToYAML(out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
|
||||
25
envoy/1.20/patches/envoy/20240527-fix-wasm-recover.patch
Normal file
25
envoy/1.20/patches/envoy/20240527-fix-wasm-recover.patch
Normal file
@@ -0,0 +1,25 @@
|
||||
diff -Naur envoy/bazel/repository_locations.bzl envoy-new/bazel/repository_locations.bzl
|
||||
--- envoy/bazel/repository_locations.bzl 2024-05-27 18:04:13.116443196 +0800
|
||||
+++ envoy-new/bazel/repository_locations.bzl 2024-05-27 18:02:24.812441069 +0800
|
||||
@@ -1031,8 +1031,8 @@
|
||||
project_name = "WebAssembly for Proxies (C++ host implementation)",
|
||||
project_desc = "WebAssembly for Proxies (C++ host implementation)",
|
||||
project_url = "https://github.com/higress-group/proxy-wasm-cpp-host",
|
||||
- version = "cad2eb04d402dbf559101f3cb4f44da0d9c5b0b0",
|
||||
- sha256 = "4efbcc97c58994fab92c9dc50c051ad16463647d4c0c6df36a7204d2984c1e63",
|
||||
+ version = "28a33a5a3e6c1ff8f53128a74e89aeca47850f68",
|
||||
+ sha256 = "1aaa5898c169aeff115eff2fedf58095b3509d2e59861ad498e661a990d78b3d",
|
||||
strip_prefix = "proxy-wasm-cpp-host-{version}",
|
||||
urls = ["https://github.com/higress-group/proxy-wasm-cpp-host/archive/{version}.tar.gz"],
|
||||
use_category = ["dataplane_ext"],
|
||||
diff -Naur envoy/source/extensions/filters/http/wasm/wasm_filter.h envoy-new/source/extensions/filters/http/wasm/wasm_filter.h
|
||||
--- envoy/source/extensions/filters/http/wasm/wasm_filter.h 2024-05-27 18:04:13.112443196 +0800
|
||||
+++ envoy-new/source/extensions/filters/http/wasm/wasm_filter.h 2024-05-27 18:03:25.360442258 +0800
|
||||
@@ -51,6 +51,7 @@
|
||||
if (opt_ref->recover()) {
|
||||
ENVOY_LOG(info, "wasm vm recover success");
|
||||
wasm = opt_ref->handle()->wasmHandle()->wasm().get();
|
||||
+ handle = opt_ref->handle();
|
||||
} else {
|
||||
ENVOY_LOG(info, "wasm vm recover failed");
|
||||
failed = true;
|
||||
259
envoy/1.20/patches/envoy/20240610-optimize-xds.patch
Normal file
259
envoy/1.20/patches/envoy/20240610-optimize-xds.patch
Normal file
@@ -0,0 +1,259 @@
|
||||
diff --git a/source/common/router/BUILD b/source/common/router/BUILD
|
||||
index 5c58501..4db76cd 100644
|
||||
--- a/source/common/router/BUILD
|
||||
+++ b/source/common/router/BUILD
|
||||
@@ -212,6 +212,7 @@ envoy_cc_library(
|
||||
"//envoy/router:rds_interface",
|
||||
"//envoy/router:scopes_interface",
|
||||
"//envoy/thread_local:thread_local_interface",
|
||||
+ "//source/common/protobuf:utility_lib",
|
||||
"@envoy_api//envoy/config/route/v3:pkg_cc_proto",
|
||||
"@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg_cc_proto",
|
||||
],
|
||||
diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc
|
||||
index ff7b4c8..5ac4523 100644
|
||||
--- a/source/common/router/config_impl.cc
|
||||
+++ b/source/common/router/config_impl.cc
|
||||
@@ -550,19 +550,11 @@ RouteEntryImplBase::RouteEntryImplBase(const VirtualHostImpl& vhost,
|
||||
"not be stripped: {}",
|
||||
path_redirect_);
|
||||
}
|
||||
- ENVOY_LOG(info, "route stats is {}, name is {}", route.stat_prefix(), route.name());
|
||||
if (!route.stat_prefix().empty()) {
|
||||
route_stats_context_ = std::make_unique<RouteStatsContext>(
|
||||
factory_context.scope(), factory_context.routerContext().routeStatNames(), vhost.statName(),
|
||||
route.stat_prefix());
|
||||
- } else if (!route.name().empty()) {
|
||||
- // Added by Ingress
|
||||
- // use route_name as default stat_prefix
|
||||
- route_stats_context_ = std::make_unique<RouteStatsContext>(
|
||||
- factory_context.scope(), factory_context.routerContext().routeStatNames(), vhost.statName(),
|
||||
- route.name());
|
||||
}
|
||||
- // End Added
|
||||
}
|
||||
|
||||
bool RouteEntryImplBase::evaluateRuntimeMatch(const uint64_t random_value) const {
|
||||
@@ -1415,9 +1407,7 @@ VirtualHostImpl::VirtualHostImpl(
|
||||
retry_shadow_buffer_limit_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(
|
||||
virtual_host, per_request_buffer_limit_bytes, std::numeric_limits<uint32_t>::max())),
|
||||
include_attempt_count_in_request_(virtual_host.include_request_attempt_count()),
|
||||
- include_attempt_count_in_response_(virtual_host.include_attempt_count_in_response()),
|
||||
- virtual_cluster_catch_all_(*vcluster_scope_,
|
||||
- factory_context.routerContext().virtualClusterStatNames()) {
|
||||
+ include_attempt_count_in_response_(virtual_host.include_attempt_count_in_response()) {
|
||||
switch (virtual_host.require_tls()) {
|
||||
case envoy::config::route::v3::VirtualHost::NONE:
|
||||
ssl_requirements_ = SslRequirements::None;
|
||||
@@ -1478,10 +1468,14 @@ VirtualHostImpl::VirtualHostImpl(
|
||||
}
|
||||
}
|
||||
|
||||
- for (const auto& virtual_cluster : virtual_host.virtual_clusters()) {
|
||||
- virtual_clusters_.push_back(
|
||||
- VirtualClusterEntry(virtual_cluster, *vcluster_scope_,
|
||||
- factory_context.routerContext().virtualClusterStatNames()));
|
||||
+ if (!virtual_host.virtual_clusters().empty()) {
|
||||
+ virtual_cluster_catch_all_ = std::make_unique<CatchAllVirtualCluster>(
|
||||
+ *vcluster_scope_, factory_context.routerContext().virtualClusterStatNames());
|
||||
+ for (const auto& virtual_cluster : virtual_host.virtual_clusters()) {
|
||||
+ virtual_clusters_.push_back(
|
||||
+ VirtualClusterEntry(virtual_cluster, *vcluster_scope_,
|
||||
+ factory_context.routerContext().virtualClusterStatNames()));
|
||||
+ }
|
||||
}
|
||||
|
||||
if (virtual_host.has_cors()) {
|
||||
@@ -1774,7 +1768,7 @@ VirtualHostImpl::virtualClusterFromEntries(const Http::HeaderMap& headers) const
|
||||
}
|
||||
|
||||
if (!virtual_clusters_.empty()) {
|
||||
- return &virtual_cluster_catch_all_;
|
||||
+ return virtual_cluster_catch_all_.get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
diff --git a/source/common/router/config_impl.h b/source/common/router/config_impl.h
|
||||
index cf0ddf3..d83eb94 100644
|
||||
--- a/source/common/router/config_impl.h
|
||||
+++ b/source/common/router/config_impl.h
|
||||
@@ -352,10 +352,10 @@ private:
|
||||
const bool include_attempt_count_in_response_;
|
||||
absl::optional<envoy::config::route::v3::RetryPolicy> retry_policy_;
|
||||
absl::optional<envoy::config::route::v3::HedgePolicy> hedge_policy_;
|
||||
- const CatchAllVirtualCluster virtual_cluster_catch_all_;
|
||||
#if defined(ALIMESH)
|
||||
std::vector<std::string> allow_server_names_;
|
||||
#endif
|
||||
+ std::unique_ptr<const CatchAllVirtualCluster> virtual_cluster_catch_all_;
|
||||
};
|
||||
|
||||
using VirtualHostSharedPtr = std::shared_ptr<VirtualHostImpl>;
|
||||
diff --git a/source/common/router/scoped_config_impl.cc b/source/common/router/scoped_config_impl.cc
|
||||
index 594d571..6482615 100644
|
||||
--- a/source/common/router/scoped_config_impl.cc
|
||||
+++ b/source/common/router/scoped_config_impl.cc
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "source/common/http/header_utility.h"
|
||||
#endif
|
||||
|
||||
+#include "source/common/protobuf/utility.h"
|
||||
+
|
||||
namespace Envoy {
|
||||
namespace Router {
|
||||
|
||||
@@ -239,7 +241,8 @@ HeaderValueExtractorImpl::computeFragment(const Http::HeaderMap& headers) const
|
||||
|
||||
ScopedRouteInfo::ScopedRouteInfo(envoy::config::route::v3::ScopedRouteConfiguration&& config_proto,
|
||||
ConfigConstSharedPtr&& route_config)
|
||||
- : config_proto_(std::move(config_proto)), route_config_(std::move(route_config)) {
|
||||
+ : config_proto_(std::move(config_proto)), route_config_(std::move(route_config)),
|
||||
+ config_hash_(MessageUtil::hash(config_proto)) {
|
||||
// TODO(stevenzzzz): Maybe worth a KeyBuilder abstraction when there are more than one type of
|
||||
// Fragment.
|
||||
for (const auto& fragment : config_proto_.key().fragments()) {
|
||||
diff --git a/source/common/router/scoped_config_impl.h b/source/common/router/scoped_config_impl.h
|
||||
index 9f6a1b2..28e2ee5 100644
|
||||
--- a/source/common/router/scoped_config_impl.h
|
||||
+++ b/source/common/router/scoped_config_impl.h
|
||||
@@ -154,11 +154,13 @@ public:
|
||||
return config_proto_;
|
||||
}
|
||||
const std::string& scopeName() const { return config_proto_.name(); }
|
||||
+ uint64_t configHash() const { return config_hash_; }
|
||||
|
||||
private:
|
||||
envoy::config::route::v3::ScopedRouteConfiguration config_proto_;
|
||||
ScopeKey scope_key_;
|
||||
ConfigConstSharedPtr route_config_;
|
||||
+ const uint64_t config_hash_;
|
||||
};
|
||||
using ScopedRouteInfoConstSharedPtr = std::shared_ptr<const ScopedRouteInfo>;
|
||||
// Ordered map for consistent config dumping.
|
||||
diff --git a/source/common/router/scoped_rds.cc b/source/common/router/scoped_rds.cc
|
||||
index 133e91e..9b2096e 100644
|
||||
--- a/source/common/router/scoped_rds.cc
|
||||
+++ b/source/common/router/scoped_rds.cc
|
||||
@@ -245,6 +245,11 @@ bool ScopedRdsConfigSubscription::addOrUpdateScopes(
|
||||
dynamic_cast<const envoy::config::route::v3::ScopedRouteConfiguration&>(
|
||||
resource.get().resource());
|
||||
const std::string scope_name = scoped_route_config.name();
|
||||
+ if (const auto& scope_info_iter = scoped_route_map_.find(scope_name);
|
||||
+ scope_info_iter != scoped_route_map_.end() &&
|
||||
+ scope_info_iter->second->configHash() == MessageUtil::hash(scoped_route_config)) {
|
||||
+ continue;
|
||||
+ }
|
||||
rds.set_route_config_name(scoped_route_config.route_configuration_name());
|
||||
std::unique_ptr<RdsRouteConfigProviderHelper> rds_config_provider_helper;
|
||||
std::shared_ptr<ScopedRouteInfo> scoped_route_info = nullptr;
|
||||
@@ -398,6 +403,7 @@ void ScopedRdsConfigSubscription::onRdsConfigUpdate(const std::string& scope_nam
|
||||
auto new_scoped_route_info = std::make_shared<ScopedRouteInfo>(
|
||||
envoy::config::route::v3::ScopedRouteConfiguration(iter->second->configProto()),
|
||||
std::move(new_rds_config));
|
||||
+ scoped_route_map_[new_scoped_route_info->scopeName()] = new_scoped_route_info;
|
||||
applyConfigUpdate([new_scoped_route_info](ConfigProvider::ConfigConstSharedPtr config)
|
||||
-> ConfigProvider::ConfigConstSharedPtr {
|
||||
auto* thread_local_scoped_config =
|
||||
diff --git a/source/common/router/scoped_rds.h b/source/common/router/scoped_rds.h
|
||||
index d21d812..a510c1f 100644
|
||||
--- a/source/common/router/scoped_rds.h
|
||||
+++ b/source/common/router/scoped_rds.h
|
||||
@@ -104,7 +104,7 @@ struct ScopedRdsStats {
|
||||
// A scoped RDS subscription to be used with the dynamic scoped RDS ConfigProvider.
|
||||
class ScopedRdsConfigSubscription
|
||||
: public Envoy::Config::DeltaConfigSubscriptionInstance,
|
||||
- Envoy::Config::SubscriptionBase<envoy::config::route::v3::ScopedRouteConfiguration> {
|
||||
+ public Envoy::Config::SubscriptionBase<envoy::config::route::v3::ScopedRouteConfiguration> {
|
||||
public:
|
||||
using ScopedRouteConfigurationMap =
|
||||
std::map<std::string, envoy::config::route::v3::ScopedRouteConfiguration>;
|
||||
diff --git a/test/common/router/scoped_config_impl_test.cc b/test/common/router/scoped_config_impl_test.cc
|
||||
index f63f258..69a2f4b 100644
|
||||
--- a/test/common/router/scoped_config_impl_test.cc
|
||||
+++ b/test/common/router/scoped_config_impl_test.cc
|
||||
@@ -452,6 +452,24 @@ TEST_F(ScopedRouteInfoTest, Creation) {
|
||||
EXPECT_EQ(info_->scopeKey(), makeKey({"foo", "bar"}));
|
||||
}
|
||||
|
||||
+// Tests that config hash changes if ScopedRouteConfiguration of the ScopedRouteInfo changes.
|
||||
+TEST_F(ScopedRouteInfoTest, Hash) {
|
||||
+ const envoy::config::route::v3::ScopedRouteConfiguration config_copy = scoped_route_config_;
|
||||
+ info_ = std::make_unique<ScopedRouteInfo>(scoped_route_config_, route_config_);
|
||||
+ EXPECT_EQ(info_->routeConfig().get(), route_config_.get());
|
||||
+ EXPECT_TRUE(TestUtility::protoEqual(info_->configProto(), config_copy));
|
||||
+ EXPECT_EQ(info_->scopeName(), "foo_scope");
|
||||
+ EXPECT_EQ(info_->scopeKey(), makeKey({"foo", "bar"}));
|
||||
+
|
||||
+ const auto info2 = std::make_unique<ScopedRouteInfo>(scoped_route_config_, route_config_);
|
||||
+ ASSERT_EQ(info2->configHash(), info_->configHash());
|
||||
+
|
||||
+ // Mutate the config and hash should be different now.
|
||||
+ scoped_route_config_.set_on_demand(true);
|
||||
+ const auto info3 = std::make_unique<ScopedRouteInfo>(scoped_route_config_, route_config_);
|
||||
+ ASSERT_NE(info3->configHash(), info_->configHash());
|
||||
+}
|
||||
+
|
||||
class ScopedConfigImplTest : public testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
diff --git a/test/common/router/scoped_rds_test.cc b/test/common/router/scoped_rds_test.cc
|
||||
index 09b96a6..b4776c9 100644
|
||||
--- a/test/common/router/scoped_rds_test.cc
|
||||
+++ b/test/common/router/scoped_rds_test.cc
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "envoy/stats/scope.h"
|
||||
|
||||
#include "source/common/config/api_version.h"
|
||||
+#include "source/common/config/config_provider_impl.h"
|
||||
#include "source/common/config/grpc_mux_impl.h"
|
||||
#include "source/common/protobuf/message_validator_impl.h"
|
||||
#include "source/common/router/scoped_rds.h"
|
||||
@@ -365,6 +366,48 @@ key:
|
||||
"Didn't find a registered implementation for name: 'filter.unknown'");
|
||||
}
|
||||
|
||||
+// Test that scopes with same config as existing scopes will be skipped in a config push.
|
||||
+TEST_F(ScopedRdsTest, UnchangedScopesAreSkipped) {
|
||||
+ setup();
|
||||
+ init_watcher_.expectReady();
|
||||
+ const std::string config_yaml = R"EOF(
|
||||
+name: foo_scope
|
||||
+route_configuration_name: foo_routes
|
||||
+key:
|
||||
+ fragments:
|
||||
+ - string_key: x-foo-key
|
||||
+)EOF";
|
||||
+ const auto resource = parseScopedRouteConfigurationFromYaml(config_yaml);
|
||||
+ const std::string config_yaml2 = R"EOF(
|
||||
+name: foo_scope2
|
||||
+route_configuration_name: foo_routes
|
||||
+key:
|
||||
+ fragments:
|
||||
+ - string_key: x-bar-key
|
||||
+)EOF";
|
||||
+ const auto resource_2 = parseScopedRouteConfigurationFromYaml(config_yaml2);
|
||||
+
|
||||
+ // Delta API.
|
||||
+ const auto decoded_resources = TestUtility::decodeResources({resource, resource_2});
|
||||
+ context_init_manager_.initialize(init_watcher_);
|
||||
+ EXPECT_NO_THROW(srds_subscription_->onConfigUpdate(decoded_resources.refvec_, {}, "v1"));
|
||||
+ EXPECT_EQ(1UL,
|
||||
+ server_factory_context_.scope_.counter("foo.scoped_rds.foo_scoped_routes.config_reload")
|
||||
+ .value());
|
||||
+ EXPECT_EQ(2UL, all_scopes_.value());
|
||||
+ pushRdsConfig({"foo_routes"}, "111");
|
||||
+ Envoy::Router::ScopedRdsConfigSubscription* srds_delta_subscription =
|
||||
+ static_cast<Envoy::Router::ScopedRdsConfigSubscription*>(srds_subscription_);
|
||||
+ ASSERT_NE(srds_delta_subscription, nullptr);
|
||||
+ ASSERT_EQ("v1", srds_delta_subscription->configInfo()->last_config_version_);
|
||||
+ // Push again the same set of config with different version number, the config will be skipped.
|
||||
+ EXPECT_NO_THROW(srds_subscription_->onConfigUpdate(decoded_resources.refvec_, {}, "123"));
|
||||
+ ASSERT_EQ("v1", srds_delta_subscription->configInfo()->last_config_version_);
|
||||
+ EXPECT_EQ(2UL,
|
||||
+ server_factory_context_.scope_.counter("foo.scoped_rds.foo_scoped_routes.config_reload")
|
||||
+ .value());
|
||||
+}
|
||||
+
|
||||
// Test ignoring the optional unknown factory in the per-virtualhost typed config.
|
||||
TEST_F(ScopedRdsTest, OptionalUnknownFactoryForPerVirtualHostTypedConfig) {
|
||||
OptionalHttpFilters optional_http_filters;
|
||||
6
go.mod
6
go.mod
@@ -259,7 +259,6 @@ require (
|
||||
golang.org/x/crypto v0.17.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/oauth2 v0.6.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
@@ -281,6 +280,8 @@ require (
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
k8s.io/apiserver v0.22.5 // indirect
|
||||
k8s.io/component-base v0.22.5 // indirect
|
||||
k8s.io/klog/v2 v2.60.1 // indirect
|
||||
k8s.io/kube-openapi v0.0.0-20211109043538-20434351676c // indirect
|
||||
k8s.io/utils v0.0.0-20220728103510-ee6ede2d64ed // indirect
|
||||
oras.land/oras-go v0.4.0 // indirect
|
||||
@@ -312,10 +313,9 @@ require (
|
||||
github.com/kylelemons/godebug v1.1.0
|
||||
github.com/mholt/acmez v1.2.0
|
||||
github.com/tidwall/gjson v1.17.0
|
||||
golang.org/x/net v0.17.0
|
||||
helm.sh/helm/v3 v3.7.1
|
||||
k8s.io/apiextensions-apiserver v0.25.4
|
||||
k8s.io/component-base v0.22.5
|
||||
k8s.io/klog/v2 v2.60.1
|
||||
knative.dev/networking v0.0.0-20220302134042-e8b2eb995165
|
||||
knative.dev/pkg v0.0.0-20220301181942-2fdd5f232e77
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 1.4.0-rc.1
|
||||
appVersion: 1.4.1
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -10,4 +10,4 @@ name: higress-core
|
||||
sources:
|
||||
- http://github.com/alibaba/higress
|
||||
type: application
|
||||
version: 1.4.0-rc.1
|
||||
version: 1.4.1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
revision: ""
|
||||
global:
|
||||
liteMetrics: false
|
||||
liteMetrics: true
|
||||
xdsMaxRecvMsgSize: "104857600"
|
||||
defaultUpstreamConcurrencyThreshold: 10000
|
||||
enableSRDS: true
|
||||
@@ -589,7 +589,7 @@ controller:
|
||||
maxReplicas: 5
|
||||
targetCPUUtilizationPercentage: 80
|
||||
automaticHttps:
|
||||
enabled: false
|
||||
enabled: true
|
||||
email: ""
|
||||
|
||||
## Discovery Settings
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 1.4.0-rc.1
|
||||
version: 1.4.1
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 1.4.0
|
||||
digest: sha256:320b1b3ed08fad56dff0d21faaffe41a0325fdcdb96847e53a588d6b0df7e73e
|
||||
generated: "2024-05-19T17:52:19.676747+08:00"
|
||||
version: 1.4.1
|
||||
digest: sha256:de41b8f771e869aef9b83d2334fea5d34492a1c5df37e5aaff383189877cba23
|
||||
generated: "2024-06-19T17:10:02.426994+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 1.4.0-rc.1
|
||||
appVersion: 1.4.1
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 1.4.0-rc.1
|
||||
version: 1.4.1
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 1.4.0
|
||||
version: 1.4.1
|
||||
type: application
|
||||
version: 1.4.0-rc.1
|
||||
version: 1.4.1
|
||||
|
||||
69
istio/1.12/patches/istio/20240527-fix-vs-merge.patch
Normal file
69
istio/1.12/patches/istio/20240527-fix-vs-merge.patch
Normal file
@@ -0,0 +1,69 @@
|
||||
diff -Naur istio/pilot/pkg/model/push_context.go istio-new/pilot/pkg/model/push_context.go
|
||||
--- istio/pilot/pkg/model/push_context.go 2024-05-27 23:03:09.000000000 +0800
|
||||
+++ istio-new/pilot/pkg/model/push_context.go 2024-05-27 21:33:45.000000000 +0800
|
||||
@@ -1482,8 +1482,14 @@
|
||||
ns := virtualService.Namespace
|
||||
rule := virtualService.Spec.(*networking.VirtualService)
|
||||
// Added by ingress
|
||||
- for _, host := range rule.Hosts {
|
||||
- ps.virtualServiceIndex.byHost[host] = append(ps.virtualServiceIndex.byHost[host], virtualService)
|
||||
+ if len(rule.Gateways) > 0 {
|
||||
+ if len(rule.Hosts) == 0 {
|
||||
+ ps.virtualServiceIndex.byHost[constants.GlobalWildcardHost] = append(ps.virtualServiceIndex.byHost[constants.GlobalWildcardHost], virtualService)
|
||||
+ } else {
|
||||
+ for _, host := range rule.Hosts {
|
||||
+ ps.virtualServiceIndex.byHost[host] = append(ps.virtualServiceIndex.byHost[host], virtualService)
|
||||
+ }
|
||||
+ }
|
||||
}
|
||||
// End added by ingress
|
||||
gwNames := getGatewayNames(rule)
|
||||
diff -Naur istio/pilot/pkg/networking/core/v1alpha3/gateway.go istio-new/pilot/pkg/networking/core/v1alpha3/gateway.go
|
||||
--- istio/pilot/pkg/networking/core/v1alpha3/gateway.go 2024-05-27 23:03:09.000000000 +0800
|
||||
+++ istio-new/pilot/pkg/networking/core/v1alpha3/gateway.go 2024-05-27 22:58:33.000000000 +0800
|
||||
@@ -376,8 +376,15 @@
|
||||
gatewayVirtualServices[gatewayName] = virtualServices
|
||||
}
|
||||
for _, virtualService := range virtualServices {
|
||||
- for _, host := range virtualService.Spec.(*networking.VirtualService).Hosts {
|
||||
- hostSet.Insert(host)
|
||||
+ rule := virtualService.Spec.(*networking.VirtualService)
|
||||
+ if len(rule.Gateways) > 0 {
|
||||
+ if len(rule.Hosts) == 0 {
|
||||
+ hostSet.Insert(constants.GlobalWildcardHost)
|
||||
+ break
|
||||
+ }
|
||||
+ for _, host := range rule.Hosts {
|
||||
+ hostSet.Insert(host)
|
||||
+ }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -689,7 +696,7 @@
|
||||
vHost = &route.VirtualHost{
|
||||
Name: util.DomainName(hostRDSHost, port),
|
||||
Domains: buildGatewayVirtualHostDomains(hostRDSHost, port),
|
||||
- Routes: routes,
|
||||
+ Routes: append(routes[:0:0], routes...),
|
||||
IncludeRequestAttemptCount: true,
|
||||
TypedPerFilterConfig: mseingress.ConstructTypedPerFilterConfigForVHost(globalHTTPFilters, virtualService),
|
||||
}
|
||||
@@ -884,7 +891,7 @@
|
||||
newVHost := &route.VirtualHost{
|
||||
Name: util.DomainName(string(hostname), port),
|
||||
Domains: buildGatewayVirtualHostDomains(string(hostname), port),
|
||||
- Routes: routes,
|
||||
+ Routes: append(routes[:0:0], routes...),
|
||||
IncludeRequestAttemptCount: true,
|
||||
TypedPerFilterConfig: mseingress.ConstructTypedPerFilterConfigForVHost(globalHTTPFilters, virtualService),
|
||||
}
|
||||
diff -Naur istio/pkg/config/constants/constants.go istio-new/pkg/config/constants/constants.go
|
||||
--- istio/pkg/config/constants/constants.go 2024-05-27 23:03:09.000000000 +0800
|
||||
+++ istio-new/pkg/config/constants/constants.go 2024-05-27 21:31:58.000000000 +0800
|
||||
@@ -145,5 +145,6 @@
|
||||
// Added by ingress
|
||||
HigressHostRDSNamePrefix = "higress-rds-"
|
||||
DefaultScopedRouteName = "scoped-route"
|
||||
+ GlobalWildcardHost = "*"
|
||||
// End added by ingress
|
||||
)
|
||||
17
istio/1.12/patches/istio/20240529-optimize-mcp-cds.patch
Normal file
17
istio/1.12/patches/istio/20240529-optimize-mcp-cds.patch
Normal file
@@ -0,0 +1,17 @@
|
||||
diff -Naur istio/pilot/pkg/model/push_context.go istio-new/pilot/pkg/model/push_context.go
|
||||
--- istio/pilot/pkg/model/push_context.go 2024-05-29 19:29:45.000000000 +0800
|
||||
+++ istio-new/pilot/pkg/model/push_context.go 2024-05-29 19:11:03.000000000 +0800
|
||||
@@ -769,6 +769,13 @@
|
||||
for _, s := range svcs {
|
||||
svcHost := string(s.Hostname)
|
||||
|
||||
+ // Added by ingress
|
||||
+ if s.Attributes.Namespace == "mcp" {
|
||||
+ gwSvcs = append(gwSvcs, s)
|
||||
+ continue
|
||||
+ }
|
||||
+ // End added by ingress
|
||||
+
|
||||
if _, ok := hostsFromGateways[svcHost]; ok {
|
||||
gwSvcs = append(gwSvcs, s)
|
||||
}
|
||||
21
istio/1.12/patches/istio/20240607-fix-stats.patch
Normal file
21
istio/1.12/patches/istio/20240607-fix-stats.patch
Normal file
@@ -0,0 +1,21 @@
|
||||
diff -Naur istio/tools/packaging/common/envoy_bootstrap.json istio-new/tools/packaging/common/envoy_bootstrap.json
|
||||
--- istio/tools/packaging/common/envoy_bootstrap.json 2024-06-07 16:50:21.000000000 +0800
|
||||
+++ istio-new/tools/packaging/common/envoy_bootstrap.json 2024-06-07 16:47:42.000000000 +0800
|
||||
@@ -38,7 +38,7 @@
|
||||
"stats_tags": [
|
||||
{
|
||||
"tag_name": "cluster_name",
|
||||
- "regex": "^cluster\\.((.+?(\\..+?\\.svc\\.cluster\\.local)?)\\.)"
|
||||
+ "regex": "^cluster\\.((.*?)\\.)(http1\\.|http2\\.|health_check\\.|zone\\.|external\\.|circuit_breakers\\.|[^\\.]+$)"
|
||||
},
|
||||
{
|
||||
"tag_name": "tcp_prefix",
|
||||
@@ -58,7 +58,7 @@
|
||||
},
|
||||
{
|
||||
"tag_name": "http_conn_manager_prefix",
|
||||
- "regex": "^http\\.(((?:[_.[:digit:]]*|[_\\[\\]aAbBcCdDeEfF[:digit:]]*))\\.)"
|
||||
+ "regex": "^http\\.(((outbound_([0-9]{1,3}\\.{0,1}){4}_\\d{0,5})|([^\\.]+))\\.)"
|
||||
},
|
||||
{
|
||||
"tag_name": "listener_address",
|
||||
53
istio/1.12/patches/istio/20240619-ai-stats.patch
Normal file
53
istio/1.12/patches/istio/20240619-ai-stats.patch
Normal file
@@ -0,0 +1,53 @@
|
||||
diff -Naur istio/tools/packaging/common/envoy_bootstrap.json istio-new/tools/packaging/common/envoy_bootstrap.json
|
||||
--- istio/tools/packaging/common/envoy_bootstrap.json 2024-06-19 13:39:49.179159469 +0800
|
||||
+++ istio-new/tools/packaging/common/envoy_bootstrap.json 2024-06-19 13:39:28.299159059 +0800
|
||||
@@ -37,6 +37,18 @@
|
||||
"use_all_default_tags": false,
|
||||
"stats_tags": [
|
||||
{
|
||||
+ "tag_name": "ai_route",
|
||||
+ "regex": "^wasmcustom\\.route\\.((.*?)\\.)upstream"
|
||||
+ },
|
||||
+ {
|
||||
+ "tag_name": "ai_cluster",
|
||||
+ "regex": "^wasmcustom\\..*?\\.upstream\\.((.*?)\\.)model"
|
||||
+ },
|
||||
+ {
|
||||
+ "tag_name": "ai_model",
|
||||
+ "regex": "^wasmcustom\\..*?\\.model\\.((.*?)\\.)(input_token|output_token)"
|
||||
+ },
|
||||
+ {
|
||||
"tag_name": "cluster_name",
|
||||
"regex": "^cluster\\.((.*?)\\.)(http1\\.|http2\\.|health_check\\.|zone\\.|external\\.|circuit_breakers\\.|[^\\.]+$)"
|
||||
},
|
||||
diff -Naur istio/tools/packaging/common/envoy_bootstrap_lite.json istio-new/tools/packaging/common/envoy_bootstrap_lite.json
|
||||
--- istio/tools/packaging/common/envoy_bootstrap_lite.json 2024-06-19 13:39:49.175159469 +0800
|
||||
+++ istio-new/tools/packaging/common/envoy_bootstrap_lite.json 2024-06-19 13:38:52.283158352 +0800
|
||||
@@ -37,6 +37,18 @@
|
||||
"use_all_default_tags": false,
|
||||
"stats_tags": [
|
||||
{
|
||||
+ "tag_name": "ai_route",
|
||||
+ "regex": "^wasmcustom\\.route\\.((.*?)\\.)upstream"
|
||||
+ },
|
||||
+ {
|
||||
+ "tag_name": "ai_cluster",
|
||||
+ "regex": "^wasmcustom\\..*?\\.upstream\\.((.*?)\\.)model"
|
||||
+ },
|
||||
+ {
|
||||
+ "tag_name": "ai_model",
|
||||
+ "regex": "^wasmcustom\\..*?\\.model\\.((.*?)\\.)(input_token|output_token)"
|
||||
+ },
|
||||
+ {
|
||||
"tag_name": "response_code_class",
|
||||
"regex": "_rq(_(\\dxx))$"
|
||||
},
|
||||
@@ -60,7 +72,7 @@
|
||||
"prefix": "vhost"
|
||||
},
|
||||
{
|
||||
- "safe_regex": {"regex": "^http.*rds.*", "google_re2":{}}
|
||||
+ "safe_regex": {"regex": "^http.*\\.rds\\..*", "google_re2":{}}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -45,11 +45,12 @@ const (
|
||||
|
||||
// Config is the configuration of automatic https.
|
||||
type Config struct {
|
||||
AutomaticHttps bool `json:"automaticHttps"`
|
||||
RenewBeforeDays int `json:"renewBeforeDays"`
|
||||
CredentialConfig []CredentialEntry `json:"credentialConfig"`
|
||||
ACMEIssuer []ACMEIssuerEntry `json:"acmeIssuer"`
|
||||
Version string `json:"version"`
|
||||
AutomaticHttps bool `json:"automaticHttps"`
|
||||
FallbackForInvalidSecret bool `json:"fallbackForInvalidSecret"`
|
||||
RenewBeforeDays int `json:"renewBeforeDays"`
|
||||
CredentialConfig []CredentialEntry `json:"credentialConfig"`
|
||||
ACMEIssuer []ACMEIssuerEntry `json:"acmeIssuer"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (c *Config) GetIssuer(issuerName IssuerName) *ACMEIssuerEntry {
|
||||
@@ -274,11 +275,12 @@ func newDefaultConfig(email string) *Config {
|
||||
}
|
||||
defaultCredentialConfig := make([]CredentialEntry, 0)
|
||||
config := &Config{
|
||||
AutomaticHttps: true,
|
||||
RenewBeforeDays: DefaultRenewBeforeDays,
|
||||
ACMEIssuer: defaultIssuer,
|
||||
CredentialConfig: defaultCredentialConfig,
|
||||
Version: time.Now().Format("20060102030405"),
|
||||
AutomaticHttps: true,
|
||||
FallbackForInvalidSecret: false,
|
||||
RenewBeforeDays: DefaultRenewBeforeDays,
|
||||
ACMEIssuer: defaultIssuer,
|
||||
CredentialConfig: defaultCredentialConfig,
|
||||
Version: time.Now().Format("20060102030405"),
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
package hgctl
|
||||
|
||||
const (
|
||||
yamlOutput = "yaml"
|
||||
jsonOutput = "json"
|
||||
flagsOutput = "flags"
|
||||
summaryOutput = "short"
|
||||
yamlOutput = "yaml"
|
||||
jsonOutput = "json"
|
||||
flagsOutput = "flags"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/cmd/hgctl/config"
|
||||
"github.com/spf13/cobra"
|
||||
"istio.io/istio/istioctl/pkg/writer/envoy/configdump"
|
||||
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
||||
)
|
||||
|
||||
@@ -49,17 +50,23 @@ func runClusterConfig(c *cobra.Command, args []string) error {
|
||||
if len(args) != 0 {
|
||||
podName = args[0]
|
||||
}
|
||||
envoyConfig, err := config.GetEnvoyConfig(&config.GetEnvoyConfigOptions{
|
||||
configWriter, err := config.GetEnvoyConfigWriter(&config.GetEnvoyConfigOptions{
|
||||
PodName: podName,
|
||||
PodNamespace: podNamespace,
|
||||
BindAddress: bindAddress,
|
||||
Output: output,
|
||||
EnvoyConfigType: config.ClusterEnvoyConfigType,
|
||||
IncludeEds: true,
|
||||
})
|
||||
}, c.OutOrStdout())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintln(c.OutOrStdout(), string(envoyConfig))
|
||||
return err
|
||||
switch output {
|
||||
case summaryOutput:
|
||||
return configWriter.PrintClusterSummary(configdump.ClusterFilter{})
|
||||
case jsonOutput, yamlOutput:
|
||||
return configWriter.PrintClusterDump(configdump.ClusterFilter{}, output)
|
||||
default:
|
||||
return fmt.Errorf("output format %q not supported", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func newConfigCommand() *cobra.Command {
|
||||
flags := cfgCommand.Flags()
|
||||
options.AddKubeConfigFlags(flags)
|
||||
|
||||
cfgCommand.PersistentFlags().StringVarP(&output, "output", "o", "json", "One of 'yaml' or 'json'")
|
||||
cfgCommand.PersistentFlags().StringVarP(&output, "output", "o", "json", "Output format: one of json|yaml|short")
|
||||
cfgCommand.PersistentFlags().StringVarP(&podNamespace, "namespace", "n", "higress-system", "Namespace where envoy proxy pod are installed.")
|
||||
|
||||
return cfgCommand
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/cmd/hgctl/config"
|
||||
"github.com/spf13/cobra"
|
||||
"istio.io/istio/istioctl/pkg/writer/envoy/configdump"
|
||||
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
||||
)
|
||||
|
||||
@@ -49,17 +50,23 @@ func runListenerConfig(c *cobra.Command, args []string) error {
|
||||
if len(args) != 0 {
|
||||
podName = args[0]
|
||||
}
|
||||
envoyConfig, err := config.GetEnvoyConfig(&config.GetEnvoyConfigOptions{
|
||||
configWriter, err := config.GetEnvoyConfigWriter(&config.GetEnvoyConfigOptions{
|
||||
PodName: podName,
|
||||
PodNamespace: podNamespace,
|
||||
BindAddress: bindAddress,
|
||||
Output: output,
|
||||
EnvoyConfigType: config.ListenerEnvoyConfigType,
|
||||
IncludeEds: true,
|
||||
})
|
||||
}, c.OutOrStdout())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintln(c.OutOrStdout(), string(envoyConfig))
|
||||
return err
|
||||
switch output {
|
||||
case summaryOutput:
|
||||
return configWriter.PrintListenerSummary(configdump.ListenerFilter{Verbose: true})
|
||||
case jsonOutput, yamlOutput:
|
||||
return configWriter.PrintListenerDump(configdump.ListenerFilter{Verbose: true}, output)
|
||||
default:
|
||||
return fmt.Errorf("output format %q not supported", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/cmd/hgctl/config"
|
||||
"github.com/spf13/cobra"
|
||||
"istio.io/istio/istioctl/pkg/writer/envoy/configdump"
|
||||
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
||||
)
|
||||
|
||||
@@ -49,17 +50,23 @@ func runRouteConfig(c *cobra.Command, args []string) error {
|
||||
if len(args) != 0 {
|
||||
podName = args[0]
|
||||
}
|
||||
envoyConfig, err := config.GetEnvoyConfig(&config.GetEnvoyConfigOptions{
|
||||
configWriter, err := config.GetEnvoyConfigWriter(&config.GetEnvoyConfigOptions{
|
||||
PodName: podName,
|
||||
PodNamespace: podNamespace,
|
||||
BindAddress: bindAddress,
|
||||
Output: output,
|
||||
EnvoyConfigType: config.RouteEnvoyConfigType,
|
||||
IncludeEds: true,
|
||||
})
|
||||
}, c.OutOrStdout())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintln(c.OutOrStdout(), string(envoyConfig))
|
||||
return err
|
||||
switch output {
|
||||
case summaryOutput:
|
||||
return configWriter.PrintRouteSummary(configdump.RouteFilter{Verbose: true})
|
||||
case jsonOutput, yamlOutput:
|
||||
return configWriter.PrintRouteDump(configdump.RouteFilter{Verbose: true}, output)
|
||||
default:
|
||||
return fmt.Errorf("output format %q not supported", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/alibaba/higress/pkg/cmd/hgctl/plugin/option"
|
||||
"github.com/alibaba/higress/pkg/cmd/hgctl/plugin/utils"
|
||||
@@ -86,6 +87,12 @@ func runInit(w io.Writer, target string) (err error) {
|
||||
return errors.Wrap(err, "failed to create option.yaml")
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "mod", "tidy")
|
||||
cmd.Dir = dir
|
||||
if err := cmd.Run(); err != nil {
|
||||
return errors.Wrap(err, "failed to run go mod tidy")
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Initialized the project in %q\n", dir)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -31,8 +31,8 @@ package main
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -93,8 +93,8 @@ module {{ .Name }}
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v0.0.0-20231019123123-86b223bc75f1
|
||||
github.com/tetratelabs/proxy-wasm-go-sdk v0.22.0
|
||||
github.com/alibaba/higress/plugins/wasm-go main
|
||||
github.com/higress-group/proxy-wasm-go-sdk main
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
`
|
||||
|
||||
@@ -51,6 +51,7 @@ import (
|
||||
higressv1 "github.com/alibaba/higress/api/networking/v1"
|
||||
extlisterv1 "github.com/alibaba/higress/client/pkg/listers/extensions/v1alpha1"
|
||||
netlisterv1 "github.com/alibaba/higress/client/pkg/listers/networking/v1"
|
||||
"github.com/alibaba/higress/pkg/cert"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/common"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/configmap"
|
||||
@@ -144,6 +145,8 @@ type IngressConfig struct {
|
||||
namespace string
|
||||
|
||||
clusterId string
|
||||
|
||||
httpsConfigMgr *cert.ConfigMgr
|
||||
}
|
||||
|
||||
func NewIngressConfig(localKubeClient kube.Client, XDSUpdater model.XDSUpdater, namespace, clusterId string) *IngressConfig {
|
||||
@@ -180,6 +183,9 @@ func NewIngressConfig(localKubeClient kube.Client, XDSUpdater model.XDSUpdater,
|
||||
higressConfigController := configmap.NewController(localKubeClient, clusterId, namespace)
|
||||
config.configmapMgr = configmap.NewConfigmapMgr(XDSUpdater, namespace, higressConfigController, higressConfigController.Lister())
|
||||
|
||||
httpsConfigMgr, _ := cert.NewConfigMgr(namespace, localKubeClient)
|
||||
config.httpsConfigMgr = httpsConfigMgr
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -347,6 +353,10 @@ func (m *IngressConfig) convertGateways(configs []common.WrapperConfig) []config
|
||||
Gateways: map[string]*common.WrapperGateway{},
|
||||
}
|
||||
|
||||
httpsCredentialConfig, err := m.httpsConfigMgr.GetConfigFromConfigmap()
|
||||
if err != nil {
|
||||
IngressLog.Errorf("Get higress https configmap err %v", err)
|
||||
}
|
||||
for idx := range configs {
|
||||
cfg := configs[idx]
|
||||
clusterId := common.GetClusterId(cfg.Config.Annotations)
|
||||
@@ -356,7 +366,7 @@ func (m *IngressConfig) convertGateways(configs []common.WrapperConfig) []config
|
||||
if ingressController == nil {
|
||||
continue
|
||||
}
|
||||
if err := ingressController.ConvertGateway(&convertOptions, &cfg); err != nil {
|
||||
if err := ingressController.ConvertGateway(&convertOptions, &cfg, httpsCredentialConfig); err != nil {
|
||||
IngressLog.Errorf("Convert ingress %s/%s to gateway fail in cluster %s, err %v", cfg.Config.Namespace, cfg.Config.Name, clusterId, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ package common
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/pkg/cert"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
"istio.io/istio/pilot/pkg/model"
|
||||
"istio.io/istio/pkg/config"
|
||||
gatewaytool "istio.io/istio/pkg/config/gateway"
|
||||
listerv1 "k8s.io/client-go/listers/core/v1"
|
||||
"k8s.io/client-go/tools/cache"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
|
||||
)
|
||||
|
||||
type ServiceKey struct {
|
||||
@@ -121,7 +121,7 @@ type IngressController interface {
|
||||
|
||||
SecretLister() listerv1.SecretLister
|
||||
|
||||
ConvertGateway(convertOptions *ConvertOptions, wrapper *WrapperConfig) error
|
||||
ConvertGateway(convertOptions *ConvertOptions, wrapper *WrapperConfig, httpsCredentialConfig *cert.Config) error
|
||||
|
||||
ConvertHTTPRoute(convertOptions *ConvertOptions, wrapper *WrapperConfig) error
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ import (
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/secret"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
k8serrors "k8s.io/apimachinery/pkg/api/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -87,8 +88,6 @@ type controller struct {
|
||||
secretController secret.SecretController
|
||||
|
||||
statusSyncer *statusSyncer
|
||||
|
||||
configMgr *cert.ConfigMgr
|
||||
}
|
||||
|
||||
// NewController creates a new Kubernetes controller
|
||||
@@ -107,7 +106,6 @@ func NewController(localKubeClient, client kubeclient.Client, options common.Opt
|
||||
IngressLog.Infof("Skipping IngressClass, resource not supported for cluster %s", options.ClusterId)
|
||||
}
|
||||
|
||||
configMgr, _ := cert.NewConfigMgr(options.SystemNamespace, client.Kube())
|
||||
c := &controller{
|
||||
options: options,
|
||||
queue: q,
|
||||
@@ -118,7 +116,6 @@ func NewController(localKubeClient, client kubeclient.Client, options common.Opt
|
||||
serviceInformer: serviceInformer.Informer(),
|
||||
serviceLister: serviceInformer.Lister(),
|
||||
secretController: secretController,
|
||||
configMgr: configMgr,
|
||||
}
|
||||
|
||||
handler := controllers.LatestVersionHandlerFuncs(controllers.EnqueueForSelf(q))
|
||||
@@ -354,7 +351,7 @@ func extractTLSSecretName(host string, tls []ingress.IngressTLS) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapper *common.WrapperConfig) error {
|
||||
func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapper *common.WrapperConfig, httpsCredentialConfig *cert.Config) error {
|
||||
if convertOptions == nil {
|
||||
return fmt.Errorf("convertOptions is nil")
|
||||
}
|
||||
@@ -377,7 +374,6 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
|
||||
common.IncrementInvalidIngress(c.options.ClusterId, common.EmptyRule)
|
||||
return fmt.Errorf("invalid ingress rule %s:%s in cluster %s, either `defaultBackend` or `rules` must be specified", cfg.Namespace, cfg.Name, c.options.ClusterId)
|
||||
}
|
||||
httpsCredentialConfig, _ := c.configMgr.GetConfigFromConfigmap()
|
||||
for _, rule := range ingressV1Beta.Rules {
|
||||
// Need create builder for every rule.
|
||||
domainBuilder := &common.IngressDomainBuilder{
|
||||
@@ -429,10 +425,23 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
|
||||
// Get tls secret matching the rule host
|
||||
secretName := extractTLSSecretName(rule.Host, ingressV1Beta.TLS)
|
||||
secretNamespace := cfg.Namespace
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
if secretName == "" && httpsCredentialConfig != nil {
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
if secretName != "" {
|
||||
if httpsCredentialConfig != nil && httpsCredentialConfig.FallbackForInvalidSecret {
|
||||
_, err := c.secretController.Lister().Secrets(secretNamespace).Get(secretName)
|
||||
if err != nil {
|
||||
if k8serrors.IsNotFound(err) {
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
if httpsCredentialConfig != nil {
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
}
|
||||
}
|
||||
if secretName == "" {
|
||||
// There no matching secret, so just skip.
|
||||
|
||||
@@ -334,7 +334,7 @@ func testConvertGateway(t *testing.T, c common.IngressController) {
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
err := c.ConvertGateway(testcase.input.options, testcase.input.wrapperConfig)
|
||||
err := c.ConvertGateway(testcase.input.options, testcase.input.wrapperConfig, nil)
|
||||
if err != nil {
|
||||
require.Equal(t, testcase.expectNoError, false)
|
||||
} else {
|
||||
|
||||
@@ -54,6 +54,7 @@ import (
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/secret"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
k8serrors "k8s.io/apimachinery/pkg/api/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -85,8 +86,6 @@ type controller struct {
|
||||
secretController secret.SecretController
|
||||
|
||||
statusSyncer *statusSyncer
|
||||
|
||||
configMgr *cert.ConfigMgr
|
||||
}
|
||||
|
||||
// NewController creates a new Kubernetes controller
|
||||
@@ -99,7 +98,6 @@ func NewController(localKubeClient, client kubeclient.Client, options common.Opt
|
||||
classes := client.KubeInformer().Networking().V1().IngressClasses()
|
||||
classes.Informer()
|
||||
|
||||
configMgr, _ := cert.NewConfigMgr(options.SystemNamespace, client.Kube())
|
||||
c := &controller{
|
||||
options: options,
|
||||
queue: q,
|
||||
@@ -110,7 +108,6 @@ func NewController(localKubeClient, client kubeclient.Client, options common.Opt
|
||||
serviceInformer: serviceInformer.Informer(),
|
||||
serviceLister: serviceInformer.Lister(),
|
||||
secretController: secretController,
|
||||
configMgr: configMgr,
|
||||
}
|
||||
|
||||
handler := controllers.LatestVersionHandlerFuncs(controllers.EnqueueForSelf(q))
|
||||
@@ -346,7 +343,7 @@ func extractTLSSecretName(host string, tls []ingress.IngressTLS) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapper *common.WrapperConfig) error {
|
||||
func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapper *common.WrapperConfig, httpsCredentialConfig *cert.Config) error {
|
||||
// Ignore canary config.
|
||||
if wrapper.AnnotationsConfig.IsCanary() {
|
||||
return nil
|
||||
@@ -363,7 +360,6 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
|
||||
return fmt.Errorf("invalid ingress rule %s:%s in cluster %s, either `defaultBackend` or `rules` must be specified", cfg.Namespace, cfg.Name, c.options.ClusterId)
|
||||
}
|
||||
|
||||
httpsCredentialConfig, _ := c.configMgr.GetConfigFromConfigmap()
|
||||
for _, rule := range ingressV1.Rules {
|
||||
// Need create builder for every rule.
|
||||
domainBuilder := &common.IngressDomainBuilder{
|
||||
@@ -415,11 +411,25 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
|
||||
// Get tls secret matching the rule host
|
||||
secretName := extractTLSSecretName(rule.Host, ingressV1.TLS)
|
||||
secretNamespace := cfg.Namespace
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
if secretName == "" && httpsCredentialConfig != nil {
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
if secretName != "" {
|
||||
if httpsCredentialConfig != nil && httpsCredentialConfig.FallbackForInvalidSecret {
|
||||
_, err := c.secretController.Lister().Secrets(secretNamespace).Get(secretName)
|
||||
if err != nil {
|
||||
if k8serrors.IsNotFound(err) {
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If there is no matching secret, try to get it from configmap.
|
||||
if httpsCredentialConfig != nil {
|
||||
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
|
||||
secretNamespace = c.options.SystemNamespace
|
||||
}
|
||||
}
|
||||
|
||||
if secretName == "" {
|
||||
// There no matching secret, so just skip.
|
||||
continue
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
PLUGIN_NAME ?= hello-world
|
||||
BUILDER_REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
|
||||
REGISTRY ?=
|
||||
REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
|
||||
GO_VERSION ?= 1.19
|
||||
TINYGO_VERSION ?= 0.28.1
|
||||
ORAS_VERSION ?= 1.0.0
|
||||
@@ -12,12 +12,14 @@ COMMIT_ID := $(shell git rev-parse --short HEAD 2>/dev/null)
|
||||
IMAGE_TAG = $(if $(strip $(PLUGIN_VERSION)),${PLUGIN_VERSION},${BUILD_TIME}-${COMMIT_ID})
|
||||
IMG ?= ${REGISTRY}${PLUGIN_NAME}:${IMAGE_TAG}
|
||||
GOPROXY := $(shell go env GOPROXY)
|
||||
EXTRA_TAGS ?=
|
||||
|
||||
.DEFAULT:
|
||||
build:
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg PLUGIN_NAME=${PLUGIN_NAME} \
|
||||
--build-arg BUILDER=${BUILDER} \
|
||||
--build-arg GOPROXY=$(GOPROXY) \
|
||||
--build-arg EXTRA_TAGS=$(EXTRA_TAGS) \
|
||||
-t ${IMG} \
|
||||
--output extensions/${PLUGIN_NAME} \
|
||||
.
|
||||
@@ -28,6 +30,7 @@ build-image:
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg PLUGIN_NAME=${PLUGIN_NAME} \
|
||||
--build-arg BUILDER=${BUILDER} \
|
||||
--build-arg GOPROXY=$(GOPROXY) \
|
||||
--build-arg EXTRA_TAGS=$(EXTRA_TAGS) \
|
||||
-t ${IMG} \
|
||||
.
|
||||
@echo ""
|
||||
|
||||
19
plugins/wasm-go/extensions/ai-cache/.gitignore
vendored
Normal file
19
plugins/wasm-go/extensions/ai-cache/.gitignore
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# File generated by hgctl. Modify as required.
|
||||
|
||||
*
|
||||
|
||||
!/.gitignore
|
||||
|
||||
!*.go
|
||||
!go.sum
|
||||
!go.mod
|
||||
|
||||
!LICENSE
|
||||
!*.md
|
||||
!*.yaml
|
||||
!*.yml
|
||||
|
||||
!*/
|
||||
|
||||
/out
|
||||
/test
|
||||
34
plugins/wasm-go/extensions/ai-cache/README.md
Normal file
34
plugins/wasm-go/extensions/ai-cache/README.md
Normal file
@@ -0,0 +1,34 @@
|
||||
## 简介
|
||||
|
||||
**Note**
|
||||
|
||||
> 需要数据面的proxy wasm版本大于等于0.2.100
|
||||
|
||||
> 编译时,需要带上版本的tag,例如:`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./`
|
||||
|
||||
LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的结果缓存,同时支持流式和非流式响应的缓存。
|
||||
|
||||
## 配置说明
|
||||
|
||||
| Name | Type | Requirement | Default | Description |
|
||||
| -------- | -------- | -------- | -------- | -------- |
|
||||
| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
|
||||
| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
|
||||
| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
|
||||
| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 |
|
||||
| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 |
|
||||
| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
|
||||
| redis.servicePort | integer | optional | 6379 | redis 服务端口 |
|
||||
| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 |
|
||||
| redis.username | string | optional | - | 登陆 redis 的用户名 |
|
||||
| redis.password | string | optional | - | 登陆 redis 的密码 |
|
||||
| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
||||
| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
serviceName: my-redis.dns
|
||||
timeout: 2000
|
||||
```
|
||||
23
plugins/wasm-go/extensions/ai-cache/go.mod
Normal file
23
plugins/wasm-go/extensions/ai-cache/go.mod
Normal file
@@ -0,0 +1,23 @@
|
||||
// File generated by hgctl. Modify as required.
|
||||
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache
|
||||
|
||||
go 1.19
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/tidwall/resp v0.1.1
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
)
|
||||
23
plugins/wasm-go/extensions/ai-cache/go.sum
Normal file
23
plugins/wasm-go/extensions/ai-cache/go.sum
Normal file
@@ -0,0 +1,23 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
371
plugins/wasm-go/extensions/ai-cache/main.go
Normal file
371
plugins/wasm-go/extensions/ai-cache/main.go
Normal file
@@ -0,0 +1,371 @@
|
||||
// File generated by hgctl. Modify as required.
|
||||
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
const (
|
||||
CacheKeyContextKey = "cacheKey"
|
||||
CacheContentContextKey = "cacheContent"
|
||||
PartialMessageContextKey = "partialMessage"
|
||||
ToolCallsContextKey = "toolCalls"
|
||||
StreamContextKey = "stream"
|
||||
DefaultCacheKeyPrefix = "higress-ai-cache:"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-cache",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
// @Name ai-cache
|
||||
// @Category protocol
|
||||
// @Phase AUTHN
|
||||
// @Priority 10
|
||||
// @Title zh-CN AI Cache
|
||||
// @Description zh-CN 大模型结果缓存
|
||||
// @IconUrl
|
||||
// @Version 0.1.0
|
||||
//
|
||||
// @Contact.name johnlanni
|
||||
// @Contact.url
|
||||
// @Contact.email
|
||||
//
|
||||
// @Example
|
||||
// redis:
|
||||
// serviceName: my-redis.dns
|
||||
// timeout: 2000
|
||||
// cacheKeyFrom:
|
||||
// requestBody: "messages.@reverse.0.content"
|
||||
// cacheValueFrom:
|
||||
// responseBody: "choices.0.message.content"
|
||||
// cacheStreamValueFrom:
|
||||
// responseBody: "choices.0.delta.content"
|
||||
// returnResponseTemplate: |
|
||||
// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}
|
||||
// returnStreamResponseTemplate: |
|
||||
// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}
|
||||
//
|
||||
// data:[DONE]
|
||||
//
|
||||
// @End
|
||||
|
||||
type RedisInfo struct {
|
||||
// @Title zh-CN redis 服务名称
|
||||
// @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local
|
||||
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
|
||||
// @Title zh-CN redis 服务端口
|
||||
// @Description zh-CN 默认值为6379
|
||||
ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"`
|
||||
// @Title zh-CN 用户名
|
||||
// @Description zh-CN 登陆 redis 的用户名,非必填
|
||||
Username string `required:"false" yaml:"username" json:"username"`
|
||||
// @Title zh-CN 密码
|
||||
// @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码
|
||||
Password string `required:"false" yaml:"password" json:"password"`
|
||||
// @Title zh-CN 请求超时
|
||||
// @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒
|
||||
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
type KVExtractor struct {
|
||||
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||||
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
|
||||
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||||
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
|
||||
}
|
||||
|
||||
type PluginConfig struct {
|
||||
// @Title zh-CN Redis 地址信息
|
||||
// @Description zh-CN 用于存储缓存结果的 Redis 地址
|
||||
RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"`
|
||||
// @Title zh-CN 缓存 key 的来源
|
||||
// @Description zh-CN 往 redis 里存时,使用的 key 的提取方式
|
||||
CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"`
|
||||
// @Title zh-CN 缓存 value 的来源
|
||||
// @Description zh-CN 往 redis 里存时,使用的 value 的提取方式
|
||||
CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"`
|
||||
// @Title zh-CN 流式响应下,缓存 value 的来源
|
||||
// @Description zh-CN 往 redis 里存时,使用的 value 的提取方式
|
||||
CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"`
|
||||
// @Title zh-CN 返回 HTTP 响应的模版
|
||||
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
|
||||
ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"`
|
||||
// @Title zh-CN 返回流式 HTTP 响应的模版
|
||||
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
|
||||
ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"`
|
||||
// @Title zh-CN 缓存的过期时间
|
||||
// @Description zh-CN 单位是秒,默认值为0,即永不过期
|
||||
CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"`
|
||||
// @Title zh-CN Redis缓存Key的前缀
|
||||
// @Description zh-CN 默认值是"higress-ai-cache:"
|
||||
CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"`
|
||||
redisClient wrapper.RedisClient `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
c.RedisInfo.ServiceName = json.Get("redis.serviceName").String()
|
||||
if c.RedisInfo.ServiceName == "" {
|
||||
return errors.New("redis service name must not by empty")
|
||||
}
|
||||
c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int())
|
||||
if c.RedisInfo.ServicePort == 0 {
|
||||
if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") {
|
||||
// use default logic port which is 80 for static service
|
||||
c.RedisInfo.ServicePort = 80
|
||||
} else {
|
||||
c.RedisInfo.ServicePort = 6379
|
||||
}
|
||||
}
|
||||
c.RedisInfo.Username = json.Get("redis.username").String()
|
||||
c.RedisInfo.Password = json.Get("redis.password").String()
|
||||
c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int())
|
||||
if c.RedisInfo.Timeout == 0 {
|
||||
c.RedisInfo.Timeout = 1000
|
||||
}
|
||||
c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String()
|
||||
if c.CacheKeyFrom.RequestBody == "" {
|
||||
c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content"
|
||||
}
|
||||
c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String()
|
||||
if c.CacheValueFrom.ResponseBody == "" {
|
||||
c.CacheValueFrom.ResponseBody = "choices.0.message.content"
|
||||
}
|
||||
c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String()
|
||||
if c.CacheStreamValueFrom.ResponseBody == "" {
|
||||
c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content"
|
||||
}
|
||||
c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String()
|
||||
if c.ReturnResponseTemplate == "" {
|
||||
c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
}
|
||||
c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String()
|
||||
if c.ReturnStreamResponseTemplate == "" {
|
||||
c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
|
||||
}
|
||||
c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String()
|
||||
if c.CacheKeyPrefix == "" {
|
||||
c.CacheKeyPrefix = DefaultCacheKeyPrefix
|
||||
}
|
||||
c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: c.RedisInfo.ServiceName,
|
||||
Port: int64(c.RedisInfo.ServicePort),
|
||||
})
|
||||
return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout))
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
// The request does not have a body.
|
||||
if contentType == "" {
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
log.Warnf("content is not json, can't process:%s", contentType)
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
// The request has a body and requires delaying the header transmission until a cache miss occurs,
|
||||
// at which point the header should be sent.
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func TrimQuote(source string) string {
|
||||
return strings.Trim(source, `"`)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
bodyJson := gjson.ParseBytes(body)
|
||||
// TODO: It may be necessary to support stream mode determination for different LLM providers.
|
||||
stream := false
|
||||
if bodyJson.Get("stream").Bool() {
|
||||
stream = true
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
} else if ctx.GetContext(StreamContextKey) != nil {
|
||||
stream = true
|
||||
}
|
||||
key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw)
|
||||
if key == "" {
|
||||
log.Debug("parse key from request body failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetContext(CacheKeyContextKey, key)
|
||||
err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("redis get key:%s failed, err:%v", key, err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if response.IsNull() {
|
||||
log.Debugf("cache miss, key:%s", key)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
log.Debugf("cache hit, key:%s", key)
|
||||
ctx.SetContext(CacheKeyContextKey, nil)
|
||||
if !stream {
|
||||
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1)
|
||||
} else {
|
||||
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("redis access failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
for _, msg := range subMessages {
|
||||
if strings.HasPrefix(msg, "data:") {
|
||||
message = msg
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(message) < 6 {
|
||||
log.Errorf("invalid message:%s", message)
|
||||
return ""
|
||||
}
|
||||
// skip the prefix "data:"
|
||||
bodyJson := message[5:]
|
||||
if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() {
|
||||
tempContentI := ctx.GetContext(CacheContentContextKey)
|
||||
if tempContentI == nil {
|
||||
content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw)
|
||||
ctx.SetContext(CacheContentContextKey, content)
|
||||
return content
|
||||
}
|
||||
append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw)
|
||||
content := tempContentI.(string) + append
|
||||
ctx.SetContext(CacheContentContextKey, content)
|
||||
return content
|
||||
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
|
||||
// TODO: compatible with other providers
|
||||
ctx.SetContext(ToolCallsContextKey, struct{}{})
|
||||
return ""
|
||||
}
|
||||
log.Debugf("unknown message:%s", bodyJson)
|
||||
return ""
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
if ctx.GetContext(ToolCallsContextKey) != nil {
|
||||
// we should not cache tool call result
|
||||
return chunk
|
||||
}
|
||||
keyI := ctx.GetContext(CacheKeyContextKey)
|
||||
if keyI == nil {
|
||||
return chunk
|
||||
}
|
||||
if !isLastChunk {
|
||||
stream := ctx.GetContext(StreamContextKey)
|
||||
if stream == nil {
|
||||
tempContentI := ctx.GetContext(CacheContentContextKey)
|
||||
if tempContentI == nil {
|
||||
ctx.SetContext(CacheContentContextKey, chunk)
|
||||
return chunk
|
||||
}
|
||||
tempContent := tempContentI.([]byte)
|
||||
tempContent = append(tempContent, chunk...)
|
||||
ctx.SetContext(CacheContentContextKey, tempContent)
|
||||
} else {
|
||||
var partialMessage []byte
|
||||
partialMessageI := ctx.GetContext(PartialMessageContextKey)
|
||||
if partialMessageI != nil {
|
||||
partialMessage = append(partialMessageI.([]byte), chunk...)
|
||||
} else {
|
||||
partialMessage = chunk
|
||||
}
|
||||
messages := strings.Split(string(partialMessage), "\n\n")
|
||||
for i, msg := range messages {
|
||||
if i < len(messages)-1 {
|
||||
// process complete message
|
||||
processSSEMessage(ctx, config, msg, log)
|
||||
}
|
||||
}
|
||||
if !strings.HasSuffix(string(partialMessage), "\n\n") {
|
||||
ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1]))
|
||||
} else {
|
||||
ctx.SetContext(PartialMessageContextKey, nil)
|
||||
}
|
||||
}
|
||||
return chunk
|
||||
}
|
||||
// last chunk
|
||||
key := keyI.(string)
|
||||
stream := ctx.GetContext(StreamContextKey)
|
||||
var value string
|
||||
if stream == nil {
|
||||
var body []byte
|
||||
tempContentI := ctx.GetContext(CacheContentContextKey)
|
||||
if tempContentI != nil {
|
||||
body = append(tempContentI.([]byte), chunk...)
|
||||
} else {
|
||||
body = chunk
|
||||
}
|
||||
bodyJson := gjson.ParseBytes(body)
|
||||
|
||||
value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw)
|
||||
if value == "" {
|
||||
log.Warnf("parse value from response body failded, body:%s", body)
|
||||
return chunk
|
||||
}
|
||||
} else {
|
||||
if len(chunk) > 0 {
|
||||
var lastMessage []byte
|
||||
partialMessageI := ctx.GetContext(PartialMessageContextKey)
|
||||
if partialMessageI != nil {
|
||||
lastMessage = append(partialMessageI.([]byte), chunk...)
|
||||
} else {
|
||||
lastMessage = chunk
|
||||
}
|
||||
if !strings.HasSuffix(string(lastMessage), "\n\n") {
|
||||
log.Warnf("invalid lastMessage:%s", lastMessage)
|
||||
return chunk
|
||||
}
|
||||
// remove the last \n\n
|
||||
lastMessage = lastMessage[:len(lastMessage)-2]
|
||||
value = processSSEMessage(ctx, config, string(lastMessage), log)
|
||||
} else {
|
||||
tempContentI := ctx.GetContext(CacheContentContextKey)
|
||||
if tempContentI == nil {
|
||||
return chunk
|
||||
}
|
||||
value = tempContentI.(string)
|
||||
}
|
||||
}
|
||||
config.redisClient.Set(config.CacheKeyPrefix+key, value, nil)
|
||||
if config.CacheTTL != 0 {
|
||||
config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil)
|
||||
}
|
||||
return chunk
|
||||
}
|
||||
52
plugins/wasm-go/extensions/ai-cache/option.yaml
Normal file
52
plugins/wasm-go/extensions/ai-cache/option.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
# File generated by hgctl. Modify as required.
|
||||
|
||||
version: 1.0.0
|
||||
|
||||
build:
|
||||
# The official builder image version
|
||||
builder:
|
||||
go: 1.19
|
||||
tinygo: 0.28.1
|
||||
oras: 1.0.0
|
||||
# The WASM plugin project directory
|
||||
input: ./
|
||||
# The output of the build products
|
||||
output:
|
||||
# Choose between 'files' and 'image'
|
||||
type: files
|
||||
# Destination address: when type=files, specify the local directory path, e.g., './out' or
|
||||
# type=image, specify the remote docker repository, e.g., 'docker.io/<your_username>/<your_image>'
|
||||
dest: ./out
|
||||
# The authentication configuration for pushing image to the docker repository
|
||||
docker-auth: ~/.docker/config.json
|
||||
# The directory for the WASM plugin configuration structure
|
||||
model-dir: ./
|
||||
# The WASM plugin configuration structure name
|
||||
model: PluginConfig
|
||||
# Enable debug mode
|
||||
debug: false
|
||||
|
||||
test:
|
||||
# Test environment name, that is a docker compose project name
|
||||
name: wasm-test
|
||||
# The output path to build products, that is the source of test configuration parameters
|
||||
from-path: ./out
|
||||
# The test configuration source
|
||||
test-path: ./test
|
||||
# Docker compose configuration, which is empty, looks for the following files from 'test-path':
|
||||
# compose.yaml, compose.yml, docker-compose.yml, docker-compose.yaml
|
||||
compose-file:
|
||||
# Detached mode: Run containers in the background
|
||||
detach: false
|
||||
|
||||
install:
|
||||
# The namespace of the installation
|
||||
namespace: higress-system
|
||||
# Use to validate WASM plugin configuration when install by yaml
|
||||
spec-yaml: ./out/spec.yaml
|
||||
# Installation source. Choose between 'from-yaml' and 'from-go-project'
|
||||
from-yaml: ./test/plugin-conf.yaml
|
||||
# If 'from-go-src' is non-empty, the output type of the build option must be 'image'
|
||||
from-go-src:
|
||||
# Enable debug mode
|
||||
debug: false
|
||||
3
plugins/wasm-go/extensions/ai-prompt-decorator/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-prompt-decorator/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
main.wasm
|
||||
tmp/
|
||||
82
plugins/wasm-go/extensions/ai-prompt-decorator/README.md
Normal file
82
plugins/wasm-go/extensions/ai-prompt-decorator/README.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# 简介
|
||||
AI提示词修饰插件,通过在与大模型发起的请求前后插入指定信息来调整大模型的输出。
|
||||
|
||||
# 配置说明
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `decorators` | array of object | 必填 | - | 修饰设置 |
|
||||
|
||||
template object 配置说明:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `name` | string | 必填 | - | 修饰名称 |
|
||||
| `decorator.prepend` | array of message object | 必填 | - | 在初始输入之前插入的语句 |
|
||||
| `decorator.append` | array of message object | 必填 | - | 在初始输入之后插入的语句 |
|
||||
|
||||
message object 配置说明:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `role` | string | 必填 | - | 角色 |
|
||||
| `content` | string | 必填 | - | 消息 |
|
||||
|
||||
# 示例
|
||||
|
||||
配置示例如下:
|
||||
|
||||
```yaml
|
||||
decorators:
|
||||
- name: "hangzhou-guide"
|
||||
decorator:
|
||||
prepend:
|
||||
- role: system
|
||||
content: "You will always respond in the Chinese language."
|
||||
- role: user
|
||||
content: "Assume you are from Hangzhou."
|
||||
append:
|
||||
- role: user
|
||||
content: "Don't introduce Hangzhou's food."
|
||||
```
|
||||
|
||||
使用以上配置发起请求:
|
||||
|
||||
```bash
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please introduce your home."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
响应如下:
|
||||
|
||||
```
|
||||
{
|
||||
"id": "chatcmpl-9UYwQlEg6GwAswEZBDYXl41RU4gab",
|
||||
"object": "chat.completion",
|
||||
"created": 1717071182,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "杭州是一个美丽的城市,有着悠久的历史和富有特色的文化。这里风景优美,有西湖、雷峰塔等著名景点,吸引着许多游客前来观光。杭州人民热情好客,城市宁静安逸,是一个适合居住和旅游的地方。"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 49,
|
||||
"completion_tokens": 117,
|
||||
"total_tokens": 166
|
||||
},
|
||||
"system_fingerprint": null
|
||||
}
|
||||
```
|
||||
19
plugins/wasm-go/extensions/ai-prompt-decorator/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-prompt-decorator/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module ai-prompt-decorator
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
25
plugins/wasm-go/extensions/ai-prompt-decorator/go.sum
Normal file
25
plugins/wasm-go/extensions/ai-prompt-decorator/go.sum
Normal file
@@ -0,0 +1,25 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
94
plugins/wasm-go/extensions/ai-prompt-decorator/main.go
Normal file
94
plugins/wasm-go/extensions/ai-prompt-decorator/main.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-prompt-decorator",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AIPromptDecoratorConfig struct {
|
||||
decorators map[string]string
|
||||
}
|
||||
|
||||
func removeBrackets(raw string) (string, error) {
|
||||
startIndex := strings.Index(raw, "{")
|
||||
endIndex := strings.LastIndex(raw, "}")
|
||||
if startIndex == -1 || endIndex == -1 {
|
||||
return raw, errors.New("message format is wrong!")
|
||||
} else {
|
||||
return raw[startIndex : endIndex+1], nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIPromptDecoratorConfig, log wrapper.Log) error {
|
||||
config.decorators = make(map[string]string)
|
||||
for _, v := range json.Get("decorators").Array() {
|
||||
config.decorators[v.Get("name").String()] = v.Get("decorator").Raw
|
||||
// log.Info(v.Get("decorator").Raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log wrapper.Log) types.Action {
|
||||
decorator, _ := proxywasm.GetHttpRequestHeader("decorator")
|
||||
if decorator == "" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetContext("decorator", decorator)
|
||||
proxywasm.RemoveHttpRequestHeader("decorator")
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log wrapper.Log) types.Action {
|
||||
decoratorName := ctx.GetContext("decorator").(string)
|
||||
decorator := config.decorators[decoratorName]
|
||||
|
||||
messageJson := `{"messages":[]}`
|
||||
|
||||
prependMessage := gjson.Get(decorator, "prepend")
|
||||
if prependMessage.Exists() {
|
||||
for _, entry := range prependMessage.Array() {
|
||||
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
rawMessage := gjson.GetBytes(body, "messages")
|
||||
if rawMessage.Exists() {
|
||||
for _, entry := range rawMessage.Array() {
|
||||
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
appendMessage := gjson.Get(decorator, "append")
|
||||
if appendMessage.Exists() {
|
||||
for _, entry := range appendMessage.Array() {
|
||||
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
newbody, err := sjson.SetRaw(string(body), "messages", gjson.Get(messageJson, "messages").Raw)
|
||||
if err != nil {
|
||||
log.Error("modify body failed")
|
||||
}
|
||||
if err = proxywasm.ReplaceHttpRequestBody([]byte(newbody)); err != nil {
|
||||
log.Error("rewrite body failed")
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
3
plugins/wasm-go/extensions/ai-prompt-template/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-prompt-template/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
main.wasm
|
||||
tmp/
|
||||
48
plugins/wasm-go/extensions/ai-prompt-template/README.md
Normal file
48
plugins/wasm-go/extensions/ai-prompt-template/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# 简介
|
||||
AI提示词模板,用于快速构建同类型的AI请求。
|
||||
|
||||
# 配置说明
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `templates` | array of object | 必填 | - | 模板设置 |
|
||||
|
||||
template object 配置说明:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `name` | string | 必填 | - | 模板名称 |
|
||||
| `template.model` | string | 必填 | - | 模型名称 |
|
||||
| `template.messages` | array of object | 必填 | - | 大模型输入 |
|
||||
|
||||
message object 配置说明:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------|
|
||||
| `role` | string | 必填 | - | 角色 |
|
||||
| `content` | string | 必填 | - | 消息 |
|
||||
|
||||
配置示例如下:
|
||||
|
||||
```yaml
|
||||
templates:
|
||||
- name: "developer-chat"
|
||||
template:
|
||||
model: gpt-3.5-turbo
|
||||
messages:
|
||||
- role: system
|
||||
content: "You are a {{program}} expert, in {{language}} programming language."
|
||||
- role: user
|
||||
content: "Write me a {{program}} program."
|
||||
```
|
||||
|
||||
使用以上配置的请求body示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"template": "developer-chat",
|
||||
"properties": {
|
||||
"program": "quick sort",
|
||||
"language": "python"
|
||||
}
|
||||
}
|
||||
```
|
||||
19
plugins/wasm-go/extensions/ai-prompt-template/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-prompt-template/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module ai-prompt-template
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
25
plugins/wasm-go/extensions/ai-prompt-template/go.sum
Normal file
25
plugins/wasm-go/extensions/ai-prompt-template/go.sum
Normal file
@@ -0,0 +1,25 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
55
plugins/wasm-go/extensions/ai-prompt-template/main.go
Normal file
55
plugins/wasm-go/extensions/ai-prompt-template/main.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-prompt-template",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AIPromptTemplateConfig struct {
|
||||
templates map[string]string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.Log) error {
|
||||
config.templates = make(map[string]string)
|
||||
for _, v := range json.Get("templates").Array() {
|
||||
config.templates[v.Get("name").String()] = v.Get("template").Raw
|
||||
log.Info(v.Get("template").Raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
|
||||
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
|
||||
if templateEnable != "true" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptTemplateConfig, body []byte, log wrapper.Log) types.Action {
|
||||
if gjson.GetBytes(body, "template").Exists() && gjson.GetBytes(body, "properties").Exists() {
|
||||
name := gjson.GetBytes(body, "template").String()
|
||||
template := config.templates[name]
|
||||
for key, value := range gjson.GetBytes(body, "properties").Map() {
|
||||
template = strings.ReplaceAll(template, fmt.Sprintf("{{%s}}", key), value.String())
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestBody([]byte(template))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -19,14 +19,14 @@ description: AI 代理插件配置参考
|
||||
|
||||
`provider`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen |
|
||||
| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>可以使用 "*" 为键来配置通用兜底映射关系 |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------ |
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称 |
|
||||
| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>可以使用 "*" 为键来配置通用兜底映射关系 |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
|
||||
`context`的配置字段说明如下:
|
||||
|
||||
@@ -77,6 +77,10 @@ Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下:
|
||||
|
||||
零一万物所对应的 `type` 为 `yi`。它并无特有的配置字段。
|
||||
|
||||
#### 智谱AI(Zhipu AI)
|
||||
|
||||
智谱AI所对应的 `type` 为 `zhipuai`。它并无特有的配置字段。
|
||||
|
||||
#### DeepSeek(DeepSeek)
|
||||
|
||||
DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
|
||||
@@ -85,13 +89,47 @@ DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
|
||||
|
||||
Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
|
||||
|
||||
#### 文心一言(Baidu)
|
||||
|
||||
文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
|
||||
|
||||
#### MiniMax
|
||||
|
||||
MiniMax所对应的 `type` 为 `minimax`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
|
||||
| `minimaxGroupId` | string | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时必填 | - | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时会使用ChatCompletion Pro,需要设置groupID |
|
||||
|
||||
#### Anthropic Claude
|
||||
|
||||
Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------|--------|-----|-----|-------------------|
|
||||
| `version` | string | 必填 | - | Claude 服务的 API 版本 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------|--------|------|-----|----------------------------------|
|
||||
| `claudeVersion` | string | 可选 | - | Claude 服务的 API 版本,默认为 2023-06-01 |
|
||||
|
||||
#### Ollama
|
||||
|
||||
Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------|
|
||||
| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 |
|
||||
| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为11434 |
|
||||
|
||||
#### 混元
|
||||
|
||||
混元所对应的 `type` 为 `hunyuan`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------|
|
||||
| `hunyuanAuthId` | string | 必填 | - | 混元用于v3版本认证的id |
|
||||
| `hunyuanAuthKey` | string | 必填 | - | 混元用于v3版本认证的key |
|
||||
|
||||
#### 阶跃星辰 (Stepfun)
|
||||
|
||||
阶跃星辰所对应的 `type` 为 `stepfun`。它并无特有的配置字段。
|
||||
|
||||
## 用法示例
|
||||
|
||||
@@ -494,6 +532,7 @@ provider:
|
||||
type: claude
|
||||
apiTokens:
|
||||
- "YOUR_CLAUDE_API_TOKEN"
|
||||
version: "2023-06-01"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -515,27 +554,214 @@ provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg_01K8iLH18FGN7Xd9deurwtoD",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 16,
|
||||
"output_tokens": 141
|
||||
},
|
||||
"content": [
|
||||
"id": "msg_01Jt3GzyjuzymnxmZERJguLK",
|
||||
"choices": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "你好!我是Claude,一个由Anthropic公司开发的人工智能助手。我的任务是尽我所能帮助人类,比如回答问题,提供建议和意见,协助完成任务等。我掌握了很多知识,也具备一定的分析和推理能力,但我不是人类,也没有实体的身体。很高兴认识你!如果有什么需要帮助的地方,欢迎随时告诉我。"
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "您好,我是一个由人工智能公司Anthropic开发的聊天助手。我的名字叫Claude,是一个聪明友善、知识渊博的对话系统。很高兴认识您!我可以就各种话题与您聊天,回答问题,提供建议和帮助。我会尽最大努力给您有帮助的回复。希望我们能有个愉快的交流!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"stop_reason": "end_turn"
|
||||
"created": 1717385918,
|
||||
"model": "claude-3-opus-20240229",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 16,
|
||||
"completion_tokens": 126,
|
||||
"total_tokens": 142
|
||||
}
|
||||
}
|
||||
```
|
||||
### 使用 OpenAI 协议代理混元服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: "hunyuan"
|
||||
hunyuanAuthKey: "<YOUR AUTH KEY>"
|
||||
apiTokens:
|
||||
- ""
|
||||
hunyuanAuthId: "<YOUR AUTH ID>"
|
||||
timeout: 1200000
|
||||
modelMapping:
|
||||
"*": "hunyuan-lite"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
请求脚本:
|
||||
```sh
|
||||
|
||||
curl --location 'http://<your higress domain>/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个名专业的开发人员!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "fd140c3e-0b69-4b19-849b-d354d32a6162",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是一名专业的开发人员。"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1717493117,
|
||||
"model": "hunyuan-lite",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 24
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理百度文心一言服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: baidu
|
||||
apiTokens:
|
||||
- "YOUR_BAIDU_API_TOKEN"
|
||||
modelMapping:
|
||||
'gpt-3': "ERNIE-4.0"
|
||||
'*': "ERNIE-4.0"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "as-e90yfg1pk1",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好,我是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1717251488,
|
||||
"model": "ERNIE-4.0",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 4,
|
||||
"completion_tokens": 33,
|
||||
"total_tokens": 37
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理MiniMax服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: minimax
|
||||
apiTokens:
|
||||
- "YOUR_MINIMAX_API_TOKEN"
|
||||
modelMapping:
|
||||
"gpt-3": "abab6.5g-chat"
|
||||
"gpt-4": "abab6.5-chat"
|
||||
"*": "abab6.5g-chat"
|
||||
minimaxGroupId: "YOUR_MINIMAX_GROUP_ID"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "02b2251f8c6c09d68c1743f07c72afd7",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "你好!我是MM智能助理,一款由MiniMax自研的大型语言模型。我可以帮助你解答问题,提供信息,进行对话等。有什么可以帮助你的吗?",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1717760544,
|
||||
"model": "abab6.5s-chat",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"total_tokens": 106
|
||||
},
|
||||
"input_sensitive": false,
|
||||
"output_sensitive": false,
|
||||
"input_sensitive_type": 0,
|
||||
"output_sensitive_type": 0,
|
||||
"base_resp": {
|
||||
"status_code": 0,
|
||||
"status_msg": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 完整配置示例
|
||||
|
||||
### Kubernetes 示例
|
||||
|
||||
以下以使用 OpenAI 协议代理 Groq 服务为例,展示完整的插件配置示例。
|
||||
|
||||
```yaml
|
||||
@@ -606,4 +832,131 @@ curl "http://<YOUR-DOMAIN>/v1/chat/completions" -H "Content-Type: application/js
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
```
|
||||
|
||||
### Docker-Compose 示例
|
||||
|
||||
`docker-compose.yml` 配置文件:
|
||||
|
||||
```yaml
|
||||
version: '3.7'
|
||||
services:
|
||||
envoy:
|
||||
image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/envoy:1.20
|
||||
entrypoint: /usr/local/bin/envoy
|
||||
# 开启了 debug 级别日志方便调试
|
||||
command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug
|
||||
networks:
|
||||
- higress-net
|
||||
ports:
|
||||
- "10000:10000"
|
||||
volumes:
|
||||
- ./envoy.yaml:/etc/envoy/envoy.yaml
|
||||
- ./plugin.wasm:/etc/envoy/plugin.wasm
|
||||
networks:
|
||||
higress-net: {}
|
||||
```
|
||||
|
||||
`envoy.yaml` 配置文件:
|
||||
|
||||
```yaml
|
||||
admin:
|
||||
address:
|
||||
socket_address:
|
||||
protocol: TCP
|
||||
address: 0.0.0.0
|
||||
port_value: 9901
|
||||
static_resources:
|
||||
listeners:
|
||||
- name: listener_0
|
||||
address:
|
||||
socket_address:
|
||||
protocol: TCP
|
||||
address: 0.0.0.0
|
||||
port_value: 10000
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
stat_prefix: ingress_http
|
||||
# Output envoy logs to stdout
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stdout
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StdoutAccessLog
|
||||
# Modify as required
|
||||
route_config:
|
||||
name: local_route
|
||||
virtual_hosts:
|
||||
- name: local_service
|
||||
domains: [ "*" ]
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/"
|
||||
route:
|
||||
cluster: claude
|
||||
timeout: 300s
|
||||
http_filters:
|
||||
- name: claude
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: claude
|
||||
vm_config:
|
||||
runtime: envoy.wasm.runtime.v8
|
||||
code:
|
||||
local:
|
||||
filename: /etc/envoy/plugin.wasm
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: | # 插件配置
|
||||
{
|
||||
"provider": {
|
||||
"type": "claude",
|
||||
"apiTokens": [
|
||||
"YOUR_API_TOKEN"
|
||||
]
|
||||
}
|
||||
}
|
||||
- name: envoy.filters.http.router
|
||||
clusters:
|
||||
- name: claude
|
||||
connect_timeout: 30s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: claude
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.anthropic.com # API 服务地址
|
||||
port_value: 443
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
"sni": "api.anthropic.com"
|
||||
```
|
||||
|
||||
访问示例:
|
||||
|
||||
```bash
|
||||
curl "http://localhost:10000/v1/chat/completions" -H "Content-Type: application/json" -d '{
|
||||
"model": "claude-3-opus-20240229",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
64
plugins/wasm-go/extensions/ai-proxy/README_dev.md
Normal file
64
plugins/wasm-go/extensions/ai-proxy/README_dev.md
Normal file
@@ -0,0 +1,64 @@
|
||||
## 构建方法
|
||||
|
||||
确认本机已安装 Docker,然后根据操作系统选择对应的构建命令,并在 `ai-proxy` 目录下执行。构建产物将输出至 `out` 目录。
|
||||
|
||||
***Linux/macOS:***
|
||||
|
||||
```shell
|
||||
DOCKER_BUILDKIT=1; docker build --build-arg PLUGIN_NAME=ai-proxy --build-arg EXTRA_TAGS=proxy_wasm_version_0_2_100 --build-arg BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.19-tinygo0.28.1-oras1.0.0 -t ai-proxy:0.0.1 --output ./out ../..
|
||||
```
|
||||
|
||||
***Windows:***
|
||||
|
||||
```powershell
|
||||
$env:DOCKER_BUILDKIT=1; docker build --build-arg PLUGIN_NAME=ai-proxy --build-arg EXTRA_TAGS=proxy_wasm_version_0_2_100 --build-arg BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.19-tinygo0.28.1-oras1.0.0 -t ai-proxy:0.0.1 --output .\out ..\..
|
||||
```
|
||||
|
||||
## 本地运行
|
||||
参考:https://higress.io/zh-cn/docs/user/wasm-go
|
||||
需要注意的是,higress/plugins/wasm-go/extensions/ai-proxy/envoy.yaml中的clusters字段,记得改成你需要地址,比如混元的话:就会有如下的一个cluster的配置:
|
||||
```yaml
|
||||
<省略>
|
||||
static_resources:
|
||||
<省略>
|
||||
clusters:
|
||||
load_assignment:
|
||||
cluster_name: moonshot
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: hunyuan.tencentcloudapi.com
|
||||
port_value: 443
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
"sni": "hunyuan.tencentcloudapi.com"
|
||||
```
|
||||
|
||||
而后你就可以在本地的pod中查看相应的输出,请求样例如下:
|
||||
```sh
|
||||
curl --location 'http://127.0.0.1:10000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个名专业的开发人员!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
## 测试须知
|
||||
|
||||
由于 `ai-proxy` 插件使用了 Higress 对数据面定制的特殊功能,因此在测试时需要使用版本不低于 1.4.0-rc.1 的 Higress Gateway 镜像。
|
||||
@@ -36,7 +36,7 @@ func main() {
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
||||
//log.Debugf("loading config: %s", json.String())
|
||||
// log.Debugf("loading config: %s", json.String())
|
||||
|
||||
pluginConfig.FromJson(json)
|
||||
if err := pluginConfig.Validate(); err != nil {
|
||||
|
||||
338
plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Normal file
338
plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// baiduProvider is the provider for baidu ernie bot service.
|
||||
|
||||
const (
|
||||
baiduDomain = "aip.baidubce.com"
|
||||
)
|
||||
|
||||
var baiduModelToPathSuffixMap = map[string]string{
|
||||
"ERNIE-4.0-8K": "completions_pro",
|
||||
"ERNIE-3.5-8K": "completions",
|
||||
"ERNIE-3.5-128K": "ernie-3.5-128k",
|
||||
"ERNIE-Speed-8K": "ernie_speed",
|
||||
"ERNIE-Speed-128K": "ernie-speed-128k",
|
||||
"ERNIE-Tiny-8K": "ernie-tiny-8k",
|
||||
"ERNIE-Bot-8K": "ernie_bot_8k",
|
||||
"BLOOMZ-7B": "bloomz_7b1",
|
||||
}
|
||||
|
||||
type baiduProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &baiduProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type baiduProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (b *baiduProvider) GetProviderType() string {
|
||||
return providerTypeBaidu
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(baiduDomain)
|
||||
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 使用文心一言接口协议
|
||||
if b.config.protocol == protocolOriginal {
|
||||
request := &baiduTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
path := b.GetRequestPath(request.Model)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
if b.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
err := b.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
b.setSystemContent(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, b.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
path := b.GetRequestPath(mappedModel)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
if b.config.context == nil {
|
||||
baiduRequest := b.baiduTextGenRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
|
||||
}
|
||||
|
||||
err := b.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
baiduRequest := b.baiduTextGenRequest(request)
|
||||
if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
|
||||
if b.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// sample event response:
|
||||
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}
|
||||
|
||||
// sample end event response:
|
||||
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}}
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
var baiduResponse baiduTextGenStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil {
|
||||
log.Errorf("unable to unmarshal baidu response: %v", err)
|
||||
continue
|
||||
}
|
||||
response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse)
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
log.Errorf("unable to marshal response: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
b.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
baiduResponse := &baiduTextGenResponse{}
|
||||
if err := json.Unmarshal(body, baiduResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err)
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg)
|
||||
}
|
||||
response := b.responseBaidu2OpenAI(ctx, baiduResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
type baiduTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func (b *baiduProvider) GetRequestPath(baiduModel string) string {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
|
||||
if !ok {
|
||||
suffix = baiduModel
|
||||
}
|
||||
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
|
||||
}
|
||||
|
||||
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
|
||||
request.System = content
|
||||
}
|
||||
|
||||
func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest {
|
||||
baiduRequest := baiduTextGenRequest{
|
||||
Messages: make([]chatMessage, 0, len(request.Messages)),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
UserId: request.User,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == roleSystem {
|
||||
baiduRequest.System = message.Content
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
type baiduTextGenResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage baiduTextGenResponseUsage `json:"usage"`
|
||||
baiduTextGenResponseError
|
||||
}
|
||||
|
||||
type baiduTextGenResponseError struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type baiduTextGenStreamResponse struct {
|
||||
baiduTextGenResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type baiduTextGenResponseUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse {
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
|
||||
FinishReason: finishReasonStop,
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: response.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: chatCompletionUsage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse {
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
|
||||
}
|
||||
if response.IsEnd {
|
||||
choice.FinishReason = finishReasonStop
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: response.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: chatCompletionUsage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
367
plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Normal file
367
plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// claudeProvider is the provider for Claude service.
|
||||
const (
|
||||
claudeDomain = "api.anthropic.com"
|
||||
claudeChatCompletionPath = "/v1/messages"
|
||||
defaultVersion = "2023-06-01"
|
||||
defaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
type claudeProviderInitializer struct{}
|
||||
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []claudeTextGenContent `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Usage claudeTextGenUsage `json:"usage"`
|
||||
Error *claudeTextGenError `json:"error"`
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type claudeTextGenError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type claudeTextGenStreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message claudeTextGenResponse `json:"message"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock *claudeTextGenContent `json:"content_block"`
|
||||
Delta *claudeTextGenDelta `json:"delta"`
|
||||
Usage claudeTextGenUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type claudeProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (c *claudeProvider) GetProviderType() string {
|
||||
return providerTypeClaude
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
_ = util.OverwriteRequestPath(claudeChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(claudeDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
|
||||
|
||||
if c.config.claudeVersion == "" {
|
||||
c.config.claudeVersion = defaultVersion
|
||||
}
|
||||
_ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
// use original protocol
|
||||
if c.config.protocol == protocolOriginal {
|
||||
if c.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &claudeTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// use openai protocol
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, c.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
if c.config.context == nil {
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
|
||||
}
|
||||
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
claudeResponse := &claudeTextGenResponse{}
|
||||
if err := json.Unmarshal(body, claudeResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
||||
}
|
||||
if claudeResponse.Error != nil {
|
||||
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
|
||||
}
|
||||
response := c.responseClaude2OpenAI(ctx, claudeResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
|
||||
if c.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
// only process the line starting with "data:"
|
||||
if strings.HasPrefix(data, "data:") {
|
||||
// extract json data from the line
|
||||
jsonData := strings.TrimPrefix(data, "data:")
|
||||
var claudeResponse claudeTextGenStreamResponse
|
||||
if err := json.Unmarshal([]byte(jsonData), &claudeResponse); err != nil {
|
||||
log.Errorf("unable to unmarshal claude response: %v", err)
|
||||
continue
|
||||
}
|
||||
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse, log)
|
||||
if response != nil {
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
log.Errorf("unable to marshal response: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
c.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
log.Debugf("modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRequest) *claudeTextGenRequest {
|
||||
claudeRequest := claudeTextGenRequest{
|
||||
Model: origRequest.Model,
|
||||
MaxTokens: origRequest.MaxTokens,
|
||||
StopSequences: origRequest.Stop,
|
||||
Stream: origRequest.Stream,
|
||||
Temperature: origRequest.Temperature,
|
||||
TopP: origRequest.TopP,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = defaultMaxTokens
|
||||
}
|
||||
|
||||
for _, message := range origRequest.Messages {
|
||||
if message.Role == roleSystem {
|
||||
claudeRequest.System = message.Content
|
||||
continue
|
||||
}
|
||||
claudeMessage := chatMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenResponse) *chatCompletionResponse {
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: origResponse.Content[0].Text},
|
||||
FinishReason: stopReasonClaude2OpenAI(origResponse.StopReason),
|
||||
}
|
||||
|
||||
return &chatCompletionResponse{
|
||||
Id: origResponse.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: chatCompletionUsage{
|
||||
PromptTokens: origResponse.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Usage.OutputTokens,
|
||||
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
if reason == nil {
|
||||
return ""
|
||||
}
|
||||
switch *reason {
|
||||
case "end_turn":
|
||||
return finishReasonStop
|
||||
case "stop_sequence":
|
||||
return finishReasonStop
|
||||
case "max_tokens":
|
||||
return finishReasonLength
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
}
|
||||
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse, log wrapper.Log) *chatCompletionResponse {
|
||||
switch origResponse.Type {
|
||||
case "message_start":
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{Role: roleAssistant, Content: ""},
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_delta":
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "message_delta":
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{},
|
||||
FinishReason: stopReasonClaude2OpenAI(origResponse.Delta.StopReason),
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
case "content_block_stop", "message_stop":
|
||||
log.Debugf("skip processing response type: %s", origResponse.Type)
|
||||
return nil
|
||||
default:
|
||||
log.Errorf("Unexpected response type: %s", origResponse.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse {
|
||||
return &chatCompletionResponse{
|
||||
Id: response.Message.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
563
plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Normal file
563
plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// hunyuanProvider is the provider for hunyuan AI service.
|
||||
|
||||
const (
|
||||
hunyuanDomain = "hunyuan.tencentcloudapi.com"
|
||||
hunyuanRequestPath = "/"
|
||||
hunyuanChatCompletionTCAction = "ChatCompletions"
|
||||
|
||||
// headers necessary for TC hunyuan api call:
|
||||
// ref: https://cloud.tencent.com/document/api/1729/105701, https://cloud.tencent.com/document/api/1729/101842
|
||||
actionKey = "X-TC-Action"
|
||||
timestampKey = "X-TC-Timestamp"
|
||||
authorizationKey = "Authorization"
|
||||
versionKey = "X-TC-Version"
|
||||
versionValue = "2023-09-01"
|
||||
hostKey = "Host"
|
||||
|
||||
ssePrefix = "data: " // Server-Sent Events (SSE) 类型的流式响应的开始标记
|
||||
hunyuanStreamEndMark = "stop" // 混元的流式的finishReason为stop时,表示结束
|
||||
|
||||
hunyuanAuthKeyLen = 32
|
||||
hunyuanAuthIdLen = 36
|
||||
)
|
||||
|
||||
type hunyuanProviderInitializer struct {
|
||||
}
|
||||
|
||||
// ref: https://console.cloud.tencent.com/api/explorer?Product=hunyuan&Version=2023-09-01&Action=ChatCompletions
|
||||
type hunyuanTextGenRequest struct {
|
||||
Model string `json:"Model"`
|
||||
Messages []hunyuanChatMessage `json:"Messages"`
|
||||
Stream bool `json:"Stream,omitempty"`
|
||||
StreamModeration bool `json:"StreamModeration,omitempty"`
|
||||
TopP float32 `json:"TopP,omitempty"`
|
||||
Temperature float32 `json:"Temperature,omitempty"`
|
||||
EnableEnhancement bool `json:"EnableEnhancement,omitempty"`
|
||||
}
|
||||
|
||||
type hunyuanTextGenResponseNonStreaming struct {
|
||||
Response hunyuanTextGenDetailedResponseNonStreaming `json:"Response"`
|
||||
}
|
||||
|
||||
type hunyuanTextGenDetailedResponseNonStreaming struct {
|
||||
RequestId string `json:"RequestId,omitempty"`
|
||||
Note string `json:"Note"`
|
||||
Choices []hunyuanTextGenChoice `json:"Choices"`
|
||||
Created int64 `json:"Created"`
|
||||
Id string `json:"Id"`
|
||||
Usage hunyuanTextGenUsage `json:"Usage"`
|
||||
}
|
||||
|
||||
type hunyuanTextGenChoice struct {
|
||||
FinishReason string `json:"FinishReason"`
|
||||
Message hunyuanChatMessage `json:"Message,omitempty"` // 当非流式返回时存储大模型生成文字
|
||||
Delta hunyuanChatMessage `json:"Delta,omitempty"` // 流式返回时存储大模型生成文字
|
||||
}
|
||||
|
||||
type hunyuanTextGenUsage struct {
|
||||
PromptTokens int `json:"PromptTokens"`
|
||||
CompletionTokens int `json:"CompletionTokens"`
|
||||
TotalTokens int `json:"TotalTokens"`
|
||||
}
|
||||
|
||||
type hunyuanChatMessage struct {
|
||||
Role string `json:"Role,omitempty"`
|
||||
Content string `json:"Content,omitempty"`
|
||||
}
|
||||
|
||||
func (m *hunyuanProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
// 校验hunyuan id 和 key的合法性
|
||||
if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen {
|
||||
return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &hunyuanProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: hunyuanDomain,
|
||||
}),
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type hunyuanProvider struct {
|
||||
config ProviderConfig
|
||||
|
||||
client wrapper.HttpClient
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) GetProviderType() string {
|
||||
return providerTypeHunyuan
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId)
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
_ = util.OverwriteRequestHost(hunyuanDomain)
|
||||
_ = util.OverwriteRequestPath(hunyuanRequestPath)
|
||||
|
||||
// 添加hunyuan需要的自定义字段
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue)
|
||||
|
||||
// 删除一些字段
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
// 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳)
|
||||
var timestamp int64 = time.Now().Unix()
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
|
||||
// log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp)
|
||||
|
||||
// 使用混元本身接口的协议
|
||||
if m.config.protocol == protocolOriginal {
|
||||
request := &hunyuanTextGenRequest{}
|
||||
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
// 根据确定好的payload进行签名
|
||||
hunyuanBody, _ := json.Marshal(request)
|
||||
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(authorizationKey, authorizedValueNew)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
|
||||
// log.Debugf("#debug nash5# OnRequestBody call hunyuan api using original api! signature computation done!")
|
||||
|
||||
// 若无配置文件,直接返回
|
||||
if m.config.context == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
log.Debugf("#debug nash5# ctx file loaded! callback start, content is: %s", content)
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
m.insertContextMessageIntoHunyuanRequest(request, content)
|
||||
|
||||
// 因为手动插入了context内容,这里需要重新计算签名
|
||||
hunyuanBody, _ := json.Marshal(request)
|
||||
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(authorizationKey, authorizedValueNew)
|
||||
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
log.Debugf("#debug nash5# ctx file load success!")
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
|
||||
log.Debugf("#debug nash5# ctx file load failed!")
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
// 使用open ai接口协议
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
// log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!")
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model) // 设置原始请求的model,以便返回值使用
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model) // 设置真实请求的模型,以便返回值使用
|
||||
|
||||
// 看请求中的stream的设置,相应的我们更该http头
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
} else {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
|
||||
}
|
||||
|
||||
// 若没有配置上下文,直接开始请求
|
||||
if m.config.context == nil {
|
||||
hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
|
||||
|
||||
// 根据确定好的payload进行签名:
|
||||
body, _ := json.Marshal(hunyuanRequest)
|
||||
authorizedValueNew := GetTC3Authorizationcode(
|
||||
m.config.hunyuanAuthId,
|
||||
m.config.hunyuanAuthKey,
|
||||
timestamp,
|
||||
hunyuanDomain,
|
||||
hunyuanChatCompletionTCAction,
|
||||
string(body),
|
||||
)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(authorizationKey, authorizedValueNew)
|
||||
// log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body))
|
||||
|
||||
// // 打印所有的headers
|
||||
// headers, err2 := proxywasm.GetHttpRequestHeaders()
|
||||
// if err2 != nil {
|
||||
// log.Errorf("failed to get request headers: %v", err2)
|
||||
// } else {
|
||||
// // 迭代并打印所有请求头
|
||||
// for _, header := range headers {
|
||||
// log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1])
|
||||
// }
|
||||
// }
|
||||
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
return
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
|
||||
|
||||
// 因为手动插入了context内容,这里需要重新计算签名
|
||||
hunyuanBody, _ := json.Marshal(hunyuanRequest)
|
||||
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(authorizationKey, authorizedValueNew)
|
||||
|
||||
if err := replaceJsonRequestBody(hunyuanRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
// hunyuan的流式返回:
|
||||
//data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
|
||||
|
||||
// openai的流式返回
|
||||
// data: {"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", "object": "chat.completion.chunk", "created": 1677664795, "model": "gpt-3.5-turbo-0613", "choices": [{"delta": {"content": "The "}, "index": 0, "finish_reason": null}]}
|
||||
|
||||
// log.Debugf("#debug nash5# [OnStreamingResponseBody] chunk is: %s", string(chunk))
|
||||
|
||||
// 从上下文获取现有缓冲区数据
|
||||
newBufferedBody := chunk
|
||||
if bufferedBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
newBufferedBody = append(bufferedBody, chunk...)
|
||||
}
|
||||
|
||||
// 初始化处理下标,以及将要返回的处理过的chunks
|
||||
var newEventPivot = -1
|
||||
var outputBuffer []byte
|
||||
|
||||
// 从buffer区取出若干完整的chunk,将其转为openAI格式后返回
|
||||
// 处理可能包含多个事件的缓冲区
|
||||
for {
|
||||
eventStartIndex := bytes.Index(newBufferedBody, []byte(ssePrefix))
|
||||
if eventStartIndex == -1 {
|
||||
break // 没有找到新事件,跳出循环
|
||||
}
|
||||
|
||||
// 移除缓冲区前面非事件部分
|
||||
newBufferedBody = newBufferedBody[eventStartIndex+len(ssePrefix):]
|
||||
|
||||
// 查找事件结束的位置(即下一个事件的开始)
|
||||
newEventPivot = bytes.Index(newBufferedBody, []byte("\n\n"))
|
||||
if newEventPivot == -1 && !isLastChunk {
|
||||
// 未找到事件结束标识,跳出循环等待更多数据,若是最后一个chunk,不一定有2个换行符
|
||||
break
|
||||
}
|
||||
|
||||
// 提取并处理一个完整的事件
|
||||
eventData := newBufferedBody[:newEventPivot]
|
||||
// log.Debugf("@@@ <<< ori chun is: %s", string(newBufferedBody[:newEventPivot]))
|
||||
newBufferedBody = newBufferedBody[newEventPivot+2:] // 跳过结束标识
|
||||
|
||||
// 转换并追加到输出缓冲区
|
||||
convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData, log)
|
||||
// log.Debugf("@@@ >>> converted one chunk: %s", string(convertedData))
|
||||
outputBuffer = append(outputBuffer, convertedData...)
|
||||
}
|
||||
|
||||
// 刷新剩余的不完整事件回到上下文缓冲区以便下次继续处理
|
||||
ctx.SetContext(ctxKeyStreamingBody, newBufferedBody)
|
||||
|
||||
log.Debugf("=== modified response chunk: %s", string(outputBuffer))
|
||||
return outputBuffer, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte, log wrapper.Log) ([]byte, error) {
|
||||
// 将hunyuan的chunk转为openai的chunk
|
||||
hunyuanFormattedChunk := &hunyuanTextGenDetailedResponseNonStreaming{}
|
||||
if err := json.Unmarshal(hunyuanChunk, hunyuanFormattedChunk); err != nil {
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
openAIFormattedChunk := &chatCompletionResponse{
|
||||
Id: hunyuanFormattedChunk.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletionChunk,
|
||||
Usage: chatCompletionUsage{
|
||||
PromptTokens: hunyuanFormattedChunk.Usage.PromptTokens,
|
||||
CompletionTokens: hunyuanFormattedChunk.Usage.CompletionTokens,
|
||||
TotalTokens: hunyuanFormattedChunk.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
// tmpStr3, _ := json.Marshal(hunyuanFormattedChunk)
|
||||
// log.Debugf("@@@ --- 源数据是:: %s", tmpStr3)
|
||||
|
||||
// 是否为最后一个chunk?
|
||||
if hunyuanFormattedChunk.Choices[0].FinishReason == hunyuanStreamEndMark {
|
||||
// log.Debugf("@@@ --- 最后chunk: ")
|
||||
openAIFormattedChunk.Choices = append(openAIFormattedChunk.Choices, chatCompletionChoice{
|
||||
FinishReason: hunyuanFormattedChunk.Choices[0].FinishReason,
|
||||
})
|
||||
} else {
|
||||
deltaMsg := chatMessage{
|
||||
Name: "",
|
||||
Role: hunyuanFormattedChunk.Choices[0].Delta.Role,
|
||||
Content: hunyuanFormattedChunk.Choices[0].Delta.Content,
|
||||
ToolCalls: []toolCall{},
|
||||
}
|
||||
|
||||
// tmpStr2, _ := json.Marshal(deltaMsg)
|
||||
// log.Debugf("@@@ --- 中间chunk: choices.chatMsg 是: %s", tmpStr2)
|
||||
|
||||
openAIFormattedChunk.Choices = append(
|
||||
openAIFormattedChunk.Choices,
|
||||
chatCompletionChoice{Delta: &deltaMsg},
|
||||
)
|
||||
// tmpStr, _ := json.Marshal(openAIFormattedChunk.Choices)
|
||||
// log.Debugf("@@@ --- 中间chunk: choices 是: %s", tmpStr)
|
||||
}
|
||||
|
||||
// 返回的格式
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
var openAIChunk strings.Builder
|
||||
openAIChunk.WriteString(ssePrefix)
|
||||
openAIChunk.WriteString(string(openAIFormattedChunkBytes))
|
||||
openAIChunk.WriteString("\n\n")
|
||||
|
||||
return []byte(openAIChunk.String()), nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
|
||||
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
|
||||
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
|
||||
}
|
||||
|
||||
if m.config.protocol == protocolOriginal {
|
||||
return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log)
|
||||
}
|
||||
|
||||
response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
|
||||
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
||||
|
||||
fileMessage := hunyuanChatMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
messages := request.Messages
|
||||
request.Messages = append([]hunyuanChatMessage{},
|
||||
append([]hunyuanChatMessage{fileMessage}, messages...)...,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) buildHunyuanTextGenerationRequest(request *chatCompletionRequest) *hunyuanTextGenRequest {
|
||||
hunyuanRequest := &hunyuanTextGenRequest{
|
||||
Model: request.Model,
|
||||
Messages: convertMessagesFromOpenAIToHunyuan(request.Messages),
|
||||
Stream: request.Stream,
|
||||
StreamModeration: false,
|
||||
TopP: float32(request.TopP),
|
||||
Temperature: float32(request.Temperature),
|
||||
EnableEnhancement: false,
|
||||
}
|
||||
|
||||
return hunyuanRequest
|
||||
}
|
||||
|
||||
func convertMessagesFromOpenAIToHunyuan(openAIMessages []chatMessage) []hunyuanChatMessage {
|
||||
// 将chatgpt的messages转换为hunyuan的messages
|
||||
hunyuanChatMessages := make([]hunyuanChatMessage, 0, len(openAIMessages))
|
||||
for _, msg := range openAIMessages {
|
||||
hunyuanChatMessages = append(hunyuanChatMessages, hunyuanChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return hunyuanChatMessages
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, hunyuanResponse *hunyuanTextGenResponseNonStreaming) *chatCompletionResponse {
|
||||
choices := make([]chatCompletionChoice, 0, len(hunyuanResponse.Response.Choices))
|
||||
for _, choice := range hunyuanResponse.Response.Choices {
|
||||
choices = append(choices, chatCompletionChoice{
|
||||
Message: &chatMessage{
|
||||
Name: "",
|
||||
Role: choice.Message.Role,
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: nil,
|
||||
},
|
||||
FinishReason: choice.FinishReason,
|
||||
})
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: hunyuanResponse.Response.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: chatCompletionUsage{
|
||||
PromptTokens: hunyuanResponse.Response.Usage.PromptTokens,
|
||||
CompletionTokens: hunyuanResponse.Response.Usage.CompletionTokens,
|
||||
TotalTokens: hunyuanResponse.Response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Sha256hex(s string) string {
|
||||
b := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func Hmacsha256(s, key string) string {
|
||||
hashed := hmac.New(sha256.New, []byte(key))
|
||||
hashed.Write([]byte(s))
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
/**
|
||||
* @param secretId 秘钥id
|
||||
* @param secretKey 秘钥
|
||||
* @param timestamp 时间戳
|
||||
* @param host 目标域名
|
||||
* @param action 请求动作
|
||||
* @param payload 请求体
|
||||
* @return 签名
|
||||
*/
|
||||
func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64, host string, action string, payload string) string {
|
||||
algorithm := "TC3-HMAC-SHA256"
|
||||
service := "hunyuan" // 注意,必须和域名中的产品名保持一致
|
||||
|
||||
// step 1: build canonical request string
|
||||
httpRequestMethod := "POST"
|
||||
canonicalURI := "/"
|
||||
canonicalQueryString := ""
|
||||
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||
"application/json", host, strings.ToLower(action))
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
|
||||
// fmt.Println("payload is: %s", payload)
|
||||
hashedRequestPayload := Sha256hex(payload)
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
httpRequestMethod,
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
hashedRequestPayload)
|
||||
// fmt.Println(canonicalRequest)
|
||||
|
||||
// step 2: build string to sign
|
||||
date := time.Unix(timestamp, 0).UTC().Format("2006-01-02")
|
||||
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, service)
|
||||
hashedCanonicalRequest := Sha256hex(canonicalRequest)
|
||||
string2sign := fmt.Sprintf("%s\n%d\n%s\n%s",
|
||||
algorithm,
|
||||
timestamp,
|
||||
credentialScope,
|
||||
hashedCanonicalRequest)
|
||||
// fmt.Println(string2sign)
|
||||
|
||||
// step 3: sign string
|
||||
secretDate := Hmacsha256(date, "TC3"+secretKey)
|
||||
secretService := Hmacsha256(service, secretDate)
|
||||
secretSigning := Hmacsha256("tc3_request", secretService)
|
||||
signature := hex.EncodeToString([]byte(Hmacsha256(string2sign, secretSigning)))
|
||||
// fmt.Println(signature)
|
||||
|
||||
// step 4: build authorization
|
||||
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
algorithm,
|
||||
secretId,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature)
|
||||
|
||||
// curl := fmt.Sprintf(`curl -X POST https://%s \
|
||||
// -H "Authorization: %s" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -H "Host: %s" -H "X-TC-Action: %s" \
|
||||
// -H "X-TC-Timestamp: %d" \
|
||||
// -H "X-TC-Version: 2023-09-01" \
|
||||
// -d '%s'`, host, authorization, host, action, timestamp, payload)
|
||||
// fmt.Println(curl)
|
||||
return authorization
|
||||
}
|
||||
472
plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Normal file
472
plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// minimaxProvider is the provider for minimax service.
|
||||
|
||||
const (
|
||||
minimaxDomain = "api.minimax.chat"
|
||||
// minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同
|
||||
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
|
||||
minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2"
|
||||
// minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同
|
||||
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e
|
||||
minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro"
|
||||
|
||||
senderTypeUser string = "USER" // 用户发送的内容
|
||||
senderTypeBot string = "BOT" // 模型生成的内容
|
||||
|
||||
// 默认机器人设置
|
||||
defaultBotName string = "MM智能助理"
|
||||
defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
|
||||
defaultSenderName string = "小明"
|
||||
)
|
||||
|
||||
// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro
|
||||
var chatCompletionProModels = map[string]struct{}{
|
||||
"abab6.5-chat": {},
|
||||
"abab6.5s-chat": {},
|
||||
"abab5.5s-chat": {},
|
||||
"abab5.5-chat": {},
|
||||
}
|
||||
|
||||
type minimaxProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
// 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId
|
||||
if len(config.modelMapping) > 0 && config.minimaxGroupId == "" {
|
||||
for _, minimaxModel := range config.modelMapping {
|
||||
if _, exists := chatCompletionProModels[minimaxModel]; exists {
|
||||
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &minimaxProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type minimaxProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) GetProviderType() string {
|
||||
return providerTypeMinimax
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(minimaxDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 解析并映射模型,设置上下文
|
||||
model, err := m.parseModel(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
_, ok := chatCompletionProModels[mappedModel]
|
||||
if ok {
|
||||
// 使用ChatCompletion Pro接口
|
||||
return m.handleRequestBodyByChatCompletionPro(body, log)
|
||||
} else {
|
||||
// 使用ChatCompletion v2接口
|
||||
return m.handleRequestBodyByChatCompletionV2(body, log)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
// 使用minimax接口协议
|
||||
if m.config.protocol == protocolOriginal {
|
||||
request := &minimaxChatCompletionV2Request{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
if m.config.minimaxGroupId == "" {
|
||||
return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
|
||||
}
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
|
||||
if m.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
m.setBotSettings(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
|
||||
if m.config.context == nil {
|
||||
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, "")
|
||||
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
|
||||
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
_ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
|
||||
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// 使用minimax接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
// 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
|
||||
model := ctx.GetContext(ctxKeyFinalRequestModel)
|
||||
if model != nil {
|
||||
_, ok := chatCompletionProModels[model.(string)]
|
||||
if !ok {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
}
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
|
||||
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// sample event response:
|
||||
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
|
||||
|
||||
// sample end event response:
|
||||
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}}
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
var minimaxResp minimaxChatCompletionV2Resp
|
||||
if err := json.Unmarshal([]byte(data), &minimaxResp); err != nil {
|
||||
log.Errorf("unable to unmarshal minimax response: %v", err)
|
||||
continue
|
||||
}
|
||||
response := m.responseV2ToOpenAI(&minimaxResp)
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
log.Errorf("unable to marshal response: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
m.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
|
||||
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
minimaxResp := &minimaxChatCompletionV2Resp{}
|
||||
if err := json.Unmarshal(body, minimaxResp); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
||||
}
|
||||
if minimaxResp.BaseResp.StatusCode != 0 {
|
||||
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
||||
}
|
||||
response := m.responseV2ToOpenAI(minimaxResp)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体
|
||||
type minimaxChatCompletionV2Request struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
TokensToGenerate int64 `json:"tokens_to_generate,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true
|
||||
Messages []minimaxMessage `json:"messages"`
|
||||
BotSettings []minimaxBotSetting `json:"bot_setting"`
|
||||
ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"`
|
||||
}
|
||||
|
||||
// minimaxMessage 表示对话中的消息
|
||||
type minimaxMessage struct {
|
||||
SenderType string `json:"sender_type"`
|
||||
SenderName string `json:"sender_name"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// minimaxBotSetting 表示机器人的设置
|
||||
type minimaxBotSetting struct {
|
||||
BotName string `json:"bot_name"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// minimaxReplyConstraints 表示模型回复要求
|
||||
type minimaxReplyConstraints struct {
|
||||
SenderType string `json:"sender_type"`
|
||||
SenderName string `json:"sender_name"`
|
||||
}
|
||||
|
||||
// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体
|
||||
type minimaxChatCompletionV2Resp struct {
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Reply string `json:"reply"`
|
||||
InputSensitive bool `json:"input_sensitive,omitempty"`
|
||||
InputSensitiveType int64 `json:"input_sensitive_type,omitempty"`
|
||||
OutputSensitive bool `json:"output_sensitive,omitempty"`
|
||||
OutputSensitiveType int64 `json:"output_sensitive_type,omitempty"`
|
||||
Choices []minimaxChoice `json:"choices,omitempty"`
|
||||
Usage minimaxUsage `json:"usage,omitempty"`
|
||||
Id string `json:"id"`
|
||||
BaseResp minimaxBaseResp `json:"base_resp"`
|
||||
}
|
||||
|
||||
// minimaxBaseResp 包含错误状态码和详情
|
||||
type minimaxBaseResp struct {
|
||||
StatusCode int64 `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
}
|
||||
|
||||
// minimaxChoice 结果选项
|
||||
type minimaxChoice struct {
|
||||
Messages []minimaxMessage `json:"messages"`
|
||||
Index int64 `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// minimaxUsage 令牌使用情况
|
||||
type minimaxUsage struct {
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) parseModel(body []byte) (string, error) {
|
||||
var tempMap map[string]interface{}
|
||||
if err := json.Unmarshal(body, &tempMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
model, ok := tempMap["model"].(string)
|
||||
if !ok {
|
||||
return "", errors.New("missing model in chat completion request")
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionV2Request, botSettingContent string) {
|
||||
if len(request.BotSettings) == 0 {
|
||||
request.BotSettings = []minimaxBotSetting{
|
||||
{
|
||||
BotName: defaultBotName,
|
||||
Content: func() string {
|
||||
if botSettingContent != "" {
|
||||
return botSettingContent
|
||||
}
|
||||
return defaultBotSettingContent
|
||||
}(),
|
||||
},
|
||||
}
|
||||
} else if botSettingContent != "" {
|
||||
newSetting := minimaxBotSetting{
|
||||
BotName: request.BotSettings[0].BotName,
|
||||
Content: botSettingContent,
|
||||
}
|
||||
request.BotSettings = append([]minimaxBotSetting{newSetting}, request.BotSettings...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompletionRequest, botSettingContent string) *minimaxChatCompletionV2Request {
|
||||
var messages []minimaxMessage
|
||||
var botSetting []minimaxBotSetting
|
||||
var botName string
|
||||
|
||||
determineName := func(name string, defaultName string) string {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return defaultName
|
||||
}
|
||||
|
||||
for _, message := range request.Messages {
|
||||
switch message.Role {
|
||||
case roleSystem:
|
||||
botName = determineName(message.Name, defaultBotName)
|
||||
botSetting = append(botSetting, minimaxBotSetting{
|
||||
BotName: botName,
|
||||
Content: message.Content,
|
||||
})
|
||||
case roleAssistant:
|
||||
messages = append(messages, minimaxMessage{
|
||||
SenderType: senderTypeBot,
|
||||
SenderName: determineName(message.Name, defaultBotName),
|
||||
Text: message.Content,
|
||||
})
|
||||
case roleUser:
|
||||
messages = append(messages, minimaxMessage{
|
||||
SenderType: senderTypeUser,
|
||||
SenderName: determineName(message.Name, defaultSenderName),
|
||||
Text: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
replyConstraints := minimaxReplyConstraints{
|
||||
SenderType: senderTypeBot,
|
||||
SenderName: determineName(botName, defaultBotName),
|
||||
}
|
||||
result := &minimaxChatCompletionV2Request{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
TokensToGenerate: int64(request.MaxTokens),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaskSensitiveInfo: true,
|
||||
Messages: messages,
|
||||
BotSettings: botSetting,
|
||||
ReplyConstraints: replyConstraints,
|
||||
}
|
||||
|
||||
m.setBotSettings(result, botSettingContent)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Resp) *chatCompletionResponse {
|
||||
var choices []chatCompletionChoice
|
||||
messageIndex := 0
|
||||
for _, choice := range response.Choices {
|
||||
for _, message := range choice.Messages {
|
||||
message := &chatMessage{
|
||||
Name: message.SenderName,
|
||||
Role: roleAssistant,
|
||||
Content: message.Text,
|
||||
}
|
||||
choices = append(choices, chatCompletionChoice{
|
||||
FinishReason: choice.FinishReason,
|
||||
Index: messageIndex,
|
||||
Message: message,
|
||||
})
|
||||
messageIndex++
|
||||
}
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: response.Id,
|
||||
Object: objectChatCompletion,
|
||||
Created: response.Created,
|
||||
Model: response.Model,
|
||||
Choices: choices,
|
||||
Usage: chatCompletionUsage{
|
||||
TotalTokens: int(response.Usage.TotalTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
@@ -30,6 +30,7 @@ type chatCompletionRequest struct {
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type streamOptions struct {
|
||||
@@ -54,7 +55,7 @@ type toolChoice struct {
|
||||
|
||||
type chatCompletionResponse struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Choices []chatCompletionChoice `json:"choices,omitempty"`
|
||||
Choices []chatCompletionChoice `json:"choices"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
@@ -102,14 +103,15 @@ func (m *chatMessage) IsEmpty() bool {
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function functionCall `json:"function"`
|
||||
}
|
||||
|
||||
type functionCall struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
|
||||
114
plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
Normal file
114
plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// ollamaProvider is the provider for Ollama service.
|
||||
|
||||
const (
|
||||
ollamaChatCompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
type ollamaProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *ollamaProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if config.ollamaServerHost == "" {
|
||||
return errors.New("missing ollamaServerHost in provider config")
|
||||
}
|
||||
if config.ollamaServerPort == 0 {
|
||||
return errors.New("missing ollamaServerPort in provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort)
|
||||
serviceDomain := config.ollamaServerHost + ":" + serverPortStr
|
||||
return &ollamaProvider{
|
||||
config: config,
|
||||
serviceDomain: serviceDomain,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ollamaProvider struct {
|
||||
config ProviderConfig
|
||||
serviceDomain string
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) GetProviderType() string {
|
||||
return providerTypeOllama
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(ollamaChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(m.serviceDomain)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
if m.config.modelMapping == nil && m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
|
||||
if m.contextCache != nil {
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
} else {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
} else {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
}
|
||||
@@ -23,13 +23,26 @@ const (
|
||||
providerTypeBaichuan = "baichuan"
|
||||
providerTypeYi = "yi"
|
||||
providerTypeDeepSeek = "deepseek"
|
||||
providerTypeZhipuAi = "zhipuai"
|
||||
providerTypeOllama = "ollama"
|
||||
providerTypeClaude = "claude"
|
||||
providerTypeBaidu = "baidu"
|
||||
providerTypeHunyuan = "hunyuan"
|
||||
providerTypeStepfun = "stepfun"
|
||||
providerTypeMinimax = "minimax"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
|
||||
roleSystem = "system"
|
||||
roleSystem = "system"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiName = "apiKey"
|
||||
ctxKeyStreamingBody = "streamingBody"
|
||||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||||
ctxKeyFinalRequestModel = "finalRequestModel"
|
||||
@@ -60,6 +73,13 @@ var (
|
||||
providerTypeBaichuan: &baichuanProviderInitializer{},
|
||||
providerTypeYi: &yiProviderInitializer{},
|
||||
providerTypeDeepSeek: &deepseekProviderInitializer{},
|
||||
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
|
||||
providerTypeOllama: &ollamaProviderInitializer{},
|
||||
providerTypeClaude: &claudeProviderInitializer{},
|
||||
providerTypeBaidu: &baiduProviderInitializer{},
|
||||
providerTypeHunyuan: &hunyuanProviderInitializer{},
|
||||
providerTypeStepfun: &stepfunProviderInitializer{},
|
||||
providerTypeMinimax: &minimaxProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -89,7 +109,7 @@ type ResponseBodyHandler interface {
|
||||
|
||||
type ProviderConfig struct {
|
||||
// @Title zh-CN AI服务提供商
|
||||
// @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"
|
||||
// @Description zh-CN AI服务提供商类型
|
||||
typ string `required:"true" yaml:"type" json:"type"`
|
||||
// @Title zh-CN API Tokens
|
||||
// @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。
|
||||
@@ -109,6 +129,21 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 启用通义千问搜索服务
|
||||
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
|
||||
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
|
||||
// @Title zh-CN Ollama Server IP/Domain
|
||||
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。
|
||||
ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"`
|
||||
// @Title zh-CN Ollama Server Port
|
||||
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的端口号。
|
||||
ollamaServerPort uint32 `required:"false" yaml:"ollamaServerPort" json:"ollamaServerPort"`
|
||||
// @Title zh-CN hunyuan api key for authorization
|
||||
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权,API key/id 参考:https://cloud.tencent.com/document/api/1729/101843#Golang
|
||||
hunyuanAuthKey string `required:"false" yaml:"hunyuanAuthKey" json:"hunyuanAuthKey"`
|
||||
// @Title zh-CN hunyuan api id for authorization
|
||||
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
|
||||
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
|
||||
// @Title zh-CN minimax group id
|
||||
// @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型
|
||||
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
|
||||
// @Title zh-CN 模型名称映射表
|
||||
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
|
||||
modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"`
|
||||
@@ -118,6 +153,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 模型对话上下文
|
||||
// @Description zh-CN 配置一个外部获取对话上下文的文件来源,用于在AI请求中补充对话上下文
|
||||
context *ContextConfig `required:"false" yaml:"context" json:"context"`
|
||||
// @Title zh-CN 版本
|
||||
// @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务
|
||||
claudeVersion string `required:"false" yaml:"version" json:"version"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
@@ -137,6 +175,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
|
||||
}
|
||||
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
|
||||
c.ollamaServerHost = json.Get("ollamaServerHost").String()
|
||||
c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint())
|
||||
c.modelMapping = make(map[string]string)
|
||||
for k, v := range json.Get("modelMapping").Map() {
|
||||
c.modelMapping[k] = v.String()
|
||||
@@ -150,6 +190,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.context = &ContextConfig{}
|
||||
c.context.FromJson(contextJson)
|
||||
}
|
||||
c.claudeVersion = json.Get("claudeVersion").String()
|
||||
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
|
||||
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
|
||||
@@ -26,6 +26,8 @@ const (
|
||||
qwenTopPMax = 0.999999
|
||||
|
||||
qwenDummySystemMessageContent = "You are a helpful assistant."
|
||||
|
||||
qwenLongModelName = "qwen-long"
|
||||
)
|
||||
|
||||
type qwenProviderInitializer struct {
|
||||
@@ -99,7 +101,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
m.insertContextMessage(request, content)
|
||||
m.insertContextMessage(request, content, false)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
@@ -292,7 +294,7 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
|
||||
Tools: origRequest.Tools,
|
||||
},
|
||||
}
|
||||
if len(m.config.qwenFileIds) != 0 {
|
||||
if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName {
|
||||
builder := strings.Builder{}
|
||||
for _, fileId := range m.config.qwenFileIds {
|
||||
if builder.Len() != 0 {
|
||||
@@ -301,7 +303,7 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
|
||||
builder.WriteString("fileid://")
|
||||
builder.WriteString(fileId)
|
||||
}
|
||||
contextMessageId := m.insertContextMessage(request, builder.String())
|
||||
contextMessageId := m.insertContextMessage(request, builder.String(), true)
|
||||
if contextMessageId == 0 {
|
||||
// The context message cannot come first. We need to add another dummy system message before it.
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
|
||||
@@ -339,6 +341,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
Id: qwenResponse.RequestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
Choices: make([]chatCompletionChoice, 0),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletionChunk,
|
||||
}
|
||||
@@ -346,14 +349,26 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
responses := make([]*chatCompletionResponse, 0)
|
||||
|
||||
qwenChoice := qwenResponse.Output.Choices[0]
|
||||
// Yes, Qwen uses a string "null" as null.
|
||||
finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null"
|
||||
message := qwenChoice.Message
|
||||
|
||||
deltaMessage := &chatMessage{Role: message.Role, Content: message.Content, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
|
||||
deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content}
|
||||
deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
|
||||
if !incrementalStreaming {
|
||||
for _, tc := range message.ToolCalls {
|
||||
if tc.Function.Arguments == "" && !finished {
|
||||
// We don't push any tool call until its arguments are available.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok {
|
||||
deltaMessage.Content = util.StripPrefix(deltaMessage.Content, pushedMessage.Content)
|
||||
if len(deltaMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
|
||||
for i, tc := range deltaMessage.ToolCalls {
|
||||
if message.Content == "" {
|
||||
message.Content = pushedMessage.Content
|
||||
}
|
||||
deltaContentMessage.Content = util.StripPrefix(deltaContentMessage.Content, pushedMessage.Content)
|
||||
if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
|
||||
for i, tc := range deltaToolCallsMessage.ToolCalls {
|
||||
if i >= len(pushedMessage.ToolCalls) {
|
||||
break
|
||||
}
|
||||
@@ -361,24 +376,36 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
tc.Function.Id = util.StripPrefix(tc.Function.Id, pushedFunction.Id)
|
||||
tc.Function.Name = util.StripPrefix(tc.Function.Name, pushedFunction.Name)
|
||||
tc.Function.Arguments = util.StripPrefix(tc.Function.Arguments, pushedFunction.Arguments)
|
||||
deltaMessage.ToolCalls[i] = tc
|
||||
deltaToolCallsMessage.ToolCalls[i] = tc
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx.SetContext(ctxKeyPushedMessage, message)
|
||||
}
|
||||
|
||||
if !deltaMessage.IsEmpty() {
|
||||
deltaResponse := *&baseMessage
|
||||
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: deltaMessage})
|
||||
responses = append(responses, &deltaResponse)
|
||||
if !deltaContentMessage.IsEmpty() {
|
||||
response := *&baseMessage
|
||||
response.Choices = append(response.Choices, chatCompletionChoice{Delta: deltaContentMessage})
|
||||
responses = append(responses, &response)
|
||||
}
|
||||
if !deltaToolCallsMessage.IsEmpty() {
|
||||
response := *&baseMessage
|
||||
response.Choices = append(response.Choices, chatCompletionChoice{Delta: deltaToolCallsMessage})
|
||||
responses = append(responses, &response)
|
||||
}
|
||||
|
||||
// Yes, Qwen uses a string "null" as null.
|
||||
if qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null" {
|
||||
if finished {
|
||||
finishResponse := *&baseMessage
|
||||
finishResponse.Choices = append(finishResponse.Choices, chatCompletionChoice{FinishReason: qwenChoice.FinishReason})
|
||||
responses = append(responses, &finishResponse)
|
||||
|
||||
usageResponse := *&baseMessage
|
||||
usageResponse.Usage = chatCompletionUsage{
|
||||
PromptTokens: qwenResponse.Usage.InputTokens,
|
||||
CompletionTokens: qwenResponse.Usage.OutputTokens,
|
||||
TotalTokens: qwenResponse.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
responses = append(responses, &finishResponse, &usageResponse)
|
||||
}
|
||||
|
||||
return responses
|
||||
@@ -417,12 +444,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) int {
|
||||
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int {
|
||||
fileMessage := qwenMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
firstNonSystemMessageIndex := -1
|
||||
var firstNonSystemMessageIndex int
|
||||
messages := request.Input.Messages
|
||||
if messages != nil {
|
||||
for i, message := range request.Input.Messages {
|
||||
@@ -432,12 +459,22 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
|
||||
}
|
||||
}
|
||||
}
|
||||
if firstNonSystemMessageIndex == -1 {
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
|
||||
return 0
|
||||
} else {
|
||||
} else if !onlyOneSystemBeforeFile {
|
||||
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
return firstNonSystemMessageIndex
|
||||
} else {
|
||||
builder := strings.Builder{}
|
||||
for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] {
|
||||
if builder.Len() != 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(message.Content)
|
||||
}
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,14 +37,14 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
firstNonSystemMessageIndex := -1
|
||||
var firstNonSystemMessageIndex int
|
||||
for i, message := range request.Messages {
|
||||
if message.Role != roleSystem {
|
||||
firstNonSystemMessageIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if firstNonSystemMessageIndex == -1 {
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
|
||||
} else {
|
||||
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
|
||||
85
plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
Normal file
85
plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
const (
|
||||
stepfunDomain = "api.stepfun.com"
|
||||
stepfunChatCompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
type stepfunProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *stepfunProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &stepfunProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type stepfunProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) GetProviderType() string {
|
||||
return providerTypeStepfun
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(stepfunChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(stepfunDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken())
|
||||
|
||||
if m.contextCache == nil {
|
||||
ctx.DontReadRequestBody()
|
||||
} else {
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
}
|
||||
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
84
plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
Normal file
84
plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
const (
|
||||
zhipuAiDomain = "open.bigmodel.cn"
|
||||
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
|
||||
)
|
||||
|
||||
type zhipuAiProviderInitializer struct{}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &zhipuAiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type zhipuAiProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) GetProviderType() string {
|
||||
return providerTypeZhipuAi
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(zhipuAiChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(zhipuAiDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken())
|
||||
|
||||
if m.contextCache == nil {
|
||||
ctx.DontReadRequestBody()
|
||||
} else {
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
}
|
||||
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
main.wasm
|
||||
tmp/
|
||||
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# 简介
|
||||
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
||||
|
||||

|
||||
|
||||
# 配置说明
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
||||
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
|
||||
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
||||
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
||||
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
||||
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||
|
||||
# 示例
|
||||
|
||||
```yaml
|
||||
dashscope:
|
||||
apiKey: xxxxxxxxxxxxxxx
|
||||
serviceName: dashscope
|
||||
servicePort: 443
|
||||
domain: dashscope.aliyuncs.com
|
||||
dashvector:
|
||||
apiKey: xxxxxxxxxxxxxxxxxxxx
|
||||
serviceName: dashvector
|
||||
servicePort: 443
|
||||
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||
collection: xxxxxxxxxxxxxxx
|
||||
```
|
||||
|
||||
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
||||
|
||||
以下为使用RAG进行增强的例子,原始请求为:
|
||||
```
|
||||
海南追尾事故,发生在哪里?原因是什么?人员伤亡情况如何?
|
||||
```
|
||||
|
||||
未经过RAG插件处理LLM返回的结果为:
|
||||
```
|
||||
抱歉,作为AI模型,我无法实时获取和更新新闻事件的具体信息,包括地点、原因、人员伤亡等细节。对于此类具体事件,建议您查阅最新的新闻报道或官方通报以获取准确信息。您可以访问主流媒体网站、使用新闻应用或者关注相关政府部门的公告来获取这类动态资讯。
|
||||
```
|
||||
|
||||
经过RAG插件处理后LLM返回的结果为:
|
||||
```
|
||||
海南追尾事故发生在海文高速公路文昌至海口方向37公里处。关于事故的具体原因,交警部门当时仍在进一步调查中,所以根据提供的信息无法确定事故的确切原因。人员伤亡情况是1人死亡(司机当场死亡),另有8人受伤(包括2名儿童和6名成人),所有受伤人员都被解救并送往医院进行治疗。
|
||||
```
|
||||
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package dashscope
|
||||
|
||||
// DashScope embedding service: Request
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameter Parameter `json:"parameters"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
TextType string `json:"text_type"`
|
||||
}
|
||||
|
||||
// DashScope embedding service: Response
|
||||
type Response struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
TextIndex int32 `json:"text_index"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
TotalTokens int32 `json:"total_tokens"`
|
||||
}
|
||||
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package dashvector
|
||||
|
||||
// DashVecotor document search: Request
|
||||
type Request struct {
|
||||
TopK int32 `json:"topk"`
|
||||
OutputFileds []string `json:"output_fileds"`
|
||||
Vector []float32 `json:"vector"`
|
||||
}
|
||||
|
||||
// DashVecotor document search: Response
|
||||
type Response struct {
|
||||
Code int32 `json:"code"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
Output []OutputObject `json:"output"`
|
||||
}
|
||||
|
||||
type OutputObject struct {
|
||||
ID string `json:"id"`
|
||||
Fields FieldObject `json:"fields"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
type FieldObject struct {
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
18
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
18
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
@@ -0,0 +1,18 @@
|
||||
module ai-rag
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
)
|
||||
22
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
22
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
@@ -0,0 +1,22 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"ai-rag/dashscope"
|
||||
"ai-rag/dashvector"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-rag",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AIRagConfig struct {
|
||||
DashScopeClient wrapper.HttpClient
|
||||
DashScopeAPIKey string
|
||||
DashVectorClient wrapper.HttpClient
|
||||
DashVectorAPIKey string
|
||||
DashVectorCollection string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Topp int32 `json:"top_p"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashscope.serviceName").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Domain: json.Get("dashscope.domain").String(),
|
||||
})
|
||||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashvector.serviceName").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Domain: json.Get("dashvector.domain").String(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
if p != "/api/openai/v1/chat/completions" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
|
||||
var rawRequest Request
|
||||
_ = json.Unmarshal(body, &rawRequest)
|
||||
messageLength := len(rawRequest.Messages)
|
||||
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||
requestEmbedding := dashscope.Request{
|
||||
Model: "text-embedding-v1",
|
||||
Input: dashscope.Input{
|
||||
Texts: []string{rawContent},
|
||||
},
|
||||
Parameter: dashscope.Parameter{
|
||||
TextType: "query",
|
||||
},
|
||||
}
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||
// log.Info(string(reqEmbeddingSerialized))
|
||||
config.DashScopeClient.Post(
|
||||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
headers,
|
||||
reqEmbeddingSerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var responseEmbedding dashscope.Response
|
||||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||
requestQuery := dashvector.Request{
|
||||
TopK: 1,
|
||||
OutputFileds: []string{"raw"},
|
||||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||
}
|
||||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||
config.DashVectorClient.Post(
|
||||
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
|
||||
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
|
||||
requestQuerySerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var response dashvector.Response
|
||||
_ = json.Unmarshal(responseBody, &response)
|
||||
doc := response.Output[0].Fields.Raw
|
||||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
// log.Info(string(newBody))
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
},
|
||||
)
|
||||
},
|
||||
50000,
|
||||
)
|
||||
return types.ActionPause
|
||||
}
|
||||
4
plugins/wasm-go/extensions/ai-security-guard/.gitignore
vendored
Normal file
4
plugins/wasm-go/extensions/ai-security-guard/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
main.wasm
|
||||
v1/
|
||||
v2/
|
||||
config.yaml
|
||||
22
plugins/wasm-go/extensions/ai-security-guard/README.md
Normal file
22
plugins/wasm-go/extensions/ai-security-guard/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# 简介
|
||||
|
||||
# 配置说明
|
||||
| Name | Type | Requirement | Default | Description |
|
||||
| :-: | :-: | :-: | :-: | :-: |
|
||||
| serviceSource | string | requried | - | 服务来源,填dns |
|
||||
| serviceName | string | requried | - | 服务名 |
|
||||
| servicePort | string | requried | - | 服务端口 |
|
||||
| domain | string | requried | - | 阿里云内容安全endpoint |
|
||||
| ak | string | requried | - | 阿里云AK |
|
||||
| sk | string | requried | - | 阿里云SK |
|
||||
|
||||
|
||||
# 配置示例
|
||||
```yaml
|
||||
serviceSource: "dns"
|
||||
serviceName: "safecheck"
|
||||
servicePort: 443
|
||||
domain: "green-cip.cn-shanghai.aliyuncs.com"
|
||||
ak: "XXXXXXXXX"
|
||||
sk: "XXXXXXXXXXXXXXX"
|
||||
```
|
||||
18
plugins/wasm-go/extensions/ai-security-guard/go.mod
Normal file
18
plugins/wasm-go/extensions/ai-security-guard/go.mod
Normal file
@@ -0,0 +1,18 @@
|
||||
module myplugin
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
)
|
||||
24
plugins/wasm-go/extensions/ai-security-guard/go.sum
Normal file
24
plugins/wasm-go/extensions/ai-security-guard/go.sum
Normal file
@@ -0,0 +1,24 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
272
plugins/wasm-go/extensions/ai-security-guard/main.go
Normal file
272
plugins/wasm-go/extensions/ai-security-guard/main.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-security-guard",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AISecurityConfig struct {
|
||||
client wrapper.HttpClient
|
||||
ak string
|
||||
sk string
|
||||
}
|
||||
|
||||
type StandardResponse struct {
|
||||
Code int `json:"Code"`
|
||||
Phase string `json:"BlockPhase"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
func urlEncoding(rawStr string) string {
|
||||
encodedStr := url.PathEscape(rawStr)
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
||||
return encodedStr
|
||||
}
|
||||
|
||||
func hmacSha1(message, secret string) string {
|
||||
key := []byte(secret)
|
||||
h := hmac.New(sha1.New, key)
|
||||
h.Write([]byte(message))
|
||||
hash := h.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func getSign(params map[string]string, secret string) string {
|
||||
paramArray := []string{}
|
||||
for k, v := range params {
|
||||
paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v))
|
||||
}
|
||||
sort.Slice(paramArray, func(i, j int) bool {
|
||||
return paramArray[i] <= paramArray[j]
|
||||
})
|
||||
canonicalStr := strings.Join(paramArray, "&")
|
||||
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
|
||||
fmt.Println(signStr)
|
||||
return hmacSha1(signStr, secret)
|
||||
}
|
||||
|
||||
func generateHexID(length int) (string, error) {
|
||||
bytes := make([]byte, length/2)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
||||
serviceName := json.Get("serviceName").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
domain := json.Get("domain").String()
|
||||
config.ak = json.Get("ak").String()
|
||||
config.sk = json.Get("sk").String()
|
||||
if serviceName == "" || servicePort == 0 || domain == "" {
|
||||
return errors.New("invalid service config")
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: serviceName,
|
||||
Port: servicePort,
|
||||
Domain: domain,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
if len(messages) > 0 {
|
||||
role := messages[len(messages)-1].Get("role").String()
|
||||
content := messages[len(messages)-1].Get("content").String()
|
||||
if role != "user" {
|
||||
return types.ActionContinue
|
||||
}
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": "llm_query_moderation",
|
||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
if respData.Exists() {
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
if respAdvice.Exists() {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Request",
|
||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Request",
|
||||
Message: "risk detected",
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
},
|
||||
)
|
||||
return types.ActionPause
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func convertHeaders(hs [][2]string) map[string][]string {
|
||||
ret := make(map[string][]string)
|
||||
for _, h := range hs {
|
||||
k, v := strings.ToLower(h[0]), h[1]
|
||||
ret[k] = append(ret[k], v)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// headers: map[string][]string -> [][2]string
|
||||
func reconvertHeaders(hs map[string][]string) [][2]string {
|
||||
var ret [][2]string
|
||||
for k, vs := range hs {
|
||||
for _, v := range vs {
|
||||
ret = append(ret, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
sort.SliceStable(ret, func(i, j int) bool {
|
||||
return ret[i][0] < ret[j][0]
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get response headers: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
hdsMap := convertHeaders(headers)
|
||||
ctx.SetContext("headers", hdsMap)
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
messages := gjson.GetBytes(body, "choices").Array()
|
||||
if len(messages) > 0 {
|
||||
content := messages[0].Get("message").Get("content").String()
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
randomID, _ := generateHexID(16)
|
||||
params := map[string]string{
|
||||
"Format": "JSON",
|
||||
"Version": "2022-03-02",
|
||||
"SignatureMethod": "Hmac-SHA1",
|
||||
"SignatureNonce": randomID,
|
||||
"SignatureVersion": "1.0",
|
||||
"Action": "TextModerationPlus",
|
||||
"AccessKeyId": config.ak,
|
||||
"Timestamp": timestamp,
|
||||
"Service": "llm_response_moderation",
|
||||
"ServiceParameters": `{"content": "` + content + `"}`,
|
||||
}
|
||||
signature := getSign(params, config.sk+"&")
|
||||
reqParams := url.Values{}
|
||||
for k, v := range params {
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
defer proxywasm.ResumeHttpResponse()
|
||||
respData := gjson.GetBytes(responseBody, "Data")
|
||||
if respData.Exists() {
|
||||
respAdvice := respData.Get("Advice")
|
||||
respResult := respData.Get("Result")
|
||||
if respAdvice.Exists() {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Response",
|
||||
Message: respAdvice.Array()[0].Get("Answer").String(),
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{"403"}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
|
||||
sr := StandardResponse{
|
||||
Code: 403,
|
||||
Phase: "Response",
|
||||
Message: "risk detected",
|
||||
}
|
||||
jsonData, _ := json.MarshalIndent(sr, "", " ")
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{"403"}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
return types.ActionPause
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
2
plugins/wasm-go/extensions/ai-statistics/.gitignore
vendored
Normal file
2
plugins/wasm-go/extensions/ai-statistics/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
main.wasm
|
||||
config.yaml
|
||||
44
plugins/wasm-go/extensions/ai-statistics/README.md
Normal file
44
plugins/wasm-go/extensions/ai-statistics/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# 介绍
|
||||
提供AI可观测基础能力,其后需接ai-proxy插件,如果不接ai-proxy插件的话,则只支持openai协议。
|
||||
|
||||
# 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------|--------|------|-----|------------------|
|
||||
| `enable` | bool | 必填 | - | 是否开启ai统计功能 |
|
||||
|
||||
开启后 metrics 示例:
|
||||
```
|
||||
route_upstream_model_input_token{ai_route="openai",ai_cluster="qwen",ai_model="qwen-max"} 21
|
||||
route_upstream_model_output_token{ai_route="openai",ai_cluster="qwen",ai_model="qwen-max"} 17
|
||||
```
|
||||
|
||||
日志示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-max",
|
||||
"input_token": "21",
|
||||
"output_token": "17",
|
||||
"authority": "dashscope.aliyuncs.com",
|
||||
"bytes_received": "336",
|
||||
"bytes_sent": "1675",
|
||||
"duration": "1590",
|
||||
"istio_policy_status": "-",
|
||||
"method": "POST",
|
||||
"path": "/v1/chat/completions",
|
||||
"protocol": "HTTP/1.1",
|
||||
"request_id": "5895f5a9-e4e3-425b-98db-6c6a926195b7",
|
||||
"requested_server_name": "-",
|
||||
"response_code": "200",
|
||||
"response_flags": "-",
|
||||
"route_name": "openai",
|
||||
"start_time": "2024-06-18T09:37:14.078Z",
|
||||
"trace_id": "-",
|
||||
"upstream_cluster": "qwen",
|
||||
"upstream_service_time": "496",
|
||||
"upstream_transport_failure_reason": "-",
|
||||
"user_agent": "PostmanRuntime/7.37.3",
|
||||
"x_forwarded_for": "-"
|
||||
}
|
||||
```
|
||||
21
plugins/wasm-go/extensions/ai-statistics/go.mod
Normal file
21
plugins/wasm-go/extensions/ai-statistics/go.mod
Normal file
@@ -0,0 +1,21 @@
|
||||
module ai-statistics
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require github.com/tetratelabs/wazero v1.7.1 // indirect
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/wasilibs/go-re2 v1.5.3
|
||||
)
|
||||
26
plugins/wasm-go/extensions/ai-statistics/go.sum
Normal file
26
plugins/wasm-go/extensions/ai-statistics/go.sum
Normal file
@@ -0,0 +1,26 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
|
||||
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/wasilibs/go-re2 v1.5.3 h1:wiuTcgDZdLhu8NG8oqF5sF5Q3yIU14lPAvXqeYzDK3g=
|
||||
github.com/wasilibs/go-re2 v1.5.3/go.mod h1:PzpVPsBdFC7vM8QJbbEnOeTmwA0DGE783d/Gex8eCV8=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
143
plugins/wasm-go/extensions/ai-statistics/main.go
Normal file
143
plugins/wasm-go/extensions/ai-statistics/main.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-statistics",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AIStatisticsConfig struct {
|
||||
enable bool
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
}
|
||||
|
||||
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64, log wrapper.Log) {
|
||||
counter, ok := config.metrics[metricName]
|
||||
if !ok {
|
||||
counter = proxywasm.DefineCounterMetric(metricName)
|
||||
config.metrics[metricName] = counter
|
||||
}
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIStatisticsConfig, log wrapper.Log) error {
|
||||
config.enable = json.Get("enable").Bool()
|
||||
config.metrics = make(map[string]proxywasm.MetricCounter)
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
||||
if !config.enable {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if !strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.BufferResponseBody()
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func getLastChunk(data []byte) []byte {
|
||||
chunks := strings.Split(strings.TrimSpace(string(data)), "\n\n")
|
||||
length := len(chunks)
|
||||
if length < 2 {
|
||||
return data
|
||||
}
|
||||
// ai-proxy append extra usage chunk
|
||||
return []byte(chunks[length-1])
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
lastChunk := getLastChunk(data)
|
||||
modelObj := gjson.GetBytes(lastChunk, "model")
|
||||
inputTokenObj := gjson.GetBytes(lastChunk, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(lastChunk, "usage.completion_tokens")
|
||||
if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() {
|
||||
ctx.SetContext("model", modelObj.String())
|
||||
ctx.SetContext("input_token", inputTokenObj.Int())
|
||||
ctx.SetContext("output_token", outputTokenObj.Int())
|
||||
}
|
||||
|
||||
if endOfStream {
|
||||
var route, cluster string
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil {
|
||||
route = string(raw)
|
||||
}
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err == nil {
|
||||
cluster = string(raw)
|
||||
}
|
||||
model, ok := ctx.GetContext("model").(string)
|
||||
if !ok {
|
||||
log.Error("Get model failed!")
|
||||
return data
|
||||
}
|
||||
inputToken, ok := ctx.GetContext("input_token").(int64)
|
||||
if !ok {
|
||||
log.Error("Get input_token failed!")
|
||||
return data
|
||||
}
|
||||
outputToken, ok := ctx.GetContext("output_token").(int64)
|
||||
if !ok {
|
||||
log.Error("Get output_token failed!")
|
||||
return data
|
||||
}
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log)
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log)
|
||||
proxywasm.SetProperty([]string{"model"}, []byte(model))
|
||||
proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken)))
|
||||
proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken)))
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
modeObj := gjson.GetBytes(body, "model")
|
||||
inputTokenObj := gjson.GetBytes(body, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(body, "usage.completion_tokens")
|
||||
if !modeObj.Exists() {
|
||||
log.Error("Get model failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !inputTokenObj.Exists() {
|
||||
log.Error("Get input_token failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !outputTokenObj.Exists() {
|
||||
log.Error("Get output_token failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
model := modeObj.String()
|
||||
inputToken := inputTokenObj.Int()
|
||||
outputToken := outputTokenObj.Int()
|
||||
var route, cluster string
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil {
|
||||
route = string(raw)
|
||||
}
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err == nil {
|
||||
cluster = string(raw)
|
||||
}
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log)
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log)
|
||||
|
||||
proxywasm.SetProperty([]string{"model"}, []byte(model))
|
||||
proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken)))
|
||||
proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken)))
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
2
plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore
vendored
Normal file
2
plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
main.wasm
|
||||
config.yaml
|
||||
186
plugins/wasm-go/extensions/ai-token-ratelimit/README.md
Normal file
186
plugins/wasm-go/extensions/ai-token-ratelimit/README.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# 功能说明
|
||||
|
||||
`ai-token-ratelimit`插件实现了基于特定键值实现token限流,键值来源可以是 URL 参数、HTTP 请求头、客户端 IP 地址、consumer 名称、cookie中 key 名称
|
||||
|
||||
|
||||
|
||||
# 配置说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ----------------------- | ------ | ---- | ------ |---------------------------------------------------------------------------|
|
||||
| rule_name | string | 是 | - | 限流规则名称,根据限流规则名称+限流类型+限流key名称+限流key对应的实际值来拼装redis key |
|
||||
| rule_items | array of object | 是 | - | 限流规则项,按照rule_items下的排列顺序,匹配第一个rule_item后命中限流规则,后续规则将被忽略 |
|
||||
| rejected_code | int | 否 | 429 | 请求被限流时,返回的HTTP状态码 |
|
||||
| rejected_msg | string | 否 | Too many requests | 请求被限流时,返回的响应体 |
|
||||
| redis | object | 是 | - | redis相关配置 |
|
||||
|
||||
`rule_items`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| --------------------- | --------------- | -------------------------- | ------ | ------------------------------------------------------------ |
|
||||
| limit_by_header | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 HTTP 请求头名称 |
|
||||
| limit_by_param | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 URL 参数名称 |
|
||||
| limit_by_consumer | string | 否,`limit_by_*`中选填一项 | - | 根据 consumer 名称进行限流,无需添加实际值 |
|
||||
| limit_by_cookie | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 Cookie中 key 名称 |
|
||||
| limit_by_per_header | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 HTTP 请求头,并对每个请求头分别计算限流,配置获取限流键值的来源 HTTP 请求头名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_param | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 URL 参数,并对每个参数分别计算限流,配置获取限流键值的来源 URL 参数名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_consumer | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 consumer,并对每个 consumer 分别计算限流,根据 consumer 名称进行限流,无需添加实际值,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_cookie | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 Cookie,并对每个 Cookie 分别计算限流,配置获取限流键值的来源 Cookie中 key 名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_ip | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 IP,并对每个 IP 分别计算限流,配置获取限流键值的来源 IP 参数名称,从请求头获取,以`from-header-对应的header名`,示例:`from-header-x-forwarded-for`,直接获取对端socket ip,配置为`from-remote-addr` |
|
||||
| limit_keys | array of object | 是 | - | 配置匹配键值后的限流次数 |
|
||||
|
||||
`limit_keys`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ---------------- | ------ | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
|
||||
| key | string | 是 | - | 匹配的键值,`limit_by_per_header`,`limit_by_per_param`,`limit_by_per_consumer`,`limit_by_per_cookie` 类型支持配置正则表达式(以regexp:开头后面跟正则表达式)或者*(代表所有),正则表达式示例:`regexp:^d.*`(以d开头的所有字符串);`limit_by_per_ip`支持配置 IP 地址或 IP 段 |
|
||||
| token_per_second | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每秒请求token数 |
|
||||
| token_per_minute | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每分钟请求token数 |
|
||||
| token_per_hour | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每小时请求token数 |
|
||||
| token_per_day | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每天请求token数 |
|
||||
|
||||
`redis`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- |
|
||||
| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
|
||||
| service_port | int | 否 | 服务类型为固定地址(static service)默认值为80,其他为6379 | 输入redis服务的服务端口 |
|
||||
| username | string | 否 | - | redis用户名 |
|
||||
| password | string | 否 | - | redis密码 |
|
||||
| timeout | int | 否 | 1000 | redis连接超时时间,单位毫秒 |
|
||||
|
||||
|
||||
|
||||
# 配置示例
|
||||
|
||||
## 识别请求参数 apikey,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_param: apikey
|
||||
limit_keys:
|
||||
- key: 9a342114-ba8a-11ec-b1bf-00163e1250b5
|
||||
token_per_minute: 10
|
||||
- key: a6a6d7f2-ba8a-11ec-bec2-00163e1250b5
|
||||
token_per_hour: 100
|
||||
- limit_by_per_param: apikey
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
token_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
token_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个apikey对应的请求1000qdh
|
||||
- key: "*"
|
||||
token_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 识别请求头 x-ca-key,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_header: x-ca-key
|
||||
limit_keys:
|
||||
- key: 102234
|
||||
token_per_minute: 10
|
||||
- key: 308239
|
||||
token_per_hour: 10
|
||||
- limit_by_per_header: x-ca-key
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
token_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
token_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个apikey对应的请求1000qdh
|
||||
- key: "*"
|
||||
token_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 根据请求头 x-forwarded-for 获取对端IP,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_per_ip: from-header-x-forwarded-for
|
||||
limit_keys:
|
||||
# 精确ip
|
||||
- key: 1.1.1.1
|
||||
token_per_day: 10
|
||||
# ip段,符合这个ip段的ip,每个ip 100qpd
|
||||
- key: 1.1.1.0/24
|
||||
token_per_day: 100
|
||||
# 兜底用,即默认每个ip 1000qpd
|
||||
- key: 0.0.0.0/0
|
||||
token_per_day: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
```
|
||||
|
||||
## 识别consumer,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_consumer: ''
|
||||
limit_keys:
|
||||
- key: consumer1
|
||||
token_per_second: 10
|
||||
- key: consumer2
|
||||
token_per_hour: 100
|
||||
- limit_by_per_consumer: ''
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个consumer对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
token_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个consumer对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
token_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个consumer对应的请求1000qdh
|
||||
- key: "*"
|
||||
token_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 识别cookie中的键值对,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_cookie: key1
|
||||
limit_keys:
|
||||
- key: value1
|
||||
token_per_minute: 10
|
||||
- key: value2
|
||||
token_per_hour: 100
|
||||
- limit_by_per_cookie: key1
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个cookie中的value对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
token_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个cookie中的value对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
token_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个cookie中的value对应的请求1000qdh
|
||||
- key: "*"
|
||||
token_per_hour: 1000
|
||||
rejected_code: 200
|
||||
rejected_msg: '{"code":-1,"msg":"Too many requests"}'
|
||||
redis:
|
||||
service_name: redis.static
|
||||
```
|
||||
297
plugins/wasm-go/extensions/ai-token-ratelimit/config.go
Normal file
297
plugins/wasm-go/extensions/ai-token-ratelimit/config.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
re "github.com/wasilibs/go-re2"
|
||||
"github.com/zmap/go-iptree/iptree"
|
||||
)
|
||||
|
||||
// 限流规则项类型
|
||||
type limitRuleItemType string
|
||||
|
||||
// 限流配置项key类型
|
||||
type limitConfigItemType string
|
||||
|
||||
const (
|
||||
limitByHeaderType limitRuleItemType = "limit_by_header"
|
||||
limitByParamType limitRuleItemType = "limit_by_param"
|
||||
limitByConsumerType limitRuleItemType = "limit_by_consumer"
|
||||
limitByCookieType limitRuleItemType = "limit_by_cookie"
|
||||
limitByPerHeaderType limitRuleItemType = "limit_by_per_header"
|
||||
limitByPerParamType limitRuleItemType = "limit_by_per_param"
|
||||
limitByPerConsumerType limitRuleItemType = "limit_by_per_consumer"
|
||||
limitByPerCookieType limitRuleItemType = "limit_by_per_cookie"
|
||||
limitByPerIpType limitRuleItemType = "limit_by_per_ip"
|
||||
|
||||
exactType limitConfigItemType = "exact" // 精确匹配
|
||||
regexpType limitConfigItemType = "regexp" // 正则表达式
|
||||
allType limitConfigItemType = "*" // 匹配所有情况
|
||||
ipNetType limitConfigItemType = "ipNet" // ip段
|
||||
|
||||
RemoteAddrSourceType = "remote-addr"
|
||||
HeaderSourceType = "header"
|
||||
|
||||
DefaultRejectedCode uint32 = 429
|
||||
DefaultRejectedMsg string = "Too many requests"
|
||||
|
||||
Second int64 = 1
|
||||
SecondsPerMinute = 60 * Second
|
||||
SecondsPerHour = 60 * SecondsPerMinute
|
||||
SecondsPerDay = 24 * SecondsPerHour
|
||||
)
|
||||
|
||||
var timeWindows = map[string]int64{
|
||||
"token_per_second": Second,
|
||||
"token_per_minute": SecondsPerMinute,
|
||||
"token_per_hour": SecondsPerHour,
|
||||
"token_per_day": SecondsPerDay,
|
||||
}
|
||||
|
||||
type ClusterKeyRateLimitConfig struct {
|
||||
ruleName string // 限流规则名称
|
||||
ruleItems []LimitRuleItem // 限流规则项
|
||||
showLimitQuotaHeader bool // 响应头中是否显示X-RateLimit-Limit和X-RateLimit-Remaining
|
||||
rejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码
|
||||
rejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体
|
||||
redisClient wrapper.RedisClient
|
||||
}
|
||||
|
||||
type LimitRuleItem struct {
|
||||
limitType limitRuleItemType // 限流类型
|
||||
key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值
|
||||
limitByPerIp LimitByPerIp // 对端ip地址或ip段
|
||||
configItems []LimitConfigItem // 限流配置项
|
||||
}
|
||||
|
||||
type LimitByPerIp struct {
|
||||
sourceType string // ip来源类型
|
||||
headerName string // 根据该请求头获取客户端ip
|
||||
}
|
||||
|
||||
type LimitConfigItem struct {
|
||||
configType limitConfigItemType // 限流配置项key类型
|
||||
key string // 限流key
|
||||
ipNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType
|
||||
regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType
|
||||
count int64 // 指定时间窗口内的总请求数量阈值
|
||||
timeWindow int64 // 时间窗口大小
|
||||
}
|
||||
|
||||
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
redisConfig := json.Get("redis")
|
||||
if !redisConfig.Exists() {
|
||||
return errors.New("missing redis in config")
|
||||
}
|
||||
serviceName := redisConfig.Get("service_name").String()
|
||||
if serviceName == "" {
|
||||
return errors.New("redis service name must not be empty")
|
||||
}
|
||||
servicePort := int(redisConfig.Get("service_port").Int())
|
||||
if servicePort == 0 {
|
||||
if strings.HasSuffix(serviceName, ".static") {
|
||||
// use default logic port which is 80 for static service
|
||||
servicePort = 80
|
||||
} else {
|
||||
servicePort = 6379
|
||||
}
|
||||
}
|
||||
username := redisConfig.Get("username").String()
|
||||
password := redisConfig.Get("password").String()
|
||||
timeout := int(redisConfig.Get("timeout").Int())
|
||||
if timeout == 0 {
|
||||
timeout = 1000
|
||||
}
|
||||
config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: int64(servicePort),
|
||||
})
|
||||
return config.redisClient.Init(username, password, int64(timeout))
|
||||
}
|
||||
|
||||
func parseClusterKeyRateLimitConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
ruleName := json.Get("rule_name")
|
||||
if !ruleName.Exists() {
|
||||
return errors.New("missing rule_name in config")
|
||||
}
|
||||
config.ruleName = ruleName.String()
|
||||
|
||||
// 初始化ruleItems
|
||||
err := initRuleItems(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rejectedCode := json.Get("rejected_code")
|
||||
if rejectedCode.Exists() {
|
||||
config.rejectedCode = uint32(rejectedCode.Uint())
|
||||
} else {
|
||||
config.rejectedCode = DefaultRejectedCode
|
||||
}
|
||||
rejectedMsg := json.Get("rejected_msg")
|
||||
if rejectedCode.Exists() {
|
||||
config.rejectedMsg = rejectedMsg.String()
|
||||
} else {
|
||||
config.rejectedMsg = DefaultRejectedMsg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initRuleItems(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
ruleItemsResult := json.Get("rule_items")
|
||||
if !ruleItemsResult.Exists() {
|
||||
return errors.New("missing rule_items in config")
|
||||
}
|
||||
if len(ruleItemsResult.Array()) == 0 {
|
||||
return errors.New("config rule_items cannot be empty")
|
||||
}
|
||||
var ruleItems []LimitRuleItem
|
||||
for _, item := range ruleItemsResult.Array() {
|
||||
var ruleItem LimitRuleItem
|
||||
|
||||
// 根据配置区分限流类型
|
||||
var limitType limitRuleItemType
|
||||
setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr limitRuleItemType) {
|
||||
if field.Exists() && field.String() != "" {
|
||||
ruleItem.key = field.String()
|
||||
limitType = limitTypeStr
|
||||
}
|
||||
}
|
||||
setLimitByKeyIfExists(item.Get("limit_by_header"), limitByHeaderType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_param"), limitByParamType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_cookie"), limitByCookieType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_header"), limitByPerHeaderType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_param"), limitByPerParamType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), limitByPerCookieType)
|
||||
|
||||
limitByConsumer := item.Get("limit_by_consumer")
|
||||
if limitByConsumer.Exists() {
|
||||
ruleItem.key = ConsumerHeader
|
||||
limitType = limitByConsumerType
|
||||
}
|
||||
limitByPerConsumer := item.Get("limit_by_per_consumer")
|
||||
if limitByPerConsumer.Exists() {
|
||||
ruleItem.key = ConsumerHeader
|
||||
limitType = limitByPerConsumerType
|
||||
}
|
||||
|
||||
limitByPerIpResult := item.Get("limit_by_per_ip")
|
||||
if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" {
|
||||
limitByPerIp := limitByPerIpResult.String()
|
||||
ruleItem.key = limitByPerIp
|
||||
if strings.HasPrefix(limitByPerIp, "from-header-") {
|
||||
headerName := limitByPerIp[len("from-header-"):]
|
||||
if headerName == "" {
|
||||
return errors.New("limit_by_per_ip parse error: empty after 'from-header-'")
|
||||
}
|
||||
ruleItem.limitByPerIp = LimitByPerIp{
|
||||
sourceType: HeaderSourceType,
|
||||
headerName: headerName,
|
||||
}
|
||||
} else if limitByPerIp == "from-remote-addr" {
|
||||
ruleItem.limitByPerIp = LimitByPerIp{
|
||||
sourceType: RemoteAddrSourceType,
|
||||
headerName: "",
|
||||
}
|
||||
} else {
|
||||
return errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'")
|
||||
}
|
||||
limitType = limitByPerIpType
|
||||
}
|
||||
|
||||
if limitType == "" {
|
||||
return errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set")
|
||||
}
|
||||
ruleItem.limitType = limitType
|
||||
|
||||
// 初始化configItems
|
||||
err := initConfigItems(item, &ruleItem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ruleItems = append(ruleItems, ruleItem)
|
||||
}
|
||||
config.ruleItems = ruleItems
|
||||
return nil
|
||||
}
|
||||
|
||||
func initConfigItems(json gjson.Result, rule *LimitRuleItem) error {
|
||||
limitKeys := json.Get("limit_keys")
|
||||
if !limitKeys.Exists() {
|
||||
return errors.New("missing limit_keys in config")
|
||||
}
|
||||
if len(limitKeys.Array()) == 0 {
|
||||
return errors.New("config limit_keys cannot be empty")
|
||||
}
|
||||
var configItems []LimitConfigItem
|
||||
for _, item := range limitKeys.Array() {
|
||||
key := item.Get("key")
|
||||
if !key.Exists() || key.String() == "" {
|
||||
return errors.New("limit_keys key is required")
|
||||
}
|
||||
|
||||
var (
|
||||
itemKey = key.String()
|
||||
itemType limitConfigItemType
|
||||
ipNet *iptree.IPTree
|
||||
regexp *re.Regexp
|
||||
)
|
||||
if rule.limitType == limitByPerIpType {
|
||||
var err error
|
||||
ipNet, err = parseIPNet(itemKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err)
|
||||
}
|
||||
itemType = ipNetType
|
||||
} else if rule.limitType == limitByPerHeaderType ||
|
||||
rule.limitType == limitByPerParamType ||
|
||||
rule.limitType == limitByPerConsumerType ||
|
||||
rule.limitType == limitByPerCookieType {
|
||||
if itemKey == "*" {
|
||||
itemType = allType
|
||||
} else if strings.HasPrefix(itemKey, "regexp:") {
|
||||
regexpStr := itemKey[len("regexp:"):]
|
||||
var err error
|
||||
regexp, err = re.Compile(regexpStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err)
|
||||
}
|
||||
itemType = regexpType
|
||||
} else {
|
||||
return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.limitType)
|
||||
}
|
||||
} else {
|
||||
itemType = exactType
|
||||
}
|
||||
|
||||
if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil {
|
||||
return err
|
||||
} else if configItem != nil {
|
||||
configItems = append(configItems, *configItem)
|
||||
}
|
||||
}
|
||||
rule.configItems = configItems
|
||||
return nil
|
||||
}
|
||||
|
||||
func createConfigItemFromRate(item gjson.Result, itemType limitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) {
|
||||
for timeWindowKey, duration := range timeWindows {
|
||||
q := item.Get(timeWindowKey)
|
||||
if q.Exists() && q.Int() > 0 {
|
||||
return &LimitConfigItem{
|
||||
configType: itemType,
|
||||
key: key,
|
||||
ipNet: ipNet,
|
||||
regexp: regexp,
|
||||
count: q.Int(),
|
||||
timeWindow: duration,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for key: " + key)
|
||||
}
|
||||
25
plugins/wasm-go/extensions/ai-token-ratelimit/go.mod
Normal file
25
plugins/wasm-go/extensions/ai-token-ratelimit/go.mod
Normal file
@@ -0,0 +1,25 @@
|
||||
module ai-token-ratelimit
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/wasilibs/go-re2 v1.5.3
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1
|
||||
)
|
||||
31
plugins/wasm-go/extensions/ai-token-ratelimit/go.sum
Normal file
31
plugins/wasm-go/extensions/ai-token-ratelimit/go.sum
Normal file
@@ -0,0 +1,31 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c h1:wKCSg4rYfwkZrMk7tYY7navjgcHCMZjcgFrCsjLQBmg=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
|
||||
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/wasilibs/go-re2 v1.5.3 h1:wiuTcgDZdLhu8NG8oqF5sF5Q3yIU14lPAvXqeYzDK3g=
|
||||
github.com/wasilibs/go-re2 v1.5.3/go.mod h1:PzpVPsBdFC7vM8QJbbEnOeTmwA0DGE783d/Gex8eCV8=
|
||||
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M=
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
303
plugins/wasm-go/extensions/ai-token-ratelimit/main.go
Normal file
303
plugins/wasm-go/extensions/ai-token-ratelimit/main.go
Normal file
@@ -0,0 +1,303 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-token-ratelimit",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
|
||||
)
|
||||
}
|
||||
|
||||
const (
|
||||
ClusterRateLimitFormat string = "higress-token-ratelimit:%s:%s:%s:%s"
|
||||
RequestPhaseFixedWindowScript string = `
|
||||
local ttl = redis.call('ttl', KEYS[1])
|
||||
if ttl < 0 then
|
||||
redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2])
|
||||
return {ARGV[1], ARGV[1], ARGV[2]}
|
||||
end
|
||||
return {ARGV[1], redis.call('get', KEYS[1]), ttl}
|
||||
`
|
||||
ResponsePhaseFixedWindowScript string = `
|
||||
local ttl = redis.call('ttl', KEYS[1])
|
||||
if ttl < 0 then
|
||||
redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2])
|
||||
return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]}
|
||||
end
|
||||
return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl}
|
||||
`
|
||||
|
||||
LimitRedisContextKey string = "LimitRedisContext"
|
||||
|
||||
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
|
||||
CookieHeader string = "cookie"
|
||||
|
||||
RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数
|
||||
RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数
|
||||
RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回)
|
||||
)
|
||||
|
||||
type LimitContext struct {
|
||||
count int
|
||||
remaining int
|
||||
reset int
|
||||
}
|
||||
|
||||
type LimitRedisContext struct {
|
||||
key string
|
||||
count int64
|
||||
window int64
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
|
||||
err := initRedisClusterClient(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = parseClusterKeyRateLimitConfig(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
// 判断是否命中限流规则
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
|
||||
if ruleItem == nil || configItem == nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// 构建redis限流key和参数
|
||||
limitKey := fmt.Sprintf(ClusterRateLimitFormat, config.ruleName, ruleItem.limitType, ruleItem.key, val)
|
||||
keys := []interface{}{limitKey}
|
||||
args := []interface{}{configItem.count, configItem.timeWindow}
|
||||
|
||||
limitRedisContext := LimitRedisContext{
|
||||
key: limitKey,
|
||||
count: configItem.count,
|
||||
window: configItem.timeWindow,
|
||||
}
|
||||
ctx.SetContext(LimitRedisContextKey, limitRedisContext)
|
||||
|
||||
// 执行限流逻辑
|
||||
err := config.redisClient.Eval(RequestPhaseFixedWindowScript, 1, keys, args, func(response resp.Value) {
|
||||
resultArray := response.Array()
|
||||
if len(resultArray) != 3 {
|
||||
log.Errorf("redis response parse error, response: %v", response)
|
||||
return
|
||||
}
|
||||
context := LimitContext{
|
||||
count: resultArray[0].Integer(),
|
||||
remaining: resultArray[1].Integer(),
|
||||
reset: resultArray[2].Integer(),
|
||||
}
|
||||
if context.remaining < 0 {
|
||||
// 触发限流
|
||||
rejected(config, context)
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("redis call failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
if !endOfStream {
|
||||
return data
|
||||
}
|
||||
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
inputToken, err := strconv.Atoi(string(inputTokenStr))
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
outputToken, err := strconv.Atoi(string(outputTokenStr))
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
keys := []interface{}{limitRedisContext.key}
|
||||
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
|
||||
|
||||
err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, func(response resp.Value) {
|
||||
if response.Error() != nil {
|
||||
log.Errorf("call Eval error: %v", response.Error())
|
||||
}
|
||||
proxywasm.ResumeHttpResponse()
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("redis call failed: %v", err)
|
||||
return data
|
||||
} else {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
for _, rule := range ruleItems {
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
|
||||
if ruleItem != nil && configItem != nil {
|
||||
return val, ruleItem, configItem
|
||||
}
|
||||
}
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
switch rule.limitType {
|
||||
// 根据HTTP请求头限流
|
||||
case limitByHeaderType, limitByPerHeaderType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(rule.key)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据HTTP请求参数限流
|
||||
case limitByParamType, limitByPerParamType:
|
||||
parse, err := url.Parse(ctx.Path())
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
|
||||
}
|
||||
query, err := url.ParseQuery(parse.RawQuery)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
|
||||
}
|
||||
val, ok := query[rule.key]
|
||||
if !ok {
|
||||
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
|
||||
}
|
||||
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
|
||||
// 根据consumer限流
|
||||
case limitByConsumerType, limitByPerConsumerType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据cookie中key值限流
|
||||
case limitByCookieType, limitByPerCookieType:
|
||||
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
|
||||
}
|
||||
val := extractCookieValueByKey(cookie, rule.key)
|
||||
if val == "" {
|
||||
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据客户端IP限流
|
||||
case limitByPerIpType:
|
||||
realIp, err := getDownStreamIp(rule)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get down stream ip: %v", err)
|
||||
return "", &rule, nil
|
||||
}
|
||||
for _, item := range rule.configItems {
|
||||
if _, found, _ := item.ipNet.Get(realIp); !found {
|
||||
continue
|
||||
}
|
||||
return realIp.String(), &rule, &item
|
||||
}
|
||||
}
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
log.Debugf(errMsg, args...)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func findMatchingItem(limitType limitRuleItemType, items []LimitConfigItem, key string) *LimitConfigItem {
|
||||
for _, item := range items {
|
||||
// per类型,检查allType和regexpType
|
||||
if limitType == limitByPerHeaderType ||
|
||||
limitType == limitByPerParamType ||
|
||||
limitType == limitByPerConsumerType ||
|
||||
limitType == limitByPerCookieType {
|
||||
if item.configType == allType || (item.configType == regexpType && item.regexp.MatchString(key)) {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
// 其他类型,直接比较key
|
||||
if item.key == key {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDownStreamIp(rule LimitRuleItem) (net.IP, error) {
|
||||
var (
|
||||
realIpStr string
|
||||
err error
|
||||
)
|
||||
if rule.limitByPerIp.sourceType == HeaderSourceType {
|
||||
realIpStr, err = proxywasm.GetHttpRequestHeader(rule.limitByPerIp.headerName)
|
||||
if err == nil {
|
||||
realIpStr = strings.Split(strings.Trim(realIpStr, " "), ",")[0]
|
||||
}
|
||||
} else {
|
||||
var bs []byte
|
||||
bs, err = proxywasm.GetProperty([]string{"source", "address"})
|
||||
realIpStr = string(bs)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := parseIP(realIpStr)
|
||||
realIP := net.ParseIP(ip)
|
||||
if realIP == nil {
|
||||
return nil, fmt.Errorf("invalid ip[%s]", ip)
|
||||
}
|
||||
return realIP, nil
|
||||
}
|
||||
|
||||
func rejected(config ClusterKeyRateLimitConfig, context LimitContext) {
|
||||
headers := make(map[string][]string)
|
||||
headers[RateLimitResetHeader] = []string{strconv.Itoa(context.reset)}
|
||||
_ = proxywasm.SendHttpResponse(
|
||||
config.rejectedCode, reconvertHeaders(headers), []byte(config.rejectedMsg), -1)
|
||||
}
|
||||
60
plugins/wasm-go/extensions/ai-token-ratelimit/utils.go
Normal file
60
plugins/wasm-go/extensions/ai-token-ratelimit/utils.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/zmap/go-iptree/iptree"
|
||||
)
|
||||
|
||||
// parseIPNet 解析Ip段配置
|
||||
func parseIPNet(key string) (*iptree.IPTree, error) {
|
||||
tree := iptree.New()
|
||||
err := tree.AddByString(key, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP[%s]", key)
|
||||
}
|
||||
return tree, nil
|
||||
}
|
||||
|
||||
// parseIP 解析IP
|
||||
func parseIP(source string) string {
|
||||
if strings.Contains(source, ".") {
|
||||
// parse ipv4
|
||||
return strings.Split(source, ":")[0]
|
||||
}
|
||||
// parse ipv6
|
||||
if strings.Contains(source, "]") {
|
||||
return strings.Split(source, "]")[0][1:]
|
||||
}
|
||||
return source
|
||||
}
|
||||
|
||||
// reconvertHeaders headers: map[string][]string -> [][2]string
|
||||
func reconvertHeaders(hs map[string][]string) [][2]string {
|
||||
var ret [][2]string
|
||||
for k, vs := range hs {
|
||||
for _, v := range vs {
|
||||
ret = append(ret, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
sort.SliceStable(ret, func(i, j int) bool {
|
||||
return ret[i][0] < ret[j][0]
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
// extractCookieValueByKey 从cookie中提取key对应的value
|
||||
func extractCookieValueByKey(cookie string, key string) (value string) {
|
||||
pairs := strings.Split(cookie, ";")
|
||||
for _, pair := range pairs {
|
||||
pair = strings.TrimSpace(pair)
|
||||
kv := strings.Split(pair, "=")
|
||||
if kv[0] == key {
|
||||
value = kv[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
3
plugins/wasm-go/extensions/ai-transformer/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-transformer/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
main.wasm
|
||||
tmp/
|
||||
81
plugins/wasm-go/extensions/ai-transformer/README.md
Normal file
81
plugins/wasm-go/extensions/ai-transformer/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# 简介
|
||||
低代码开发插件,通过LLM对请求/响应的header以及body进行修改。
|
||||
|
||||
# 配置说明
|
||||
| Name | Type | Requirement | Default | Description |
|
||||
| :- | :- | :- | :- | :- |
|
||||
| request.enable | bool | requried | - | 是否在request阶段开启转换 |
|
||||
| request.prompt | string | requried | - | request阶段转换使用的prompt |
|
||||
| response.enable | string | requried | - | 是否在response阶段开启转换 |
|
||||
| response.prompt | string | requried | - | response阶段转换使用的prompt |
|
||||
| provider.serviceName | string | requried | - | DNS类型的服务名,目前仅支持通义千问 |
|
||||
| provider.domain | string | requried | - | LLM服务域名 |
|
||||
| provider.apiKey | string | requried | - | 阿里云dashscope服务的API Key |
|
||||
|
||||
# 配置示例
|
||||
```yaml
|
||||
request:
|
||||
enable: false
|
||||
prompt: "如果请求path是以/httpbin开头的,帮我去掉/httpbin前缀,其他的不要改。"
|
||||
response:
|
||||
enable: true
|
||||
prompt: "帮我修改以下HTTP应答信息,要求:1. content-type修改为application/json;2. body由xml转化为json;3. 移除content-length。"
|
||||
provider:
|
||||
serviceName: qwen
|
||||
domain: dashscope.aliyuncs.com
|
||||
apiKey: xxxxxxxxxxxxx
|
||||
```
|
||||
|
||||
访问原始的httbin的/xml接口,结果为:
|
||||
```
|
||||
<?xml version='1.0' encoding='us-ascii'?>
|
||||
|
||||
<!-- A SAMPLE set of slides -->
|
||||
|
||||
<slideshow
|
||||
title="Sample Slide Show"
|
||||
date="Date of publication"
|
||||
author="Yours Truly"
|
||||
>
|
||||
|
||||
<!-- TITLE SLIDE -->
|
||||
<slide type="all">
|
||||
<title>Wake up to WonderWidgets!</title>
|
||||
</slide>
|
||||
|
||||
<!-- OVERVIEW -->
|
||||
<slide type="all">
|
||||
<title>Overview</title>
|
||||
<item>Why <em>WonderWidgets</em> are great</item>
|
||||
<item/>
|
||||
<item>Who <em>buys</em> WonderWidgets</item>
|
||||
</slide>
|
||||
|
||||
</slideshow>
|
||||
```
|
||||
|
||||
使用以上配置,通过网关访问httpbin的/xml接口,结果为:
|
||||
```
|
||||
{
|
||||
"slideshow": {
|
||||
"title": "Sample Slide Show",
|
||||
"date": "Date of publication",
|
||||
"author": "Yours Truly",
|
||||
"slides": [
|
||||
{
|
||||
"type": "all",
|
||||
"title": "Wake up to WonderWidgets!"
|
||||
},
|
||||
{
|
||||
"type": "all",
|
||||
"title": "Overview",
|
||||
"items": [
|
||||
"Why <em>WonderWidgets</em> are great",
|
||||
"",
|
||||
"Who <em>buys</em> WonderWidgets"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
19
plugins/wasm-go/extensions/ai-transformer/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-transformer/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module ai-transformer
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
29
plugins/wasm-go/extensions/ai-transformer/go.sum
Normal file
29
plugins/wasm-go/extensions/ai-transformer/go.sum
Normal file
@@ -0,0 +1,29 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.0 h1:uFf+mbZ2iuRXJzRbmWBuxiHvNDMGf3PCBJ6TI86bopY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.0/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
176
plugins/wasm-go/extensions/ai-transformer/main.go
Normal file
176
plugins/wasm-go/extensions/ai-transformer/main.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-transformer",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AITransformerConfig struct {
|
||||
client wrapper.HttpClient
|
||||
requestTransformEnable bool
|
||||
requestTransformPrompt string
|
||||
responseTransformEnable bool
|
||||
responseTransformPrompt string
|
||||
providerAPIKey string
|
||||
}
|
||||
|
||||
const llmRequestTemplate = `{
|
||||
"model": "qwen-max",
|
||||
"input":{
|
||||
"messages":[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "假设你是一个http 1.1协议专家,你的回答应该只包含http报文,除此之外不要有任何其他内容。"
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": ""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
func parseConfig(json gjson.Result, config *AITransformerConfig, log wrapper.Log) error {
|
||||
config.requestTransformEnable = json.Get("request.enable").Bool()
|
||||
config.requestTransformPrompt = json.Get("request.prompt").String()
|
||||
config.responseTransformEnable = json.Get("response.enable").Bool()
|
||||
config.responseTransformPrompt = json.Get("response.prompt").String()
|
||||
config.providerAPIKey = json.Get("provider.apiKey").String()
|
||||
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("provider.serviceName").String(),
|
||||
Port: 443,
|
||||
Domain: json.Get("provider.domain").String(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSplitPos(header string) int {
|
||||
for i, ch := range header {
|
||||
if ch == ':' && i != 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func extraceHttpFrame(frame string) ([][2]string, []byte, error) {
|
||||
pos := strings.Index(frame, "\n\n")
|
||||
headers := [][2]string{}
|
||||
for _, header := range strings.Split(frame[:pos], "\n") {
|
||||
splitPos := getSplitPos(header)
|
||||
if splitPos == -1 {
|
||||
return nil, nil, errors.New("invalid http frame.")
|
||||
}
|
||||
headers = append(headers, [2]string{header[:splitPos], header[splitPos+1:]})
|
||||
}
|
||||
body := []byte(frame[pos+2:])
|
||||
return headers, body, nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||||
log.Info("onHttpRequestHeaders")
|
||||
if !config.requestTransformEnable || config.requestTransformPrompt == "" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||||
log.Info("onHttpRequestBody")
|
||||
headers, err := proxywasm.GetHttpRequestHeaders()
|
||||
if err != nil {
|
||||
log.Error("Failed to get http response headers.")
|
||||
return types.ActionContinue
|
||||
}
|
||||
headerStr := ""
|
||||
for _, hd := range headers {
|
||||
headerStr += hd[0] + ":" + hd[1] + "\n"
|
||||
}
|
||||
var llmRequestBody string
|
||||
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.requestTransformPrompt)
|
||||
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
|
||||
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
|
||||
log.Info(headerStr + "\n" + string(body))
|
||||
config.client.Post(
|
||||
"/api/v1/services/aigc/text-generation/generation",
|
||||
hds,
|
||||
[]byte(llmRequestBody),
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
|
||||
if err == nil {
|
||||
proxywasm.ReplaceHttpRequestHeaders(newHeaders)
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
}
|
||||
proxywasm.ResumeHttpRequest()
|
||||
},
|
||||
50000,
|
||||
)
|
||||
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||||
if !config.responseTransformEnable || config.responseTransformPrompt == "" {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Error("Failed to get http response headers.")
|
||||
return types.ActionContinue
|
||||
}
|
||||
headerStr := ""
|
||||
for _, hd := range headers {
|
||||
headerStr += hd[0] + ":" + hd[1] + "\n"
|
||||
}
|
||||
var llmRequestBody string
|
||||
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.responseTransformPrompt)
|
||||
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
|
||||
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
|
||||
log.Info(headerStr + "\n" + string(body))
|
||||
config.client.Post(
|
||||
"/api/v1/services/aigc/text-generation/generation",
|
||||
hds,
|
||||
[]byte(llmRequestBody),
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
|
||||
if err == nil {
|
||||
proxywasm.ReplaceHttpResponseHeaders(newHeaders)
|
||||
proxywasm.ReplaceHttpResponseBody(newBody)
|
||||
}
|
||||
proxywasm.ResumeHttpResponse()
|
||||
},
|
||||
50000,
|
||||
)
|
||||
|
||||
return types.ActionPause
|
||||
}
|
||||
194
plugins/wasm-go/extensions/cluster-key-rate-limit/README.md
Normal file
194
plugins/wasm-go/extensions/cluster-key-rate-limit/README.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# 功能说明
|
||||
|
||||
`key-cluster-rate-limit`插件实现了基于特定键值实现集群限流,键值来源可以是 URL 参数、HTTP 请求头、客户端 IP 地址、consumer 名称、cookie中 key 名称
|
||||
|
||||
|
||||
|
||||
# 配置说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ----------------------- | ------ | ---- | ------ |---------------------------------------------------------------------------|
|
||||
| rule_name | string | 是 | - | 限流规则名称,根据限流规则名称+限流类型+限流key名称+限流key对应的实际值来拼装redis key |
|
||||
| rule_items | array of object | 是 | - | 限流规则项,按照rule_items下的排列顺序,匹配第一个rule_item后命中限流规则,后续规则将被忽略 |
|
||||
| show_limit_quota_header | bool | 否 | false | 响应头中是否显示`X-RateLimit-Limit`(限制的总请求数)和`X-RateLimit-Remaining`(剩余还可以发送的请求数) |
|
||||
| rejected_code | int | 否 | 429 | 请求被限流时,返回的HTTP状态码 |
|
||||
| rejected_msg | string | 否 | Too many requests | 请求被限流时,返回的响应体 |
|
||||
| redis | object | 是 | - | redis相关配置 |
|
||||
|
||||
`rule_items`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| --------------------- | --------------- | -------------------------- | ------ | ------------------------------------------------------------ |
|
||||
| limit_by_header | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 HTTP 请求头名称 |
|
||||
| limit_by_param | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 URL 参数名称 |
|
||||
| limit_by_consumer | string | 否,`limit_by_*`中选填一项 | - | 根据 consumer 名称进行限流,无需添加实际值 |
|
||||
| limit_by_cookie | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 Cookie中 key 名称 |
|
||||
| limit_by_per_header | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 HTTP 请求头,并对每个请求头分别计算限流,配置获取限流键值的来源 HTTP 请求头名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_param | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 URL 参数,并对每个参数分别计算限流,配置获取限流键值的来源 URL 参数名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_consumer | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 consumer,并对每个 consumer 分别计算限流,根据 consumer 名称进行限流,无需添加实际值,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_cookie | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 Cookie,并对每个 Cookie 分别计算限流,配置获取限流键值的来源 Cookie中 key 名称,配置`limit_keys`时支持正则表达式或`*` |
|
||||
| limit_by_per_ip | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 IP,并对每个 IP 分别计算限流,配置获取限流键值的来源 IP 参数名称,从请求头获取,以`from-header-对应的header名`,示例:`from-header-x-forwarded-for`,直接获取对端socket ip,配置为`from-remote-addr` |
|
||||
| limit_keys | array of object | 是 | - | 配置匹配键值后的限流次数 |
|
||||
|
||||
`limit_keys`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ---------------- | ------ | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
|
||||
| key | string | 是 | - | 匹配的键值,`limit_by_per_header`,`limit_by_per_param`,`limit_by_per_consumer`,`limit_by_per_cookie` 类型支持配置正则表达式(以regexp:开头后面跟正则表达式)或者*(代表所有),正则表达式示例:`regexp:^d.*`(以d开头的所有字符串);`limit_by_per_ip`支持配置 IP 地址或 IP 段 |
|
||||
| query_per_second | int | 否,`query_per_second`,`query_per_minute`,`query_per_hour`,`query_per_day` 中选填一项 | - | 允许每秒请求次数 |
|
||||
| query_per_minute | int | 否,`query_per_second`,`query_per_minute`,`query_per_hour`,`query_per_day` 中选填一项 | - | 允许每分钟请求次数 |
|
||||
| query_per_hour | int | 否,`query_per_second`,`query_per_minute`,`query_per_hour`,`query_per_day` 中选填一项 | - | 允许每小时请求次数 |
|
||||
| query_per_day | int | 否,`query_per_second`,`query_per_minute`,`query_per_hour`,`query_per_day` 中选填一项 | - | 允许每天请求次数 |
|
||||
|
||||
`redis`中每一项的配置字段说明
|
||||
|
||||
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
|
||||
| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- |
|
||||
| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
|
||||
| service_port | int | 否 | 服务类型为固定地址(static service)默认值为80,其他为6379 | 输入redis服务的服务端口 |
|
||||
| username | string | 否 | - | redis用户名 |
|
||||
| password | string | 否 | - | redis密码 |
|
||||
| timeout | int | 否 | 1000 | redis连接超时时间,单位毫秒 |
|
||||
|
||||
|
||||
|
||||
# 配置示例
|
||||
|
||||
## 识别请求参数 apikey,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_param: apikey
|
||||
limit_keys:
|
||||
- key: 9a342114-ba8a-11ec-b1bf-00163e1250b5
|
||||
query_per_minute: 10
|
||||
- key: a6a6d7f2-ba8a-11ec-bec2-00163e1250b5
|
||||
query_per_hour: 100
|
||||
- limit_by_per_param: apikey
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
query_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
query_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个apikey对应的请求1000qdh
|
||||
- key: "*"
|
||||
query_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 识别请求头 x-ca-key,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_header: x-ca-key
|
||||
limit_keys:
|
||||
- key: 102234
|
||||
query_per_minute: 10
|
||||
- key: 308239
|
||||
query_per_hour: 10
|
||||
- limit_by_per_header: x-ca-key
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
query_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
query_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个apikey对应的请求1000qdh
|
||||
- key: "*"
|
||||
query_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 根据请求头 x-forwarded-for 获取对端IP,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_per_ip: from-header-x-forwarded-for
|
||||
limit_keys:
|
||||
# 精确ip
|
||||
- key: 1.1.1.1
|
||||
query_per_day: 10
|
||||
# ip段,符合这个ip段的ip,每个ip 100qpd
|
||||
- key: 1.1.1.0/24
|
||||
query_per_day: 100
|
||||
# 兜底用,即默认每个ip 1000qpd
|
||||
- key: 0.0.0.0/0
|
||||
query_per_day: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 识别consumer,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_consumer: ''
|
||||
limit_keys:
|
||||
- key: consumer1
|
||||
query_per_second: 10
|
||||
- key: consumer2
|
||||
query_per_hour: 100
|
||||
- limit_by_per_consumer: ''
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个consumer对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
query_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个consumer对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
query_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个consumer对应的请求1000qdh
|
||||
- key: "*"
|
||||
query_per_hour: 1000
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 识别cookie中的键值对,进行区别限流
|
||||
|
||||
```yaml
|
||||
rule_name: default_rule
|
||||
rule_items:
|
||||
- limit_by_cookie: key1
|
||||
limit_keys:
|
||||
- key: value1
|
||||
query_per_minute: 10
|
||||
- key: value2
|
||||
query_per_hour: 100
|
||||
- limit_by_per_cookie: key1
|
||||
limit_keys:
|
||||
# 正则表达式,匹配以a开头的所有字符串,每个cookie中的value对应的请求10qds
|
||||
- key: "regexp:^a.*"
|
||||
query_per_second: 10
|
||||
# 正则表达式,匹配以b开头的所有字符串,每个cookie中的value对应的请求100qd
|
||||
- key: "regexp:^b.*"
|
||||
query_per_minute: 100
|
||||
# 兜底用,匹配所有请求,每个cookie中的value对应的请求1000qdh
|
||||
- key: "*"
|
||||
query_per_hour: 1000
|
||||
rejected_code: 200
|
||||
rejected_msg: '{"code":-1,"msg":"Too many requests"}'
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
301
plugins/wasm-go/extensions/cluster-key-rate-limit/config.go
Normal file
301
plugins/wasm-go/extensions/cluster-key-rate-limit/config.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
re "github.com/wasilibs/go-re2"
|
||||
"github.com/zmap/go-iptree/iptree"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 限流规则项类型
|
||||
type limitRuleItemType string
|
||||
|
||||
// 限流配置项key类型
|
||||
type limitConfigItemType string
|
||||
|
||||
const (
|
||||
limitByHeaderType limitRuleItemType = "limit_by_header"
|
||||
limitByParamType limitRuleItemType = "limit_by_param"
|
||||
limitByConsumerType limitRuleItemType = "limit_by_consumer"
|
||||
limitByCookieType limitRuleItemType = "limit_by_cookie"
|
||||
limitByPerHeaderType limitRuleItemType = "limit_by_per_header"
|
||||
limitByPerParamType limitRuleItemType = "limit_by_per_param"
|
||||
limitByPerConsumerType limitRuleItemType = "limit_by_per_consumer"
|
||||
limitByPerCookieType limitRuleItemType = "limit_by_per_cookie"
|
||||
limitByPerIpType limitRuleItemType = "limit_by_per_ip"
|
||||
|
||||
exactType limitConfigItemType = "exact" // 精确匹配
|
||||
regexpType limitConfigItemType = "regexp" // 正则表达式
|
||||
allType limitConfigItemType = "*" // 匹配所有情况
|
||||
ipNetType limitConfigItemType = "ipNet" // ip段
|
||||
|
||||
RemoteAddrSourceType = "remote-addr"
|
||||
HeaderSourceType = "header"
|
||||
|
||||
DefaultRejectedCode uint32 = 429
|
||||
DefaultRejectedMsg string = "Too many requests"
|
||||
|
||||
Second int64 = 1
|
||||
SecondsPerMinute = 60 * Second
|
||||
SecondsPerHour = 60 * SecondsPerMinute
|
||||
SecondsPerDay = 24 * SecondsPerHour
|
||||
)
|
||||
|
||||
var timeWindows = map[string]int64{
|
||||
"query_per_second": Second,
|
||||
"query_per_minute": SecondsPerMinute,
|
||||
"query_per_hour": SecondsPerHour,
|
||||
"query_per_day": SecondsPerDay,
|
||||
}
|
||||
|
||||
type ClusterKeyRateLimitConfig struct {
|
||||
ruleName string // 限流规则名称
|
||||
ruleItems []LimitRuleItem // 限流规则项
|
||||
showLimitQuotaHeader bool // 响应头中是否显示X-RateLimit-Limit和X-RateLimit-Remaining
|
||||
rejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码
|
||||
rejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体
|
||||
redisClient wrapper.RedisClient
|
||||
}
|
||||
|
||||
type LimitRuleItem struct {
|
||||
limitType limitRuleItemType // 限流类型
|
||||
key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值
|
||||
limitByPerIp LimitByPerIp // 对端ip地址或ip段
|
||||
configItems []LimitConfigItem // 限流配置项
|
||||
}
|
||||
|
||||
type LimitByPerIp struct {
|
||||
sourceType string // ip来源类型
|
||||
headerName string // 根据该请求头获取客户端ip
|
||||
}
|
||||
|
||||
type LimitConfigItem struct {
|
||||
configType limitConfigItemType // 限流配置项key类型
|
||||
key string // 限流key
|
||||
ipNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType
|
||||
regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType
|
||||
count int64 // 指定时间窗口内的总请求数量阈值
|
||||
timeWindow int64 // 时间窗口大小
|
||||
}
|
||||
|
||||
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
redisConfig := json.Get("redis")
|
||||
if !redisConfig.Exists() {
|
||||
return errors.New("missing redis in config")
|
||||
}
|
||||
serviceName := redisConfig.Get("service_name").String()
|
||||
if serviceName == "" {
|
||||
return errors.New("redis service name must not be empty")
|
||||
}
|
||||
servicePort := int(redisConfig.Get("service_port").Int())
|
||||
if servicePort == 0 {
|
||||
if strings.HasSuffix(serviceName, ".static") {
|
||||
// use default logic port which is 80 for static service
|
||||
servicePort = 80
|
||||
} else {
|
||||
servicePort = 6379
|
||||
}
|
||||
}
|
||||
username := redisConfig.Get("username").String()
|
||||
password := redisConfig.Get("password").String()
|
||||
timeout := int(redisConfig.Get("timeout").Int())
|
||||
if timeout == 0 {
|
||||
timeout = 1000
|
||||
}
|
||||
config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: int64(servicePort),
|
||||
})
|
||||
return config.redisClient.Init(username, password, int64(timeout))
|
||||
}
|
||||
|
||||
func parseClusterKeyRateLimitConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
ruleName := json.Get("rule_name")
|
||||
if !ruleName.Exists() {
|
||||
return errors.New("missing rule_name in config")
|
||||
}
|
||||
config.ruleName = ruleName.String()
|
||||
|
||||
// 初始化ruleItems
|
||||
err := initRuleItems(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
showLimitQuotaHeader := json.Get("show_limit_quota_header")
|
||||
if showLimitQuotaHeader.Exists() {
|
||||
config.showLimitQuotaHeader = showLimitQuotaHeader.Bool()
|
||||
}
|
||||
|
||||
rejectedCode := json.Get("rejected_code")
|
||||
if rejectedCode.Exists() {
|
||||
config.rejectedCode = uint32(rejectedCode.Uint())
|
||||
} else {
|
||||
config.rejectedCode = DefaultRejectedCode
|
||||
}
|
||||
rejectedMsg := json.Get("rejected_msg")
|
||||
if rejectedCode.Exists() {
|
||||
config.rejectedMsg = rejectedMsg.String()
|
||||
} else {
|
||||
config.rejectedMsg = DefaultRejectedMsg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initRuleItems(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
ruleItemsResult := json.Get("rule_items")
|
||||
if !ruleItemsResult.Exists() {
|
||||
return errors.New("missing rule_items in config")
|
||||
}
|
||||
if len(ruleItemsResult.Array()) == 0 {
|
||||
return errors.New("config rule_items cannot be empty")
|
||||
}
|
||||
var ruleItems []LimitRuleItem
|
||||
for _, item := range ruleItemsResult.Array() {
|
||||
var ruleItem LimitRuleItem
|
||||
|
||||
// 根据配置区分限流类型
|
||||
var limitType limitRuleItemType
|
||||
setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr limitRuleItemType) {
|
||||
if field.Exists() && field.String() != "" {
|
||||
ruleItem.key = field.String()
|
||||
limitType = limitTypeStr
|
||||
}
|
||||
}
|
||||
setLimitByKeyIfExists(item.Get("limit_by_header"), limitByHeaderType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_param"), limitByParamType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_cookie"), limitByCookieType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_header"), limitByPerHeaderType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_param"), limitByPerParamType)
|
||||
setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), limitByPerCookieType)
|
||||
|
||||
limitByConsumer := item.Get("limit_by_consumer")
|
||||
if limitByConsumer.Exists() {
|
||||
ruleItem.key = ConsumerHeader
|
||||
limitType = limitByConsumerType
|
||||
}
|
||||
limitByPerConsumer := item.Get("limit_by_per_consumer")
|
||||
if limitByPerConsumer.Exists() {
|
||||
ruleItem.key = ConsumerHeader
|
||||
limitType = limitByPerConsumerType
|
||||
}
|
||||
|
||||
limitByPerIpResult := item.Get("limit_by_per_ip")
|
||||
if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" {
|
||||
limitByPerIp := limitByPerIpResult.String()
|
||||
ruleItem.key = limitByPerIp
|
||||
if strings.HasPrefix(limitByPerIp, "from-header-") {
|
||||
headerName := limitByPerIp[len("from-header-"):]
|
||||
if headerName == "" {
|
||||
return errors.New("limit_by_per_ip parse error: empty after 'from-header-'")
|
||||
}
|
||||
ruleItem.limitByPerIp = LimitByPerIp{
|
||||
sourceType: HeaderSourceType,
|
||||
headerName: headerName,
|
||||
}
|
||||
} else if limitByPerIp == "from-remote-addr" {
|
||||
ruleItem.limitByPerIp = LimitByPerIp{
|
||||
sourceType: RemoteAddrSourceType,
|
||||
headerName: "",
|
||||
}
|
||||
} else {
|
||||
return errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'")
|
||||
}
|
||||
limitType = limitByPerIpType
|
||||
}
|
||||
|
||||
if limitType == "" {
|
||||
return errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set")
|
||||
}
|
||||
ruleItem.limitType = limitType
|
||||
|
||||
// 初始化configItems
|
||||
err := initConfigItems(item, &ruleItem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ruleItems = append(ruleItems, ruleItem)
|
||||
}
|
||||
config.ruleItems = ruleItems
|
||||
return nil
|
||||
}
|
||||
|
||||
func initConfigItems(json gjson.Result, rule *LimitRuleItem) error {
|
||||
limitKeys := json.Get("limit_keys")
|
||||
if !limitKeys.Exists() {
|
||||
return errors.New("missing limit_keys in config")
|
||||
}
|
||||
if len(limitKeys.Array()) == 0 {
|
||||
return errors.New("config limit_keys cannot be empty")
|
||||
}
|
||||
var configItems []LimitConfigItem
|
||||
for _, item := range limitKeys.Array() {
|
||||
key := item.Get("key")
|
||||
if !key.Exists() || key.String() == "" {
|
||||
return errors.New("limit_keys key is required")
|
||||
}
|
||||
|
||||
var (
|
||||
itemKey = key.String()
|
||||
itemType limitConfigItemType
|
||||
ipNet *iptree.IPTree
|
||||
regexp *re.Regexp
|
||||
)
|
||||
if rule.limitType == limitByPerIpType {
|
||||
var err error
|
||||
ipNet, err = parseIPNet(itemKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err)
|
||||
}
|
||||
itemType = ipNetType
|
||||
} else if rule.limitType == limitByPerHeaderType ||
|
||||
rule.limitType == limitByPerParamType ||
|
||||
rule.limitType == limitByPerConsumerType ||
|
||||
rule.limitType == limitByPerCookieType {
|
||||
if itemKey == "*" {
|
||||
itemType = allType
|
||||
} else if strings.HasPrefix(itemKey, "regexp:") {
|
||||
regexpStr := itemKey[len("regexp:"):]
|
||||
var err error
|
||||
regexp, err = re.Compile(regexpStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err)
|
||||
}
|
||||
itemType = regexpType
|
||||
} else {
|
||||
return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.limitType)
|
||||
}
|
||||
} else {
|
||||
itemType = exactType
|
||||
}
|
||||
|
||||
if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil {
|
||||
return err
|
||||
} else if configItem != nil {
|
||||
configItems = append(configItems, *configItem)
|
||||
}
|
||||
}
|
||||
rule.configItems = configItems
|
||||
return nil
|
||||
}
|
||||
|
||||
func createConfigItemFromRate(item gjson.Result, itemType limitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) {
|
||||
for timeWindowKey, duration := range timeWindows {
|
||||
q := item.Get(timeWindowKey)
|
||||
if q.Exists() && q.Int() > 0 {
|
||||
return &LimitConfigItem{
|
||||
configType: itemType,
|
||||
key: key,
|
||||
ipNet: ipNet,
|
||||
regexp: regexp,
|
||||
count: q.Int(),
|
||||
timeWindow: duration,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("one of 'query_per_second', 'query_per_minute', 'query_per_hour', or 'query_per_day' must be set for key: " + key)
|
||||
}
|
||||
24
plugins/wasm-go/extensions/cluster-key-rate-limit/go.mod
Normal file
24
plugins/wasm-go/extensions/cluster-key-rate-limit/go.mod
Normal file
@@ -0,0 +1,24 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/key-cluster-rate-limit
|
||||
|
||||
go 1.19
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/tidwall/resp v0.1.1
|
||||
github.com/wasilibs/go-re2 v1.5.3
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.1 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
)
|
||||
28
plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum
Normal file
28
plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum
Normal file
@@ -0,0 +1,28 @@
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
|
||||
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/wasilibs/go-re2 v1.5.3 h1:wiuTcgDZdLhu8NG8oqF5sF5Q3yIU14lPAvXqeYzDK3g=
|
||||
github.com/wasilibs/go-re2 v1.5.3/go.mod h1:PzpVPsBdFC7vM8QJbbEnOeTmwA0DGE783d/Gex8eCV8=
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M=
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
258
plugins/wasm-go/extensions/cluster-key-rate-limit/main.go
Normal file
258
plugins/wasm-go/extensions/cluster-key-rate-limit/main.go
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"cluster-key-rate-limit",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
const (
|
||||
ClusterRateLimitFormat string = "higress-cluster-key-rate-limit:%s:%s:%s:%s" // redis key为前缀:限流规则名称:限流类型:限流key名称:限流key对应的实际值
|
||||
FixedWindowScript string = `
|
||||
local ttl = redis.call('ttl', KEYS[1])
|
||||
if ttl < 0 then
|
||||
redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2])
|
||||
return {ARGV[1], ARGV[1] - 1, ARGV[2]}
|
||||
end
|
||||
return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl}
|
||||
`
|
||||
|
||||
LimitContextKey string = "LimitContext" // 限流上下文信息
|
||||
|
||||
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
|
||||
CookieHeader string = "cookie"
|
||||
|
||||
RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数
|
||||
RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数
|
||||
RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回)
|
||||
)
|
||||
|
||||
type LimitContext struct {
|
||||
count int
|
||||
remaining int
|
||||
reset int
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
|
||||
err := initRedisClusterClient(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = parseClusterKeyRateLimitConfig(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
// 判断是否命中限流规则
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
|
||||
if ruleItem == nil || configItem == nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// 构建redis限流key和参数
|
||||
limitKey := fmt.Sprintf(ClusterRateLimitFormat, config.ruleName, ruleItem.limitType, ruleItem.key, val)
|
||||
keys := []interface{}{limitKey}
|
||||
args := []interface{}{configItem.count, configItem.timeWindow}
|
||||
// 执行限流逻辑
|
||||
err := config.redisClient.Eval(FixedWindowScript, 1, keys, args, func(response resp.Value) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
resultArray := response.Array()
|
||||
if len(resultArray) != 3 {
|
||||
log.Errorf("redis response parse error, response: %v", response)
|
||||
return
|
||||
}
|
||||
context := LimitContext{
|
||||
count: resultArray[0].Integer(),
|
||||
remaining: resultArray[1].Integer(),
|
||||
reset: resultArray[2].Integer(),
|
||||
}
|
||||
if context.remaining < 0 {
|
||||
// 触发限流
|
||||
rejected(config, context)
|
||||
} else {
|
||||
ctx.SetContext(LimitContextKey, context)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("redis call failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
limitContext, ok := ctx.GetContext(LimitContextKey).(LimitContext)
|
||||
if !ok {
|
||||
return types.ActionContinue
|
||||
}
|
||||
if config.showLimitQuotaHeader {
|
||||
_ = proxywasm.ReplaceHttpResponseHeader(RateLimitLimitHeader, strconv.Itoa(limitContext.count))
|
||||
_ = proxywasm.ReplaceHttpResponseHeader(RateLimitRemainingHeader, strconv.Itoa(limitContext.remaining))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
for _, rule := range ruleItems {
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
|
||||
if ruleItem != nil && configItem != nil {
|
||||
return val, ruleItem, configItem
|
||||
}
|
||||
}
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
switch rule.limitType {
|
||||
// 根据HTTP请求头限流
|
||||
case limitByHeaderType, limitByPerHeaderType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(rule.key)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据HTTP请求参数限流
|
||||
case limitByParamType, limitByPerParamType:
|
||||
parse, err := url.Parse(ctx.Path())
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
|
||||
}
|
||||
query, err := url.ParseQuery(parse.RawQuery)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
|
||||
}
|
||||
val, ok := query[rule.key]
|
||||
if !ok {
|
||||
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
|
||||
}
|
||||
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
|
||||
// 根据consumer限流
|
||||
case limitByConsumerType, limitByPerConsumerType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据cookie中key值限流
|
||||
case limitByCookieType, limitByPerCookieType:
|
||||
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
|
||||
}
|
||||
val := extractCookieValueByKey(cookie, rule.key)
|
||||
if val == "" {
|
||||
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据客户端IP限流
|
||||
case limitByPerIpType:
|
||||
realIp, err := getDownStreamIp(rule)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get down stream ip: %v", err)
|
||||
return "", &rule, nil
|
||||
}
|
||||
for _, item := range rule.configItems {
|
||||
if _, found, _ := item.ipNet.Get(realIp); !found {
|
||||
continue
|
||||
}
|
||||
return realIp.String(), &rule, &item
|
||||
}
|
||||
}
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
log.Debugf(errMsg, args...)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func findMatchingItem(limitType limitRuleItemType, items []LimitConfigItem, key string) *LimitConfigItem {
|
||||
for _, item := range items {
|
||||
// per类型,检查allType和regexpType
|
||||
if limitType == limitByPerHeaderType ||
|
||||
limitType == limitByPerParamType ||
|
||||
limitType == limitByPerConsumerType ||
|
||||
limitType == limitByPerCookieType {
|
||||
if item.configType == allType || (item.configType == regexpType && item.regexp.MatchString(key)) {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
// 其他类型,直接比较key
|
||||
if item.key == key {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDownStreamIp(rule LimitRuleItem) (net.IP, error) {
|
||||
var (
|
||||
realIpStr string
|
||||
err error
|
||||
)
|
||||
if rule.limitByPerIp.sourceType == HeaderSourceType {
|
||||
realIpStr, err = proxywasm.GetHttpRequestHeader(rule.limitByPerIp.headerName)
|
||||
if err == nil {
|
||||
realIpStr = strings.Split(strings.Trim(realIpStr, " "), ",")[0]
|
||||
}
|
||||
} else {
|
||||
var bs []byte
|
||||
bs, err = proxywasm.GetProperty([]string{"source", "address"})
|
||||
realIpStr = string(bs)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := parseIP(realIpStr)
|
||||
realIP := net.ParseIP(ip)
|
||||
if realIP == nil {
|
||||
return nil, fmt.Errorf("invalid ip[%s]", ip)
|
||||
}
|
||||
return realIP, nil
|
||||
}
|
||||
|
||||
func rejected(config ClusterKeyRateLimitConfig, context LimitContext) {
|
||||
headers := make(map[string][]string)
|
||||
headers[RateLimitResetHeader] = []string{strconv.Itoa(context.reset)}
|
||||
if config.showLimitQuotaHeader {
|
||||
headers[RateLimitLimitHeader] = []string{strconv.Itoa(context.count)}
|
||||
headers[RateLimitRemainingHeader] = []string{strconv.Itoa(0)}
|
||||
}
|
||||
_ = proxywasm.SendHttpResponse(
|
||||
config.rejectedCode, reconvertHeaders(headers), []byte(config.rejectedMsg), -1)
|
||||
}
|
||||
59
plugins/wasm-go/extensions/cluster-key-rate-limit/utils.go
Normal file
59
plugins/wasm-go/extensions/cluster-key-rate-limit/utils.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/zmap/go-iptree/iptree"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseIPNet 解析Ip段配置
|
||||
func parseIPNet(key string) (*iptree.IPTree, error) {
|
||||
tree := iptree.New()
|
||||
err := tree.AddByString(key, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP[%s]", key)
|
||||
}
|
||||
return tree, nil
|
||||
}
|
||||
|
||||
// parseIP 解析IP
|
||||
func parseIP(source string) string {
|
||||
if strings.Contains(source, ".") {
|
||||
// parse ipv4
|
||||
return strings.Split(source, ":")[0]
|
||||
}
|
||||
// parse ipv6
|
||||
if strings.Contains(source, "]") {
|
||||
return strings.Split(source, "]")[0][1:]
|
||||
}
|
||||
return source
|
||||
}
|
||||
|
||||
// reconvertHeaders headers: map[string][]string -> [][2]string
|
||||
func reconvertHeaders(hs map[string][]string) [][2]string {
|
||||
var ret [][2]string
|
||||
for k, vs := range hs {
|
||||
for _, v := range vs {
|
||||
ret = append(ret, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
sort.SliceStable(ret, func(i, j int) bool {
|
||||
return ret[i][0] < ret[j][0]
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
// extractCookieValueByKey 从cookie中提取key对应的value
|
||||
func extractCookieValueByKey(cookie string, key string) (value string) {
|
||||
pairs := strings.Split(cookie, ";")
|
||||
for _, pair := range pairs {
|
||||
pair = strings.TrimSpace(pair)
|
||||
kv := strings.Split(pair, "=")
|
||||
if kv[0] == key {
|
||||
value = kv[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
2
plugins/wasm-go/extensions/jwt-auth/Dockerfile
Normal file
2
plugins/wasm-go/extensions/jwt-auth/Dockerfile
Normal file
@@ -0,0 +1,2 @@
|
||||
FROM scratch
|
||||
COPY main.wasm plugin.wasm
|
||||
5
plugins/wasm-go/extensions/jwt-auth/Makefile
Normal file
5
plugins/wasm-go/extensions/jwt-auth/Makefile
Normal file
@@ -0,0 +1,5 @@
|
||||
build:
|
||||
go mod tidy
|
||||
tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer"
|
||||
|
||||
default: build
|
||||
1
plugins/wasm-go/extensions/jwt-auth/VERSION
Normal file
1
plugins/wasm-go/extensions/jwt-auth/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0.1.0
|
||||
34
plugins/wasm-go/extensions/jwt-auth/config/checker.go
Normal file
34
plugins/wasm-go/extensions/jwt-auth/config/checker.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
type GlobalAuthType int
|
||||
|
||||
const (
|
||||
GlobalAuthTrue GlobalAuthType = 10000 + iota
|
||||
GlobalAuthFalse
|
||||
GlobalAuthNoSet
|
||||
)
|
||||
|
||||
func (c *JWTAuthConfig) GlobalAuthCheck() GlobalAuthType {
|
||||
if c.GlobalAuth == nil {
|
||||
return GlobalAuthNoSet
|
||||
}
|
||||
|
||||
if *c.GlobalAuth {
|
||||
return GlobalAuthTrue
|
||||
}
|
||||
return GlobalAuthFalse
|
||||
}
|
||||
125
plugins/wasm-go/extensions/jwt-auth/config/config.go
Normal file
125
plugins/wasm-go/extensions/jwt-auth/config/config.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
var (
|
||||
// DefaultClaimToHeaderOverride 是 claim_to_override 中 override 字段的默认值
|
||||
DefaultClaimToHeaderOverride = true
|
||||
|
||||
// DefaultClockSkewSeconds 是 ClockSkewSeconds 的默认值
|
||||
DefaultClockSkewSeconds = int64(60)
|
||||
|
||||
// DefaultKeepToken 是 KeepToken 的默认值
|
||||
DefaultKeepToken = true
|
||||
|
||||
// DefaultFromHeader 是 from_header 的默认值
|
||||
DefaultFromHeader = []FromHeader{{
|
||||
Name: "Authorization",
|
||||
ValuePrefix: "Bearer ",
|
||||
}}
|
||||
|
||||
// DefaultFromParams 是 from_params 的默认值
|
||||
DefaultFromParams = []string{"access_token"}
|
||||
|
||||
// DefaultFromCookies 是 from_cookies 的默认值
|
||||
DefaultFromCookies = []string{}
|
||||
)
|
||||
|
||||
// JWTAuthConfig defines the struct of the global config of higress wasm plugin jwt-auth.
|
||||
// https://higress.io/zh-cn/docs/plugins/jwt-auth
|
||||
type JWTAuthConfig struct {
|
||||
// 全局配置
|
||||
//
|
||||
// Consumers 配置服务的调用者,用于对请求进行认证
|
||||
Consumers []*Consumer `json:"consumers"`
|
||||
|
||||
// 全局配置
|
||||
//
|
||||
// GlobalAuth 若配置为true,则全局生效认证机制;
|
||||
// 若配置为false,则只对做了配置的域名和路由生效认证机制;
|
||||
// 若不配置则仅当没有域名和路由配置时全局生效(兼容机制)
|
||||
GlobalAuth *bool `json:"global_auth,omitempty"`
|
||||
|
||||
// 域名和路由级配置
|
||||
//
|
||||
// Allow 对于符合匹配条件的请求,配置允许访问的consumer名称
|
||||
Allow []string `json:"allow"`
|
||||
}
|
||||
|
||||
// Consumer 配置服务的调用者,用于对请求进行认证
|
||||
type Consumer struct {
|
||||
// Name 配置该consumer的名称
|
||||
Name string `json:"name"`
|
||||
|
||||
// JWKs 指定的json格式字符串,是由验证JWT中签名的公钥(或对称密钥)组成的Json Web Key Set
|
||||
//
|
||||
// https://www.rfc-editor.org/rfc/rfc7517
|
||||
JWKs string `json:"jwks"`
|
||||
|
||||
// Issuer JWT的签发者,需要和payload中的iss字段保持一致
|
||||
Issuer string `json:"issuer"`
|
||||
|
||||
// ClaimsToHeaders 抽取JWT的payload中指定字段,设置到指定的请求头中转发给后端
|
||||
ClaimsToHeaders *[]ClaimsToHeader `json:"claims_to_headers,omitempty"`
|
||||
|
||||
// FromHeaders 从指定的请求头中抽取JWT
|
||||
//
|
||||
// 默认值为 [{"name":"Authorization","value_prefix":"Bearer "}]
|
||||
//
|
||||
// 只有当from_headers,from_params,from_cookies均未配置时,才会使用默认值
|
||||
FromHeaders *[]FromHeader `json:"from_headers,omitempty"`
|
||||
|
||||
// FromParams 从指定的URL参数中抽取JWT
|
||||
//
|
||||
// 默认值为 access_token
|
||||
//
|
||||
// 只有当from_headers,from_params,from_cookies均未配置时,才会使用默认值
|
||||
FromParams *[]string `json:"from_params,omitempty"`
|
||||
|
||||
// FromCookies 从指定的cookie中抽取JWT
|
||||
FromCookies *[]string `json:"from_cookies,omitempty"`
|
||||
|
||||
// ClockSkewSeconds 校验JWT的exp和iat字段时允许的时钟偏移量,单位为秒
|
||||
//
|
||||
// 默认值为 60
|
||||
ClockSkewSeconds *int64 `json:"clock_skew_seconds,omitempty"`
|
||||
|
||||
// KeepToken 转发给后端时是否保留JWT
|
||||
//
|
||||
// 默认值为 true
|
||||
KeepToken *bool `json:"keep_token,omitempty"`
|
||||
}
|
||||
|
||||
// ClaimsToHeader 抽取JWT的payload中指定字段,设置到指定的请求头中转发给后端
|
||||
type ClaimsToHeader struct {
|
||||
// Claim JWT payload中的指定字段,要求必须是字符串或无符号整数类型
|
||||
Claim string `json:"claim"`
|
||||
|
||||
// Header 从payload取出字段的值设置到这个请求头中,转发给后端
|
||||
Header string `json:"header"`
|
||||
|
||||
// Override true时,存在同名请求头会进行覆盖;false时,追加同名请求头
|
||||
//
|
||||
// 默认值为 true
|
||||
Override *bool `json:"override,omitempty"`
|
||||
}
|
||||
|
||||
// FromHeader 从指定的请求头中抽取JWT
|
||||
type FromHeader struct {
|
||||
// Name 抽取JWT的请求header
|
||||
Name string `json:"name"`
|
||||
// ValuePrefix 对请求header的value去除此前缀,剩余部分作为JWT
|
||||
ValuePrefix string `json:"value_prefix"`
|
||||
}
|
||||
138
plugins/wasm-go/extensions/jwt-auth/config/parser.go
Normal file
138
plugins/wasm-go/extensions/jwt-auth/config/parser.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// RuleSet 插件是否至少在一个 domain 或 route 上生效
|
||||
var RuleSet bool
|
||||
|
||||
// ParseGlobalConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||||
// 此处解析的是全局配置,域名和路由级配置由 ParseRuleConfig 负责。
|
||||
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log wrapper.Log) error {
|
||||
RuleSet = false
|
||||
consumers := json.Get("consumers")
|
||||
if !consumers.IsArray() {
|
||||
return fmt.Errorf("failed to parse configuration for consumers: consumers is not a array")
|
||||
}
|
||||
|
||||
consumerNames := map[string]struct{}{}
|
||||
for _, v := range consumers.Array() {
|
||||
c, err := ParseConsumer(v, consumerNames)
|
||||
if err != nil {
|
||||
log.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
config.Consumers = append(config.Consumers, c)
|
||||
}
|
||||
if len(config.Consumers) == 0 {
|
||||
return fmt.Errorf("at least one consumer should be configured for a rule")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseRuleConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||||
// 此处解析的是域名和路由级配置,全局配置由 ParseConfig 负责。
|
||||
func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthConfig, log wrapper.Log) error {
|
||||
// override config via global
|
||||
*config = global
|
||||
|
||||
allow := json.Get("allow")
|
||||
if !allow.Exists() {
|
||||
return fmt.Errorf("allow is required")
|
||||
}
|
||||
|
||||
if len(allow.Array()) == 0 {
|
||||
return fmt.Errorf("allow cannot be empty")
|
||||
}
|
||||
|
||||
for _, item := range allow.Array() {
|
||||
config.Allow = append(config.Allow, item.String())
|
||||
}
|
||||
|
||||
RuleSet = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseConsumer(consumer gjson.Result, names map[string]struct{}) (c *Consumer, err error) {
|
||||
c = &Consumer{}
|
||||
|
||||
// 从gjson中取得原始JSON字符串,并使用标准库反序列化,以降低代码复杂度。
|
||||
err = json.Unmarshal([]byte(consumer.Raw), c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse consumer: %s", err.Error())
|
||||
}
|
||||
|
||||
// 检查consumer是否重复
|
||||
if _, ok := names[c.Name]; ok {
|
||||
return nil, fmt.Errorf("consumer already exists: %s", c.Name)
|
||||
}
|
||||
|
||||
// 检查JWKs是否合法
|
||||
jwks := &jose.JSONWebKeySet{}
|
||||
err = json.Unmarshal([]byte(c.JWKs), jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jwks is invalid, consumer:%s, status:%s, jwks:%s", c.Name, err.Error(), c.JWKs)
|
||||
}
|
||||
|
||||
// 检查是否需要使用默认jwt抽取来源
|
||||
if c.FromHeaders == nil && c.FromParams == nil && c.FromCookies == nil {
|
||||
c.FromHeaders = &DefaultFromHeader
|
||||
c.FromParams = &DefaultFromParams
|
||||
c.FromCookies = &DefaultFromCookies
|
||||
}
|
||||
|
||||
// 检查ClaimsToHeaders
|
||||
if c.ClaimsToHeaders != nil {
|
||||
// header去重
|
||||
c2h := map[string]struct{}{}
|
||||
|
||||
// 此处需要先把指针解引用到临时变量
|
||||
tmp := *c.ClaimsToHeaders
|
||||
for i := range tmp {
|
||||
if _, ok := c2h[tmp[i].Header]; ok {
|
||||
return nil, fmt.Errorf("claim to header already exists: %s", c2h[tmp[i].Header])
|
||||
}
|
||||
c2h[tmp[i].Header] = struct{}{}
|
||||
|
||||
// 为Override填充默认值
|
||||
if tmp[i].Override == nil {
|
||||
tmp[i].Override = &DefaultClaimToHeaderOverride
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为ClockSkewSeconds填充默认值
|
||||
if c.ClockSkewSeconds == nil {
|
||||
c.ClockSkewSeconds = &DefaultClockSkewSeconds
|
||||
}
|
||||
|
||||
// 为KeepToken填充默认值
|
||||
if c.KeepToken == nil {
|
||||
c.KeepToken = &DefaultKeepToken
|
||||
}
|
||||
|
||||
// consumer合法,记录consumer名称
|
||||
names[c.Name] = struct{}{}
|
||||
return c, nil
|
||||
}
|
||||
22
plugins/wasm-go/extensions/jwt-auth/go.mod
Normal file
22
plugins/wasm-go/extensions/jwt-auth/go.mod
Normal file
@@ -0,0 +1,22 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/jwt-auth
|
||||
|
||||
go 1.19
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||
github.com/go-jose/go-jose/v3 v3.0.3
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
|
||||
github.com/tidwall/gjson v1.17.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.15.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user