mirror of
https://github.com/alibaba/higress.git
synced 2026-03-10 03:30:48 +08:00
Compare commits
24 Commits
update-hel
...
f2fcd68ef8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2fcd68ef8 | ||
|
|
cbcc3ecf43 | ||
|
|
a92c89ce61 | ||
|
|
819f773297 | ||
|
|
255f0bde76 | ||
|
|
a2eb599eff | ||
|
|
3a28a9b6a7 | ||
|
|
399d2f372e | ||
|
|
ac69eb5b27 | ||
|
|
9d8a1c2e95 | ||
|
|
fb71d7b33d | ||
|
|
eb7b22d2b9 | ||
|
|
f1a5f18c78 | ||
|
|
e7010256fe | ||
|
|
5e787b3258 | ||
|
|
23fbe0e9e9 | ||
|
|
72c87b3e15 | ||
|
|
78d4b33424 | ||
|
|
b09793c3d4 | ||
|
|
5d7a30783f | ||
|
|
b98b51ef06 | ||
|
|
9c11c5406f | ||
|
|
10ca6d9515 | ||
|
|
08a7204085 |
13
README.md
13
README.md
@@ -45,7 +45,7 @@ Higress was born within Alibaba to solve the issues of Tengine reload affecting
|
||||
|
||||
You can click the button below to install the enterprise version of Higress:
|
||||
|
||||
[](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
[](https://www.aliyun.com/product/api-gateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
|
||||
If you use open-source Higress and wish to obtain enterprise-level support, you can contact the project maintainer johnlanni's email: **zty98751@alibaba-inc.com** or social media accounts (WeChat ID: **nomadao**, DingTalk ID: **chengtanzty**). Please note **Higress** when adding as a friend :)
|
||||
@@ -119,7 +119,16 @@ If you are deploying on the cloud, it is recommended to use the [Enterprise Edit
|
||||
|
||||
Higress can function as a feature-rich ingress controller, which is compatible with many annotations of K8s' nginx ingress controller.
|
||||
|
||||
[Gateway API](https://gateway-api.sigs.k8s.io/) support is coming soon and will support smooth migration from Ingress API to Gateway API.
|
||||
[Gateway API](https://gateway-api.sigs.k8s.io/) is already supported, and it supports a smooth migration from Ingress API to Gateway API.
|
||||
|
||||
Compared to ingress-nginx, the resource overhead has significantly decreased, and the speed at which route changes take effect has improved by ten times.
|
||||
|
||||
> The following resource overhead comparison comes from [sealos](https://github.com/labring).
|
||||
>
|
||||
> For details, you can read this [article](https://sealos.io/blog/sealos-envoy-vs-nginx-2000-tenants) to understand how sealos migrates the monitoring of **tens of thousands of ingress** resources from nginx ingress to higress.
|
||||
|
||||

|
||||
|
||||
|
||||
- **Microservice gateway**:
|
||||
|
||||
|
||||
@@ -250,6 +250,10 @@ template:
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 6 }}
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.topologySpreadConstraints }}
|
||||
topologySpreadConstraints:
|
||||
{{- toYaml . | nindent 6 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- emptyDir: {}
|
||||
name: workload-socket
|
||||
|
||||
@@ -301,6 +301,10 @@ spec:
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.controller.topologySpreadConstraints }}
|
||||
topologySpreadConstraints:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: log
|
||||
emptyDir: {}
|
||||
|
||||
@@ -24,9 +24,6 @@ spec:
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.service.externalTrafficPolicy }}
|
||||
externalTrafficPolicy: "{{ . }}"
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.service.loadBalancerClass}}
|
||||
loadBalancerClass: "{{ . }}"
|
||||
{{- end }}
|
||||
type: {{ .Values.gateway.service.type }}
|
||||
ports:
|
||||
|
||||
@@ -524,6 +524,8 @@ gateway:
|
||||
|
||||
affinity: {}
|
||||
|
||||
topologySpreadConstraints: []
|
||||
|
||||
# -- If specified, the gateway will act as a network gateway for the given network.
|
||||
networkGateway: ""
|
||||
|
||||
@@ -631,6 +633,8 @@ controller:
|
||||
|
||||
affinity: {}
|
||||
|
||||
topologySpreadConstraints: []
|
||||
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
|
||||
@@ -83,6 +83,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| controller.serviceAccount.name | string | `""` | If not set and create is true, a name is generated using the fullname template |
|
||||
| controller.tag | string | `""` | |
|
||||
| controller.tolerations | list | `[]` | |
|
||||
| controller.topologySpreadConstraints | list | `[]` | |
|
||||
| downstream | object | `{"connectionBufferLimits":32768,"http2":{"initialConnectionWindowSize":1048576,"initialStreamWindowSize":65535,"maxConcurrentStreams":100},"idleTimeout":180,"maxRequestHeadersKb":60,"routeTimeout":0}` | Downstream config settings |
|
||||
| gateway.affinity | object | `{}` | |
|
||||
| gateway.annotations | object | `{}` | Annotations to apply to all resources |
|
||||
@@ -152,6 +153,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
|
||||
| gateway.tag | string | `""` | |
|
||||
| gateway.tolerations | list | `[]` | |
|
||||
| gateway.topologySpreadConstraints | list | `[]` | |
|
||||
| gateway.unprivilegedPortSupported | string | `nil` | |
|
||||
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
|
||||
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
|
||||
|
||||
Submodule istio/istio updated: 3d7792ae28...c4703274ca
@@ -16,12 +16,13 @@ package bootstrap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"istio.io/istio/pkg/config/mesh/meshwatcher"
|
||||
"istio.io/istio/pkg/kube/krt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"istio.io/istio/pkg/config/mesh/meshwatcher"
|
||||
"istio.io/istio/pkg/kube/krt"
|
||||
|
||||
prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
@@ -436,10 +437,17 @@ func (s *Server) initHttpServer() error {
|
||||
}
|
||||
s.xdsServer.AddDebugHandlers(s.httpMux, nil, true, nil)
|
||||
s.httpMux.HandleFunc("/ready", s.readyHandler)
|
||||
s.httpMux.HandleFunc("/registry/watcherStatus", s.registryWatcherStatusHandler)
|
||||
s.httpMux.HandleFunc("/registry/watcherStatus", s.withConditionalAuth(s.registryWatcherStatusHandler))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withConditionalAuth(handler http.HandlerFunc) http.HandlerFunc {
|
||||
if features.DebugAuth {
|
||||
return s.xdsServer.AllowAuthenticatedOrLocalhost(handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
// readyHandler checks whether the http server is ready
|
||||
func (s *Server) readyHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
for name, fn := range s.readinessProbes {
|
||||
|
||||
@@ -26,8 +26,8 @@ type config struct {
|
||||
matchList []common.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
defaultServer *common.SSEServer
|
||||
redisClient *common.RedisClient
|
||||
sharedMCPServer *common.MCPServer // Created once, thread-safe with sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
@@ -110,6 +110,9 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
GlobalSSEPathSuffix = ssePathSuffix
|
||||
|
||||
// Create shared MCPServer once during config parsing (thread-safe with sync.RWMutex)
|
||||
conf.sharedMCPServer = common.NewMCPServer(DefaultServerName, Version)
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -125,9 +128,6 @@ func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
if childConfig.rateLimitConfig != nil {
|
||||
newConfig.rateLimitConfig = childConfig.rateLimitConfig
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ type filter struct {
|
||||
skipRequestBody bool
|
||||
skipResponseBody bool
|
||||
cachedResponseBody []byte
|
||||
sseServer *common.SSEServer // SSE server instance for this filter (per-request, not shared)
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
@@ -135,11 +136,13 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
||||
trimmed += "?" + rq
|
||||
}
|
||||
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
// Create SSE server instance for this filter (per-request, not shared)
|
||||
// MCPServer is shared (thread-safe), but SSEServer must be per-request (contains request-specific messageEndpoint)
|
||||
f.sseServer = common.NewSSEServer(f.config.sharedMCPServer,
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(trimmed),
|
||||
common.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
f.serverName = f.sseServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
@@ -275,9 +278,9 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
|
||||
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
// handle default server
|
||||
// handle SSE server for this filter instance
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
f.sseServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
} else {
|
||||
_ = buffer.SetString(RedisNotEnabledResponseBody)
|
||||
|
||||
@@ -16,40 +16,91 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `local seed = KEYS[1]
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
|
||||
RedisLua = `local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local current_target = KEYS[3]
|
||||
local current_count = 0
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
local function randomBool()
|
||||
return math.random() >= 0.5
|
||||
end
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
if redis.call('HEXISTS', hset_key, current_target) == 1 then
|
||||
current_count = redis.call('HGET', hset_key, current_target)
|
||||
for i = 4, #KEYS do
|
||||
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
|
||||
local count = redis.call('HGET', hset_key, KEYS[i])
|
||||
if tonumber(count) < tonumber(current_count) then
|
||||
current_target = KEYS[i]
|
||||
current_count = count
|
||||
elseif count == current_count and randomBool() then
|
||||
current_target = KEYS[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
return current_target`
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}`
|
||||
)
|
||||
|
||||
type GlobalLeastRequestLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
redisClient wrapper.RedisClient
|
||||
maxRequestCount int64
|
||||
cleanInterval int64 // seconds
|
||||
enableDetailLog bool
|
||||
}
|
||||
|
||||
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
|
||||
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
lb.maxRequestCount = json.Get("maxRequestCount").Int()
|
||||
lb.cleanInterval = json.Get("cleanInterval").Int()
|
||||
if lb.cleanInterval == 0 {
|
||||
lb.cleanInterval = 60 * 60 // default 60 minutes
|
||||
} else {
|
||||
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
|
||||
}
|
||||
lb.enableDetailLog = true
|
||||
if val := json.Get("enableDetailLog"); val.Exists() {
|
||||
lb.enableDetailLog = val.Bool()
|
||||
}
|
||||
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
allHostMap := make(map[string]struct{})
|
||||
// Only healthy host can be selected
|
||||
healthyHostArray := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
allHostMap[hostInfo[0]] = struct{}{}
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHostArray = append(healthyHostArray, hostInfo[0])
|
||||
}
|
||||
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
}
|
||||
randomIndex := rand.Intn(len(healthyHostArray))
|
||||
hostSelected := healthyHostArray[randomIndex]
|
||||
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
|
||||
|
||||
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
|
||||
keys := []interface{}{
|
||||
time.Now().UnixMicro(),
|
||||
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
|
||||
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
|
||||
lb.cleanInterval,
|
||||
hostSelected,
|
||||
len(healthyHostArray),
|
||||
"0",
|
||||
}
|
||||
if lb.enableDetailLog {
|
||||
keys[6] = "1"
|
||||
}
|
||||
for _, v := range healthyHostArray {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
|
||||
for host := range allHostMap {
|
||||
isHealthy := false
|
||||
for _, hh := range healthyHostArray {
|
||||
if host == hh {
|
||||
isHealthy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isHealthy {
|
||||
keys = append(keys, host)
|
||||
}
|
||||
}
|
||||
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("HGetAll failed: %+v", err)
|
||||
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = response.String()
|
||||
valArray := response.Array()
|
||||
if len(valArray) < 2 {
|
||||
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
|
||||
ctx.SetContext("error", true)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = valArray[0].String()
|
||||
currentCount := valArray[1].Integer()
|
||||
|
||||
// detail log
|
||||
if lb.enableDetailLog && len(valArray) >= 3 {
|
||||
detailLogStr := "host and count: "
|
||||
details := valArray[2].Array()
|
||||
for i := 0; i+1 < len(details); i += 2 {
|
||||
h := details[i].String()
|
||||
c := details[i+1].String()
|
||||
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
|
||||
}
|
||||
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
|
||||
ctx.SetContext("error", true)
|
||||
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
|
||||
// return 429
|
||||
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
return
|
||||
}
|
||||
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
|
||||
// finally resume the request
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
if err != nil {
|
||||
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
-- Mocking Redis environment
|
||||
local redis_data = {
|
||||
hset = {},
|
||||
kv = {}
|
||||
}
|
||||
|
||||
local redis = {
|
||||
call = function(cmd, ...)
|
||||
local args = {...}
|
||||
if cmd == "HGET" then
|
||||
local key, field = args[1], args[2]
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HSET" then
|
||||
local key, field, val = args[1], args[2], args[3]
|
||||
redis_data.hset[field] = val
|
||||
elseif cmd == "HINCRBY" then
|
||||
local key, field, increment = args[1], args[2], args[3]
|
||||
local val = tonumber(redis_data.hset[field] or 0)
|
||||
redis_data.hset[field] = tostring(val + increment)
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HKEYS" then
|
||||
local keys = {}
|
||||
for k, _ in pairs(redis_data.hset) do
|
||||
table.insert(keys, k)
|
||||
end
|
||||
return keys
|
||||
elseif cmd == "HDEL" then
|
||||
local key, field = args[1], args[2]
|
||||
redis_data.hset[field] = nil
|
||||
elseif cmd == "GET" then
|
||||
return redis_data.kv[args[1]]
|
||||
elseif cmd == "HMGET" then
|
||||
local key = args[1]
|
||||
local res = {}
|
||||
for i = 2, #args do
|
||||
table.insert(res, redis_data.hset[args[i]])
|
||||
end
|
||||
return res
|
||||
elseif cmd == "SET" then
|
||||
redis_data.kv[args[1]] = args[2]
|
||||
end
|
||||
end
|
||||
}
|
||||
|
||||
-- The actual logic from lb_policy.go
|
||||
local function run_lb_logic(KEYS)
|
||||
local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}
|
||||
end
|
||||
|
||||
-- --- Test 1: Load Balancing Distribution ---
|
||||
print("--- Test 1: Load Balancing Distribution ---")
|
||||
local hosts = {"host1", "host2", "host3", "host4", "host5"}
|
||||
local iterations = 100000
|
||||
local results = {}
|
||||
for _, h in ipairs(hosts) do results[h] = 0 end
|
||||
|
||||
-- Reset redis
|
||||
redis_data.hset = {}
|
||||
for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end
|
||||
|
||||
print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts))
|
||||
|
||||
for i = 1, iterations do
|
||||
local initial_host = hosts[math.random(#hosts)]
|
||||
-- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts]
|
||||
local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"}
|
||||
for _, h in ipairs(hosts) do table.insert(keys, h) end
|
||||
|
||||
local res = run_lb_logic(keys)
|
||||
local selected = res[1]
|
||||
results[selected] = results[selected] + 1
|
||||
end
|
||||
|
||||
for _, h in ipairs(hosts) do
|
||||
local percentage = (results[h] / iterations) * 100
|
||||
print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage))
|
||||
end
|
||||
|
||||
-- --- Test 2: IP Cleanup Logic ---
|
||||
print("\n--- Test 2: IP Cleanup Logic ---")
|
||||
|
||||
local function test_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["host2"] = "5",
|
||||
["old_ip_1"] = "1",
|
||||
["old_ip_2"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s
|
||||
|
||||
local current_hosts = {"host1", "host2"}
|
||||
local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say)
|
||||
local clean_interval = 300
|
||||
|
||||
print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", "))
|
||||
|
||||
-- Run logic (seed is microtime)
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", "))
|
||||
|
||||
local exists_old1 = redis_data.hset["old_ip_1"] ~= nil
|
||||
local exists_old2 = redis_data.hset["old_ip_2"] ~= nil
|
||||
|
||||
if not exists_old1 and not exists_old2 then
|
||||
print("Success: Outdated IPs removed.")
|
||||
else
|
||||
print("Failure: Outdated IPs still exist.")
|
||||
end
|
||||
|
||||
print("New last_clean_time:", redis_data.kv["clean_key"])
|
||||
end
|
||||
|
||||
test_cleanup()
|
||||
|
||||
-- --- Test 3: No Cleanup if Interval Not Reached ---
|
||||
print("\n--- Test 3: No Cleanup if Interval Not Reached ---")
|
||||
|
||||
local function test_no_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["old_ip_1"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000"
|
||||
|
||||
local current_hosts = {"host1"}
|
||||
local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached)
|
||||
local clean_interval = 300
|
||||
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
if redis_data.hset["old_ip_1"] then
|
||||
print("Success: Cleanup not triggered as expected.")
|
||||
else
|
||||
print("Failure: Cleanup triggered unexpectedly.")
|
||||
end
|
||||
end
|
||||
|
||||
test_no_cleanup()
|
||||
@@ -0,0 +1,24 @@
|
||||
package global_least_request
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool {
|
||||
// 如果没有配置最大请求数,直接通过
|
||||
if lb.maxRequestCount <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果当前请求数大于最大请求数,则限流
|
||||
// 注意:Lua脚本已经加了1,所以这里比较的是加1后的值
|
||||
if currentCount > lb.maxRequestCount {
|
||||
// 恢复 Redis 计数
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -26,6 +26,8 @@ description: AI 代理插件配置参考
|
||||
|
||||
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body,再转换为对应 LLM 厂商的文本向量协议
|
||||
|
||||
> 请求路径后缀匹配 `/v1/images/generations` 时,对应文生图场景,会用 OpenAI 的图片生成协议解析请求 Body,再转换为对应 LLM 厂商的图片生成协议
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
@@ -309,7 +311,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
|
||||
|
||||
#### Google Vertex AI
|
||||
|
||||
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下:
|
||||
Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式:
|
||||
|
||||
**标准模式**(使用 Service Account):
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
@@ -320,25 +324,56 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
|
||||
|
||||
**Express Mode**(使用 API Key,简化配置):
|
||||
|
||||
Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key,从 Google Cloud Console 的 API & Services > Credentials 获取 |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
|
||||
**OpenAI 兼容模式**(使用 Vertex AI Chat Completions API):
|
||||
|
||||
Vertex AI 提供了 OpenAI 兼容的 Chat Completions API 端点,可以直接使用 OpenAI 格式的请求和响应,无需进行协议转换。详见 [Vertex AI OpenAI 兼容性文档](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview)。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `vertexOpenAICompatible` | boolean | 非必填 | false | 启用 OpenAI 兼容模式。启用后将使用 Vertex AI 的 OpenAI-compatible Chat Completions API |
|
||||
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key |
|
||||
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等) |
|
||||
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID |
|
||||
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称 |
|
||||
|
||||
**注意**:OpenAI 兼容模式与 Express Mode 互斥,不能同时配置 `apiTokens` 和 `vertexOpenAICompatible`。
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下:
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|--------|------|-----|------------------------------|
|
||||
| `modelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
|
||||
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
|
||||
1. **AWS Signature V4 认证**:使用 `awsAccessKey` 和 `awsSecretKey` 进行 AWS 标准签名认证
|
||||
2. **Bearer Token 认证**:使用 `apiTokens` 配置 AWS Bearer Token(适用于 IAM Identity Center 等场景)
|
||||
|
||||
**注意**:两种认证方式二选一,如果同时配置了 `apiTokens`,将优先使用 Bearer Token 认证方式。
|
||||
|
||||
它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|---------------|-------------------|-------|---------------------------------------------------|
|
||||
| `apiTokens` | array of string | 与 ak/sk 二选一 | - | AWS Bearer Token,用于 Bearer Token 认证方式 |
|
||||
| `awsAccessKey` | string | 与 apiTokens 二选一 | - | AWS Access Key,用于 AWS Signature V4 认证 |
|
||||
| `awsSecretKey` | string | 与 apiTokens 二选一 | - | AWS Secret Access Key,用于 AWS Signature V4 认证 |
|
||||
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
|
||||
|
||||
#### NVIDIA Triton Interference Server
|
||||
|
||||
NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|--------|------|-----|------------------------------|
|
||||
| `awsAccessKey` | string | 必填 | - | AWS Access Key,用于身份认证 |
|
||||
| `awsSecretKey` | string | 必填 | - | AWS Secret Access Key,用于身份认证 |
|
||||
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------------|--------|--------|-------|------------------------------------------|
|
||||
| `tritonModelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
|
||||
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
|
||||
|
||||
## 用法示例
|
||||
|
||||
@@ -1947,7 +1982,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式)
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -2009,8 +2044,236 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(Express Mode)
|
||||
|
||||
Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是 Gemini,由 Google 开发的人工智能助手。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(OpenAI 兼容模式)
|
||||
|
||||
OpenAI 兼容模式使用 Vertex AI 的 OpenAI-compatible Chat Completions API,请求和响应都使用 OpenAI 格式,无需进行协议转换。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexOpenAICompatible: true
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
modelMapping:
|
||||
"gpt-4": "gemini-2.0-flash"
|
||||
"*": "gemini-1.5-flash"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是由 Google 开发的 Gemini 模型。我可以帮助回答问题、提供信息和进行对话。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 47
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 图片生成服务
|
||||
|
||||
Vertex AI 支持使用 Gemini 模型进行图片生成。通过 ai-proxy 插件,可以使用 OpenAI 的 `/v1/images/generations` 接口协议来调用 Vertex AI 的图片生成能力。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
modelMapping:
|
||||
"dall-e-3": "gemini-2.0-flash-exp"
|
||||
geminiSafetySetting:
|
||||
HARM_CATEGORY_HARASSMENT: "OFF"
|
||||
HARM_CATEGORY_HATE_SPEECH: "OFF"
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
|
||||
```
|
||||
|
||||
**使用 curl 请求**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://your-gateway-address/v1/images/generations" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "一只可爱的橘猫在阳光下打盹",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
**使用 OpenAI Python SDK**
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="any-value", # 可以是任意值,认证由网关处理
|
||||
base_url="http://your-gateway-address/v1"
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model="gemini-2.0-flash-exp",
|
||||
prompt="一只可爱的橘猫在阳光下打盹",
|
||||
size="1024x1024",
|
||||
n=1
|
||||
)
|
||||
|
||||
# 获取生成的图片(base64 编码)
|
||||
image_data = response.data[0].b64_json
|
||||
print(f"Generated image (base64): {image_data[:100]}...")
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"created": 1729986750,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 1356,
|
||||
"input_tokens": 13,
|
||||
"output_tokens": 1120
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**支持的尺寸参数**
|
||||
|
||||
Vertex AI 支持的宽高比(aspectRatio):`1:1`、`3:2`、`2:3`、`3:4`、`4:3`、`4:5`、`5:4`、`9:16`、`16:9`、`21:9`
|
||||
|
||||
Vertex AI 支持的分辨率(imageSize):`1k`、`2k`、`4k`
|
||||
|
||||
| OpenAI size 参数 | Vertex AI aspectRatio | Vertex AI imageSize |
|
||||
|------------------|----------------------|---------------------|
|
||||
| 256x256 | 1:1 | 1k |
|
||||
| 512x512 | 1:1 | 1k |
|
||||
| 1024x1024 | 1:1 | 1k |
|
||||
| 1792x1024 | 16:9 | 2k |
|
||||
| 1024x1792 | 9:16 | 2k |
|
||||
| 2048x2048 | 1:1 | 2k |
|
||||
| 4096x4096 | 1:1 | 4k |
|
||||
| 1536x1024 | 3:2 | 2k |
|
||||
| 1024x1536 | 2:3 | 2k |
|
||||
| 1024x768 | 4:3 | 1k |
|
||||
| 768x1024 | 3:4 | 1k |
|
||||
| 1280x1024 | 5:4 | 1k |
|
||||
| 1024x1280 | 4:5 | 1k |
|
||||
| 2560x1080 | 21:9 | 2k |
|
||||
|
||||
**注意事项**
|
||||
|
||||
- 图片生成使用 Gemini 模型(如 `gemini-2.0-flash-exp`、`gemini-3-pro-image-preview`),不同模型的可用性可能因区域而异
|
||||
- 返回的图片数据为 base64 编码格式(`b64_json`)
|
||||
- 可以通过 `geminiSafetySetting` 配置内容安全过滤级别
|
||||
- 如果需要使用模型映射(如将 `dall-e-3` 映射到 Gemini 模型),可以配置 `modelMapping`
|
||||
|
||||
### 使用 OpenAI 协议代理 AWS Bedrock 服务
|
||||
|
||||
AWS Bedrock 支持两种认证方式:
|
||||
|
||||
#### 方式一:使用 AWS Access Key/Secret Key 认证(AWS Signature V4)
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
@@ -2018,7 +2281,21 @@ provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
#### 方式二:使用 Bearer Token 认证(适用于 IAM Identity Center 等场景)
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
apiTokens:
|
||||
- "YOUR_AWS_BEARER_TOKEN"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
@@ -2027,7 +2304,7 @@ provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -25,6 +25,8 @@ The plugin now supports **automatic protocol detection**, allowing seamless comp
|
||||
|
||||
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.
|
||||
|
||||
> When the request path suffix matches `/v1/images/generations`, it corresponds to text-to-image scenarios. The request body will be parsed using OpenAI's image generation protocol and then converted to the corresponding LLM vendor's image generation protocol.
|
||||
|
||||
## Execution Properties
|
||||
Plugin execution phase: `Default Phase`
|
||||
Plugin execution priority: `100`
|
||||
@@ -255,7 +257,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
|
||||
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
|
||||
|
||||
#### Google Vertex AI
|
||||
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
|
||||
For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes:
|
||||
|
||||
**Standard Mode** (using Service Account):
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
@@ -266,16 +270,47 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
|
||||
|
||||
**Express Mode** (using API Key, simplified configuration):
|
||||
|
||||
Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview).
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials |
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
|
||||
**OpenAI Compatible Mode** (using Vertex AI Chat Completions API):
|
||||
|
||||
Vertex AI provides an OpenAI-compatible Chat Completions API endpoint, allowing you to use OpenAI format requests and responses directly without protocol conversion. See [Vertex AI OpenAI Compatibility documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview).
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `vertexOpenAICompatible` | boolean | Optional | false | Enable OpenAI compatible mode. When enabled, uses Vertex AI's OpenAI-compatible Chat Completions API |
|
||||
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key for authentication |
|
||||
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) |
|
||||
| `vertexProjectId` | string | Required | - | Google Cloud Project ID |
|
||||
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication |
|
||||
|
||||
**Note**: OpenAI Compatible Mode and Express Mode are mutually exclusive. You cannot configure both `apiTokens` and `vertexOpenAICompatible` at the same time.
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration field is:
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|---------------------------|-----------|-------------|---------|---------------------------------------------------------|
|
||||
| `awsAccessKey` | string | Required | - | AWS Access Key used for authentication |
|
||||
| `awsSecretKey` | string | Required | - | AWS Secret Access Key used for authentication |
|
||||
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
|
||||
1. **AWS Signature V4 Authentication**: Uses `awsAccessKey` and `awsSecretKey` for standard AWS signature authentication
|
||||
2. **Bearer Token Authentication**: Uses `apiTokens` to configure AWS Bearer Token (suitable for IAM Identity Center and similar scenarios)
|
||||
|
||||
**Note**: Choose one of the two authentication methods. If `apiTokens` is configured, Bearer Token authentication will be used preferentially.
|
||||
|
||||
Its unique configuration fields are:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|---------------------------|-----------------|--------------------------|---------|-------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | Either this or ak/sk | - | AWS Bearer Token for Bearer Token authentication |
|
||||
| `awsAccessKey` | string | Either this or apiTokens | - | AWS Access Key for AWS Signature V4 authentication |
|
||||
| `awsSecretKey` | string | Either this or apiTokens | - | AWS Secret Access Key for AWS Signature V4 authentication |
|
||||
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -1720,7 +1755,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode)
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
@@ -1778,14 +1813,250 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode)
|
||||
|
||||
Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly.
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (OpenAI Compatible Mode)
|
||||
|
||||
OpenAI Compatible Mode uses Vertex AI's OpenAI-compatible Chat Completions API. Both requests and responses use OpenAI format, requiring no protocol conversion.
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexOpenAICompatible: true
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
modelMapping:
|
||||
"gpt-4": "gemini-2.0-flash"
|
||||
"*": "gemini-1.5-flash"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am Gemini, an AI model developed by Google. I can help answer questions, provide information, and engage in conversations. How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 47
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Image Generation
|
||||
|
||||
Vertex AI supports image generation using Gemini models. Through the ai-proxy plugin, you can use OpenAI's `/v1/images/generations` API to call Vertex AI's image generation capabilities.
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
modelMapping:
|
||||
"dall-e-3": "gemini-2.0-flash-exp"
|
||||
geminiSafetySetting:
|
||||
HARM_CATEGORY_HARASSMENT: "OFF"
|
||||
HARM_CATEGORY_HATE_SPEECH: "OFF"
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
|
||||
```
|
||||
|
||||
**Using curl**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://your-gateway-address/v1/images/generations" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "A cute orange cat napping in the sunshine",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
**Using OpenAI Python SDK**
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="any-value", # Can be any value, authentication is handled by the gateway
|
||||
base_url="http://your-gateway-address/v1"
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model="gemini-2.0-flash-exp",
|
||||
prompt="A cute orange cat napping in the sunshine",
|
||||
size="1024x1024",
|
||||
n=1
|
||||
)
|
||||
|
||||
# Get the generated image (base64 encoded)
|
||||
image_data = response.data[0].b64_json
|
||||
print(f"Generated image (base64): {image_data[:100]}...")
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
|
||||
```json
|
||||
{
|
||||
"created": 1729986750,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 1356,
|
||||
"input_tokens": 13,
|
||||
"output_tokens": 1120
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Supported Size Parameters**
|
||||
|
||||
Vertex AI supported aspect ratios: `1:1`, `3:2`, `2:3`, `3:4`, `4:3`, `4:5`, `5:4`, `9:16`, `16:9`, `21:9`
|
||||
|
||||
Vertex AI supported resolutions (imageSize): `1k`, `2k`, `4k`
|
||||
|
||||
| OpenAI size parameter | Vertex AI aspectRatio | Vertex AI imageSize |
|
||||
|-----------------------|----------------------|---------------------|
|
||||
| 256x256 | 1:1 | 1k |
|
||||
| 512x512 | 1:1 | 1k |
|
||||
| 1024x1024 | 1:1 | 1k |
|
||||
| 1792x1024 | 16:9 | 2k |
|
||||
| 1024x1792 | 9:16 | 2k |
|
||||
| 2048x2048 | 1:1 | 2k |
|
||||
| 4096x4096 | 1:1 | 4k |
|
||||
| 1536x1024 | 3:2 | 2k |
|
||||
| 1024x1536 | 2:3 | 2k |
|
||||
| 1024x768 | 4:3 | 1k |
|
||||
| 768x1024 | 3:4 | 1k |
|
||||
| 1280x1024 | 5:4 | 1k |
|
||||
| 1024x1280 | 4:5 | 1k |
|
||||
| 2560x1080 | 21:9 | 2k |
|
||||
|
||||
**Notes**
|
||||
|
||||
- Image generation uses Gemini models (e.g., `gemini-2.0-flash-exp`, `gemini-3-pro-image-preview`). Model availability may vary by region
|
||||
- The returned image data is in base64 encoded format (`b64_json`)
|
||||
- Content safety filtering levels can be configured via `geminiSafetySetting`
|
||||
- If you need model mapping (e.g., mapping `dall-e-3` to a Gemini model), configure `modelMapping`
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
|
||||
|
||||
AWS Bedrock supports two authentication methods:
|
||||
|
||||
#### Method 1: Using AWS Access Key/Secret Key Authentication (AWS Signature V4)
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
#### Method 2: Using Bearer Token Authentication (suitable for IAM Identity Center and similar scenarios)
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
apiTokens:
|
||||
- "YOUR_AWS_BEARER_TOKEN"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
@@ -1793,7 +2064,7 @@ provider:
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -8,7 +8,7 @@ toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
|
||||
@@ -128,3 +128,25 @@ func TestGeneric(t *testing.T) {
|
||||
test.RunGenericOnHttpRequestHeadersTests(t)
|
||||
test.RunGenericOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestVertex(t *testing.T) {
|
||||
test.RunVertexParseConfigTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||
// Vertex Raw 模式测试
|
||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexRawModeOnHttpResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockParseConfigTests(t)
|
||||
test.RunBedrockOnHttpRequestHeadersTests(t)
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
}
|
||||
|
||||
@@ -43,8 +43,11 @@ const (
|
||||
type bedrockProviderInitializer struct{}
|
||||
|
||||
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
|
||||
return errors.New("missing bedrock access authentication parameters")
|
||||
hasAkSk := len(config.awsAccessKey) > 0 && len(config.awsSecretKey) > 0
|
||||
hasApiToken := len(config.apiTokens) > 0
|
||||
|
||||
if !hasAkSk && !hasApiToken {
|
||||
return errors.New("missing bedrock access authentication parameters: either apiTokens or (awsAccessKey + awsSecretKey) is required")
|
||||
}
|
||||
if len(config.awsRegion) == 0 {
|
||||
return errors.New("missing bedrock region parameters")
|
||||
@@ -634,6 +637,13 @@ func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
|
||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
||||
|
||||
// If apiTokens is configured, set Bearer token authentication here
|
||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||
// AWS SigV4 authentication is handled in setAuthHeaders because it requires the request body
|
||||
if len(b.config.apiTokens) > 0 {
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+b.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
@@ -659,18 +669,18 @@ func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationResponseBody(ctx, body)
|
||||
return b.onImageGenerationResponseBody(body)
|
||||
}
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
func (b *bedrockProvider) onImageGenerationResponseBody(body []byte) ([]byte, error) {
|
||||
bedrockResponse := &bedrockImageGenerationResponse{}
|
||||
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
|
||||
}
|
||||
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
|
||||
response := b.buildBedrockImageGenerationResponse(bedrockResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
@@ -710,7 +720,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
return requestBytes, err
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
data := make([]imageGenerationData, len(bedrockResponse.Images))
|
||||
for i, image := range bedrockResponse.Images {
|
||||
data[i] = imageGenerationData{
|
||||
@@ -1138,6 +1148,13 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
// Bearer token authentication is already set in TransformRequestHeaders
|
||||
// This function only handles AWS SigV4 authentication which requires the request body
|
||||
if len(b.config.apiTokens) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use AWS Signature V4 authentication
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
@@ -134,8 +135,63 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
} else {
|
||||
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
|
||||
}
|
||||
|
||||
var token string
|
||||
|
||||
// 1. If apiTokens is configured, use it first
|
||||
if len(m.config.apiTokens) > 0 {
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
token = m.config.GetApiTokenInUse(ctx)
|
||||
if token == "" {
|
||||
log.Warnf("[openaiProvider.TransformRequestHeaders] apiTokens count > 0 but GetApiTokenInUse returned empty")
|
||||
}
|
||||
} else {
|
||||
// If no apiToken is configured, try to extract from original request headers
|
||||
|
||||
// 2. If authHeaderKey is configured, use the specified header
|
||||
if m.config.authHeaderKey != "" {
|
||||
if apiKey, err := proxywasm.GetHttpRequestHeader(m.config.authHeaderKey); err == nil && apiKey != "" {
|
||||
token = apiKey
|
||||
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from configured header: %s", m.config.authHeaderKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. If authHeaderKey is not configured, check default headers in priority order
|
||||
if token == "" {
|
||||
defaultHeaders := []string{"x-api-key", "x-authorization"}
|
||||
for _, headerName := range defaultHeaders {
|
||||
if apiKey, err := proxywasm.GetHttpRequestHeader(headerName); err == nil && apiKey != "" {
|
||||
token = apiKey
|
||||
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from %s header", headerName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Finally check Authorization header
|
||||
if token == "" {
|
||||
if auth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil && auth != "" {
|
||||
// Extract token from "Bearer <token>" format
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token = strings.TrimPrefix(auth, "Bearer ")
|
||||
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from Authorization header (Bearer format)")
|
||||
} else {
|
||||
token = auth
|
||||
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from Authorization header (no Bearer prefix)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Set Authorization header (avoid duplicate Bearer prefix)
|
||||
if token != "" {
|
||||
// Check if token already contains Bearer prefix
|
||||
if !strings.HasPrefix(token, "Bearer ") {
|
||||
token = "Bearer " + token
|
||||
}
|
||||
util.OverwriteRequestAuthorizationHeader(headers, token)
|
||||
log.Debugf("[openaiProvider.TransformRequestHeaders] Set Authorization header successfully")
|
||||
} else {
|
||||
log.Warnf("[openaiProvider.TransformRequestHeaders] No auth token available - neither configured in apiTokens nor in request headers")
|
||||
}
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ const (
|
||||
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
|
||||
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
|
||||
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
||||
ApiNameVertexRaw ApiName = "vertex/raw"
|
||||
|
||||
// OpenAI
|
||||
PathOpenAIPrefix = "/v1"
|
||||
@@ -387,12 +388,18 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Vertex token刷新提前时间
|
||||
// @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒
|
||||
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
|
||||
// @Title zh-CN Vertex AI OpenAI兼容模式
|
||||
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。
|
||||
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
|
||||
// @Title zh-CN 翻译服务需指定的目标语种
|
||||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||||
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
|
||||
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考:https://platform.openai.com/docs/guides/structured-outputs
|
||||
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
|
||||
// @Title zh-CN 自定义认证Header名称
|
||||
// @Description zh-CN 用于从请求中提取认证token的自定义header名称。如不配置,则按默认优先级检查 x-api-key、x-authorization、anthropic-api-key 和 Authorization header。
|
||||
authHeaderKey string `required:"false" yaml:"authHeaderKey" json:"authHeaderKey"`
|
||||
// @Title zh-CN 自定义大模型参数配置
|
||||
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
|
||||
customSettings []CustomSetting
|
||||
@@ -540,6 +547,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
if c.vertexTokenRefreshAhead == 0 {
|
||||
c.vertexTokenRefreshAhead = 60
|
||||
}
|
||||
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
|
||||
c.targetLang = json.Get("targetLang").String()
|
||||
|
||||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||||
|
||||
@@ -21,14 +21,21 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
// Express Mode 路径模板 (不含 project/location)
|
||||
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
|
||||
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
|
||||
// OpenAI-compatible endpoint 路径模板
|
||||
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
|
||||
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
@@ -36,12 +43,44 @@ const (
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexGlobalRegion = "global"
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
|
||||
// 允许任意 basePath 前缀,兼容 basePathHandling 配置
|
||||
var vertexRawPathRegex = regexp.MustCompile(`^.*/([^/]+)/projects/([^/]+)/locations/([^/]+)/publishers/([^/]+)/models/([^/:]+):([^/?]+)`)
|
||||
|
||||
type vertexProviderInitializer struct{}
|
||||
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// Express Mode: 如果配置了 apiTokens,则使用 API Key 认证
|
||||
if len(config.apiTokens) > 0 {
|
||||
// Express Mode 与 OpenAI 兼容模式互斥
|
||||
if config.vertexOpenAICompatible {
|
||||
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
|
||||
}
|
||||
// Express Mode 不需要其他配置
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI 兼容模式: 需要 OAuth 认证配置
|
||||
if config.vertexOpenAICompatible {
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexRegion == "" || config.vertexProjectId == "" {
|
||||
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexAuthServiceName == "" {
|
||||
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准模式: 保持原有验证逻辑
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
}
|
||||
@@ -56,26 +95,47 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(v.DefaultCapabilities())
|
||||
return &vertexProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
}),
|
||||
|
||||
provider := &vertexProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
claude: &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置)
|
||||
if !provider.isExpressMode() {
|
||||
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
})
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// isExpressMode 检测是否启用 Express Mode
|
||||
// 如果配置了 apiTokens,则使用 Express Mode(API Key 认证)
|
||||
func (v *vertexProvider) isExpressMode() bool {
|
||||
return len(v.config.apiTokens) > 0
|
||||
}
|
||||
|
||||
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
|
||||
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
|
||||
func (v *vertexProvider) isOpenAICompatibleMode() bool {
|
||||
return v.config.vertexOpenAICompatible
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
@@ -90,6 +150,12 @@ func (v *vertexProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (v *vertexProvider) GetApiName(path string) ApiName {
|
||||
// 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀
|
||||
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
|
||||
// 必须在其他 action 检查之前,因为 :predict、:generateContent 等 action 会被其他规则匹配
|
||||
if vertexRawPathRegex.MatchString(path) {
|
||||
return ApiNameVertexRaw
|
||||
}
|
||||
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
@@ -106,11 +172,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
var finalVertexDomain string
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 固定域名,不带 region 前缀
|
||||
finalVertexDomain = vertexDomain
|
||||
} else {
|
||||
// 标准模式: 带 region 前缀
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
finalVertexDomain = vertexDomain
|
||||
}
|
||||
}
|
||||
|
||||
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
|
||||
}
|
||||
|
||||
@@ -150,12 +224,66 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if !v.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
// Vertex Raw 模式: 透传请求体,只做 OAuth 认证
|
||||
// 用于直接访问 Vertex AI REST API,不做协议转换
|
||||
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
|
||||
if apiName == ApiNameVertexRaw {
|
||||
ctx.SetContext(contextVertexRawMarker, true)
|
||||
// Express Mode 不需要 OAuth 认证
|
||||
if v.isExpressMode() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
// 标准模式需要获取 OAuth token
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
if v.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
headers := util.GetRequestHeaders()
|
||||
|
||||
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
|
||||
if v.isOpenAICompatibleMode() {
|
||||
ctx.SetContext(contextOpenAICompatibleMarker, true)
|
||||
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
// OpenAI 兼容模式需要 OAuth token
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 不需要 Authorization header,API Key 已在 URL 中
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 标准模式: 需要获取 OAuth token
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -172,13 +300,44 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
|
||||
// 不转换请求体格式,只进行模型映射和路径设置
|
||||
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
|
||||
}
|
||||
|
||||
// 解析请求进行模型映射
|
||||
request := &chatCompletionRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置 OpenAI 兼容端点路径
|
||||
path := v.getOpenAICompatibleRequestPath()
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
// 如果模型被映射,需要更新请求体中的模型字段
|
||||
if request.Model != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", request.Model)
|
||||
}
|
||||
|
||||
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := v.config.parseRequestAndMapModel(ctx, request, body)
|
||||
@@ -219,7 +378,126 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageGenerationRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 图片生成不使用流式端点,需要完整响应
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, vertexChatSafetySetting{
|
||||
Category: category,
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
mimeType = "image/webp"
|
||||
default:
|
||||
mimeType = "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
Temperature: 1.0,
|
||||
MaxOutputTokens: 32768,
|
||||
ResponseModalities: []string{"TEXT", "IMAGE"},
|
||||
ImageConfig: &vertexImageConfig{
|
||||
AspectRatio: aspectRatio,
|
||||
ImageSize: imageSize,
|
||||
ImageOutputOptions: &vertexImageOutputOptions{
|
||||
MimeType: mimeType,
|
||||
},
|
||||
PersonGeneration: "ALLOW_ALL",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9
|
||||
// Vertex AI 支持的 imageSize: 1k, 2k, 4k
|
||||
func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) {
|
||||
// 默认值
|
||||
aspectRatio = "1:1"
|
||||
imageSize = "1k"
|
||||
|
||||
if size == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 预定义的尺寸映射(OpenAI 标准尺寸)
|
||||
sizeMapping := map[string]struct {
|
||||
aspectRatio string
|
||||
imageSize string
|
||||
}{
|
||||
// OpenAI DALL-E 标准尺寸
|
||||
"256x256": {"1:1", "1k"},
|
||||
"512x512": {"1:1", "1k"},
|
||||
"1024x1024": {"1:1", "1k"},
|
||||
"1792x1024": {"16:9", "2k"},
|
||||
"1024x1792": {"9:16", "2k"},
|
||||
// 扩展尺寸支持
|
||||
"2048x2048": {"1:1", "2k"},
|
||||
"4096x4096": {"1:1", "4k"},
|
||||
// 3:2 和 2:3 比例
|
||||
"1536x1024": {"3:2", "2k"},
|
||||
"1024x1536": {"2:3", "2k"},
|
||||
// 4:3 和 3:4 比例
|
||||
"1024x768": {"4:3", "1k"},
|
||||
"768x1024": {"3:4", "1k"},
|
||||
"1365x1024": {"4:3", "1k"},
|
||||
"1024x1365": {"3:4", "1k"},
|
||||
// 5:4 和 4:5 比例
|
||||
"1280x1024": {"5:4", "1k"},
|
||||
"1024x1280": {"4:5", "1k"},
|
||||
// 21:9 超宽比例
|
||||
"2560x1080": {"21:9", "2k"},
|
||||
}
|
||||
|
||||
if mapping, ok := sizeMapping[size]; ok {
|
||||
return mapping.aspectRatio, mapping.imageSize
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapesInSSE(chunk), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
@@ -260,13 +538,25 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapes(body), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,6 +649,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
||||
return &response
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
// 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据
|
||||
// 这样可以显著减少内存分配和复制次数
|
||||
response := v.buildImageGenerationResponseFromJSON(body)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应
|
||||
// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高
|
||||
func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse {
|
||||
result := gjson.ParseBytes(body)
|
||||
data := make([]imageGenerationData, 0)
|
||||
|
||||
// 遍历所有 candidates,提取图片数据
|
||||
candidates := result.Get("candidates")
|
||||
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||
parts := candidate.Get("content.parts")
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// 跳过思考过程 (thought: true)
|
||||
if part.Get("thought").Bool() {
|
||||
return true
|
||||
}
|
||||
// 提取图片数据
|
||||
inlineData := part.Get("inlineData.data")
|
||||
if inlineData.Exists() && inlineData.String() != "" {
|
||||
data = append(data, imageGenerationData{
|
||||
B64: inlineData.String(),
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
|
||||
// 提取 usage 信息
|
||||
usage := result.Get("usageMetadata")
|
||||
|
||||
return &imageGenerationResponse{
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Data: data,
|
||||
Usage: &imageGenerationUsage{
|
||||
TotalTokens: int(usage.Get("totalTokenCount").Int()),
|
||||
InputTokens: int(usage.Get("promptTokenCount").Int()),
|
||||
OutputTokens: int(usage.Get("candidatesTokenCount").Int()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
choice.Delta = &chatMessage{}
|
||||
@@ -422,19 +760,62 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
} else {
|
||||
action = vertexAnthropicMessageAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
} else if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
case ApiNameImageGeneration:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
|
||||
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
|
||||
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
@@ -521,7 +902,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
})
|
||||
}
|
||||
case contentTypeImageUrl:
|
||||
vpart, err := convertImageContent(part.ImageUrl.Url)
|
||||
vpart, err := convertMediaContent(part.ImageUrl.Url)
|
||||
if err != nil {
|
||||
log.Errorf("unable to convert image content: %v", err)
|
||||
} else {
|
||||
@@ -636,12 +1017,25 @@ type vertexChatSafetySetting struct {
|
||||
}
|
||||
|
||||
type vertexChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
ImageSize string `json:"imageSize,omitempty"`
|
||||
ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"`
|
||||
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageOutputOptions struct {
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
type vertexThinkingConfig struct {
|
||||
@@ -852,32 +1246,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
|
||||
func convertImageContent(imageUrl string) (vertexPart, error) {
|
||||
// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式
|
||||
// 支持图片、视频、音频等多种媒体类型
|
||||
func convertMediaContent(mediaUrl string) (vertexPart, error) {
|
||||
part := vertexPart{}
|
||||
if strings.HasPrefix(imageUrl, "http") {
|
||||
arr := strings.Split(imageUrl, ".")
|
||||
mimeType := "image/" + arr[len(arr)-1]
|
||||
if strings.HasPrefix(mediaUrl, "http") {
|
||||
mimeType := detectMimeTypeFromURL(mediaUrl)
|
||||
part.FileData = &fileData{
|
||||
MimeType: mimeType,
|
||||
FileUri: imageUrl,
|
||||
FileUri: mediaUrl,
|
||||
}
|
||||
return part, nil
|
||||
} else {
|
||||
// Base64 data URL 格式: data:<mimeType>;base64,<data>
|
||||
re := regexp.MustCompile(`^data:([^;]+);base64,`)
|
||||
matches := re.FindStringSubmatch(imageUrl)
|
||||
matches := re.FindStringSubmatch(mediaUrl)
|
||||
if len(matches) < 2 {
|
||||
return part, fmt.Errorf("invalid base64 format")
|
||||
return part, fmt.Errorf("invalid base64 format, expected data:<mimeType>;base64,<data>")
|
||||
}
|
||||
|
||||
mimeType := matches[1] // e.g. image/png
|
||||
mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3
|
||||
parts := strings.Split(mimeType, "/")
|
||||
if len(parts) < 2 {
|
||||
return part, fmt.Errorf("invalid mimeType")
|
||||
return part, fmt.Errorf("invalid mimeType: %s", mimeType)
|
||||
}
|
||||
part.InlineData = &blob{
|
||||
MimeType: mimeType,
|
||||
Data: strings.TrimPrefix(imageUrl, matches[0]),
|
||||
Data: strings.TrimPrefix(mediaUrl, matches[0]),
|
||||
}
|
||||
return part, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型
|
||||
// 支持图片、视频、音频和文档类型
|
||||
func detectMimeTypeFromURL(url string) string {
|
||||
// 移除查询参数和片段标识符
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
if idx := strings.Index(url, "#"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
// 获取最后一个路径段
|
||||
lastSlash := strings.LastIndex(url, "/")
|
||||
if lastSlash != -1 {
|
||||
url = url[lastSlash+1:]
|
||||
}
|
||||
|
||||
// 获取扩展名
|
||||
lastDot := strings.LastIndex(url, ".")
|
||||
if lastDot == -1 || lastDot == len(url)-1 {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
ext := strings.ToLower(url[lastDot+1:])
|
||||
|
||||
// 扩展名到 MIME 类型的映射
|
||||
mimeTypes := map[string]string{
|
||||
// 图片格式
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/x-icon",
|
||||
"heic": "image/heic",
|
||||
"heif": "image/heif",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
// 视频格式
|
||||
"mp4": "video/mp4",
|
||||
"mpeg": "video/mpeg",
|
||||
"mpg": "video/mpeg",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"wmv": "video/x-ms-wmv",
|
||||
"webm": "video/webm",
|
||||
"mkv": "video/x-matroska",
|
||||
"flv": "video/x-flv",
|
||||
"3gp": "video/3gpp",
|
||||
"3g2": "video/3gpp2",
|
||||
"m4v": "video/x-m4v",
|
||||
// 音频格式
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"m4a": "audio/mp4",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"opus": "audio/opus",
|
||||
// 文档格式
|
||||
"pdf": "application/pdf",
|
||||
}
|
||||
|
||||
if mimeType, ok := mimeTypes[ext]; ok {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
527
plugins/wasm-go/extensions/ai-proxy/test/bedrock.go
Normal file
527
plugins/wasm-go/extensions/ai-proxy/test/bedrock.go
Normal file
@@ -0,0 +1,527 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test config: Basic Bedrock config with AWS Access Key/Secret Key (AWS Signature V4)
|
||||
var basicBedrockConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with Bearer Token authentication
|
||||
var bedrockApiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with multiple Bearer Tokens
|
||||
var bedrockMultiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-1-for-unit-test",
|
||||
"test-token-2-for-unit-test",
|
||||
},
|
||||
"awsRegion": "us-west-2",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-4": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"*": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with additional fields
|
||||
var bedrockWithAdditionalFieldsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"bedrockAdditionalFields": map[string]interface{}{
|
||||
"top_k": 200,
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - missing both apiTokens and ak/sk
|
||||
var bedrockInvalidConfigMissingAuth = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - missing region
|
||||
var bedrockInvalidConfigMissingRegion = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - only has access key without secret key
|
||||
var bedrockInvalidConfigPartialAkSk = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunBedrockParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// Test basic Bedrock config with AWS Signature V4 authentication
|
||||
t.Run("basic bedrock config with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with Bearer Token authentication
|
||||
t.Run("bedrock config with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with multiple tokens
|
||||
t.Run("bedrock config with multiple tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockMultiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with additional fields
|
||||
t.Run("bedrock config with additional fields", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockWithAdditionalFieldsConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test invalid config - missing authentication
|
||||
t.Run("bedrock invalid config missing auth", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigMissingAuth)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// Test invalid config - missing region
|
||||
t.Run("bedrock invalid config missing region", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigMissingRegion)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// Test invalid config - partial ak/sk (only access key, no secret key)
|
||||
t.Run("bedrock invalid config partial ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigPartialAkSk)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock request headers processing with AWS Signature V4
|
||||
t.Run("bedrock chat completion request headers with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Verify request headers
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Host is changed to Bedrock service domain
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
|
||||
})
|
||||
|
||||
// Test Bedrock request headers processing with Bearer Token
|
||||
t.Run("bedrock chat completion request headers with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Verify request headers
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Host is changed to Bedrock service domain
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock request body processing with Bearer Token authentication
|
||||
t.Run("bedrock chat completion request body with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify request headers for Bearer Token authentication
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Authorization header uses Bearer token
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
|
||||
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain the configured token")
|
||||
|
||||
// Verify path is transformed to Bedrock format
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
// Test Bedrock request body processing with AWS Signature V4 authentication
|
||||
t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify request headers for AWS Signature V4 authentication
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Authorization header uses AWS Signature
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||
|
||||
// Verify X-Amz-Date header exists
|
||||
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||
|
||||
// Verify path is transformed to Bedrock format
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
// Test Bedrock streaming request
|
||||
t.Run("bedrock streaming request", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set streaming request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify path is transformed to Bedrock streaming format
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse-stream", "Path should contain converse-stream endpoint for streaming")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock response headers processing
|
||||
t.Run("bedrock response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Process response headers
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Amzn-Requestid", "test-request-id-12345"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify response headers
|
||||
responseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, responseHeaders)
|
||||
|
||||
// Verify status code
|
||||
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "200", statusValue, "Status should be 200")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock response body processing
|
||||
t.Run("bedrock response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Set response property to ensure IsResponseFromUpstream() returns true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// Process response headers (must include :status 200 for body processing)
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Process response body (Bedrock format)
|
||||
responseBody := `{
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello! How can I help you today?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 15,
|
||||
"totalTokens": 25
|
||||
}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify response body is transformed to OpenAI format
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedResponseBody, &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify choices exist in transformed response
|
||||
choices, exists := responseMap["choices"]
|
||||
require.True(t, exists, "Choices should exist in response body")
|
||||
require.NotNil(t, choices, "Choices should not be nil")
|
||||
|
||||
// Verify usage exists
|
||||
usage, exists := responseMap["usage"]
|
||||
require.True(t, exists, "Usage should exist in response body")
|
||||
require.NotNil(t, usage, "Usage should not be nil")
|
||||
})
|
||||
})
|
||||
}
|
||||
1585
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
1585
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,10 @@
|
||||
package util
|
||||
|
||||
import "regexp"
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StripPrefix(s string, prefix string) string {
|
||||
if len(prefix) != 0 && len(s) >= len(prefix) && s[0:len(prefix)] == prefix {
|
||||
@@ -18,3 +22,43 @@ func MatchStatus(status string, patterns []string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// unicodeEscapeRegex matches Unicode escape sequences like \uXXXX
|
||||
var unicodeEscapeRegex = regexp.MustCompile(`\\u([0-9a-fA-F]{4})`)
|
||||
|
||||
// DecodeUnicodeEscapes decodes Unicode escape sequences (\uXXXX) in a string to UTF-8 characters.
|
||||
// This is useful when a JSON response contains ASCII-safe encoded non-ASCII characters.
|
||||
func DecodeUnicodeEscapes(input []byte) []byte {
|
||||
result := unicodeEscapeRegex.ReplaceAllFunc(input, func(match []byte) []byte {
|
||||
// match is like \uXXXX, extract the hex part (XXXX)
|
||||
hexStr := string(match[2:6])
|
||||
codePoint, err := strconv.ParseInt(hexStr, 16, 32)
|
||||
if err != nil {
|
||||
return match // return original if parse fails
|
||||
}
|
||||
return []byte(string(rune(codePoint)))
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// DecodeUnicodeEscapesInSSE decodes Unicode escape sequences in SSE formatted data.
|
||||
// It processes each line that starts with "data: " and decodes Unicode escapes in the JSON payload.
|
||||
func DecodeUnicodeEscapesInSSE(input []byte) []byte {
|
||||
lines := strings.Split(string(input), "\n")
|
||||
var result strings.Builder
|
||||
for i, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
// Decode Unicode escapes in the JSON payload
|
||||
jsonData := line[6:]
|
||||
decodedData := DecodeUnicodeEscapes([]byte(jsonData))
|
||||
result.WriteString("data: ")
|
||||
result.Write(decodedData)
|
||||
} else {
|
||||
result.WriteString(line)
|
||||
}
|
||||
if i < len(lines)-1 {
|
||||
result.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return []byte(result.String())
|
||||
}
|
||||
|
||||
108
plugins/wasm-go/extensions/ai-proxy/util/string_test.go
Normal file
108
plugins/wasm-go/extensions/ai-proxy/util/string_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeUnicodeEscapes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Chinese characters",
|
||||
input: `\u4e2d\u6587\u6d4b\u8bd5`,
|
||||
expected: `中文测试`,
|
||||
},
|
||||
{
|
||||
name: "Mixed content",
|
||||
input: `Hello \u4e16\u754c World`,
|
||||
expected: `Hello 世界 World`,
|
||||
},
|
||||
{
|
||||
name: "No escape sequences",
|
||||
input: `Hello World`,
|
||||
expected: `Hello World`,
|
||||
},
|
||||
{
|
||||
name: "JSON with Unicode escapes",
|
||||
input: `{"content":"\u76c8\u5229\u80fd\u529b"}`,
|
||||
expected: `{"content":"盈利能力"}`,
|
||||
},
|
||||
{
|
||||
name: "Full width parentheses",
|
||||
input: `\uff08\u76c8\u5229\uff09`,
|
||||
expected: `(盈利)`,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: ``,
|
||||
expected: ``,
|
||||
},
|
||||
{
|
||||
name: "Invalid escape sequence (not modified)",
|
||||
input: `\u00GG`,
|
||||
expected: `\u00GG`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DecodeUnicodeEscapes([]byte(tt.input))
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUnicodeEscapesInSSE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SSE data with Unicode escapes",
|
||||
input: `data: {"choices":[{"delta":{"content":"\u4e2d\u6587"}}]}
|
||||
|
||||
`,
|
||||
expected: `data: {"choices":[{"delta":{"content":"中文"}}]}
|
||||
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Multiple SSE data lines",
|
||||
input: `data: {"content":"\u4e2d\u6587"}
|
||||
data: {"content":"\u82f1\u6587"}
|
||||
data: [DONE]
|
||||
`,
|
||||
expected: `data: {"content":"中文"}
|
||||
data: {"content":"英文"}
|
||||
data: [DONE]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Non-data lines unchanged",
|
||||
input: ": comment\nevent: message\ndata: test\n",
|
||||
expected: ": comment\nevent: message\ndata: test\n",
|
||||
},
|
||||
{
|
||||
name: "Real Vertex AI response format",
|
||||
input: `data: {"choices":[{"delta":{"content":"\uff08\u76c8\u5229\u80fd\u529b\uff09","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
|
||||
|
||||
`,
|
||||
expected: `data: {"choices":[{"delta":{"content":"(盈利能力)","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
|
||||
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DecodeUnicodeEscapesInSSE([]byte(tt.input))
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,21 +5,21 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
|
||||
@@ -6,7 +6,7 @@ toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
|
||||
@@ -83,7 +83,6 @@ global_threshold:
|
||||
token_per_minute: 1000 # 自定义规则组每分钟1000个token
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
### 识别请求参数 apikey,进行区别限流
|
||||
|
||||
@@ -89,7 +89,6 @@ global_threshold:
|
||||
token_per_minute: 1000 # 1000 tokens per minute for the custom rule group
|
||||
redis:
|
||||
service_name: redis.static
|
||||
show_limit_quota_header: true
|
||||
```
|
||||
|
||||
### Identify request parameter apikey for differentiated rate limiting
|
||||
|
||||
2
plugins/wasm-go/extensions/model-mapper/Makefile
Normal file
2
plugins/wasm-go/extensions/model-mapper/Makefile
Normal file
@@ -0,0 +1,2 @@
|
||||
build-go:
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
|
||||
61
plugins/wasm-go/extensions/model-mapper/README.md
Normal file
61
plugins/wasm-go/extensions/model-mapper/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# 功能说明
|
||||
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
|
||||
|
||||
# 配置字段
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
|
||||
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效 |
|
||||
|
||||
|
||||
## 效果说明
|
||||
|
||||
如下配置
|
||||
|
||||
```yaml
|
||||
modelMapping:
|
||||
'gpt-4-*': "qwen-max"
|
||||
'gpt-4o': "qwen-vl-plus"
|
||||
'*': "qwen-turbo"
|
||||
```
|
||||
|
||||
开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
|
||||
|
||||
例如原本的请求是:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
经过这个插件后,原始的 LLM 请求体将被改成:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-vl-plus",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
61
plugins/wasm-go/extensions/model-mapper/README_EN.md
Normal file
61
plugins/wasm-go/extensions/model-mapper/README_EN.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Function Description
|
||||
The `model-mapper` plugin implements model parameter mapping functionality based on the LLM protocol.
|
||||
|
||||
# Configuration Fields
|
||||
|
||||
| Name | Type | Requirement | Default Value | Description |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `modelKey` | string | Optional | model | The position of the model parameter in the request body. |
|
||||
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model name in the request to the model name supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all names starting with "gpt-3-";<br/>2. Supports using "*" as a key to configure a generic fallback mapping;<br/>3. If the target mapping name is an empty string "", it indicates keeping the original model name. |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes. |
|
||||
|
||||
|
||||
## Effect Description
|
||||
|
||||
Configuration example:
|
||||
|
||||
```yaml
|
||||
modelMapping:
|
||||
'gpt-4-*': "qwen-max"
|
||||
'gpt-4o': "qwen-vl-plus"
|
||||
'*': "qwen-turbo"
|
||||
```
|
||||
|
||||
After enabling, model parameters starting with `gpt-4-` will be replaced with `qwen-max`, `gpt-4o` will be replaced with `qwen-vl-plus`, and all other models will be replaced with `qwen-turbo`.
|
||||
|
||||
For example, the original request is:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the github address of the main repository of the higress project"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
After processing by this plugin, the original LLM request body will be modified to:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-vl-plus",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the github address of the main repository of the higress project"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
24
plugins/wasm-go/extensions/model-mapper/go.mod
Normal file
24
plugins/wasm-go/extensions/model-mapper/go.mod
Normal file
@@ -0,0 +1,24 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/model-mapper
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.7
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
30
plugins/wasm-go/extensions/model-mapper/go.sum
Normal file
30
plugins/wasm-go/extensions/model-mapper/go.sum
Normal file
@@ -0,0 +1,30 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
195
plugins/wasm-go/extensions/model-mapper/main.go
Normal file
195
plugins/wasm-go/extensions/model-mapper/main.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"model-mapper",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.WithRebuildAfterRequests[Config](1000),
|
||||
wrapper.WithRebuildMaxMemBytes[Config](200*1024*1024),
|
||||
)
|
||||
}
|
||||
|
||||
type ModelMapping struct {
|
||||
Prefix string
|
||||
Target string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
modelKey string
|
||||
exactModelMapping map[string]string
|
||||
prefixModelMapping []ModelMapping
|
||||
defaultModel string
|
||||
enableOnPathSuffix []string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.modelKey = json.Get("modelKey").String()
|
||||
if config.modelKey == "" {
|
||||
config.modelKey = "model"
|
||||
}
|
||||
|
||||
modelMapping := json.Get("modelMapping")
|
||||
if modelMapping.Exists() && !modelMapping.IsObject() {
|
||||
return errors.New("modelMapping must be an object")
|
||||
}
|
||||
|
||||
config.exactModelMapping = make(map[string]string)
|
||||
config.prefixModelMapping = make([]ModelMapping, 0)
|
||||
|
||||
// To replicate C++ behavior (nlohmann::json iterates keys alphabetically),
|
||||
// we collect entries and sort them by key.
|
||||
type mappingEntry struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
var entries []mappingEntry
|
||||
modelMapping.ForEach(func(key, value gjson.Result) bool {
|
||||
entries = append(entries, mappingEntry{
|
||||
key: key.String(),
|
||||
value: value.String(),
|
||||
})
|
||||
return true
|
||||
})
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].key < entries[j].key
|
||||
})
|
||||
|
||||
for _, entry := range entries {
|
||||
key := entry.key
|
||||
value := entry.value
|
||||
if key == "*" {
|
||||
config.defaultModel = value
|
||||
} else if strings.HasSuffix(key, "*") {
|
||||
prefix := strings.TrimSuffix(key, "*")
|
||||
config.prefixModelMapping = append(config.prefixModelMapping, ModelMapping{
|
||||
Prefix: prefix,
|
||||
Target: value,
|
||||
})
|
||||
} else {
|
||||
config.exactModelMapping[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
enableOnPathSuffix := json.Get("enableOnPathSuffix")
|
||||
if enableOnPathSuffix.Exists() {
|
||||
if !enableOnPathSuffix.IsArray() {
|
||||
return errors.New("enableOnPathSuffix must be an array")
|
||||
}
|
||||
for _, item := range enableOnPathSuffix.Array() {
|
||||
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
|
||||
}
|
||||
} else {
|
||||
config.enableOnPathSuffix = []string{
|
||||
"/completions",
|
||||
"/embeddings",
|
||||
"/images/generations",
|
||||
"/audio/speech",
|
||||
"/fine_tuning/jobs",
|
||||
"/moderations",
|
||||
"/image-synthesis",
|
||||
"/video-synthesis",
|
||||
"/rerank",
|
||||
"/messages",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
// Check path suffix
|
||||
path, err := proxywasm.GetHttpRequestHeader(":path")
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Strip query parameters
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
path = path[:idx]
|
||||
}
|
||||
|
||||
matched := false
|
||||
for _, suffix := range config.enableOnPathSuffix {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !matched || !ctx.HasRequestBody() {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Prepare for body processing
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
// 100MB buffer limit
|
||||
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
if len(body) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if !json.Valid(body) {
|
||||
log.Error("invalid json body")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
oldModel := gjson.GetBytes(body, config.modelKey).String()
|
||||
|
||||
newModel := config.defaultModel
|
||||
if newModel == "" {
|
||||
newModel = oldModel
|
||||
}
|
||||
|
||||
// Exact match
|
||||
if target, ok := config.exactModelMapping[oldModel]; ok {
|
||||
newModel = target
|
||||
} else {
|
||||
// Prefix match
|
||||
for _, mapping := range config.prefixModelMapping {
|
||||
if strings.HasPrefix(oldModel, mapping.Prefix) {
|
||||
newModel = mapping.Target
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if newModel != "" && newModel != oldModel {
|
||||
newBody, err := sjson.SetBytes(body, config.modelKey, newModel)
|
||||
if err != nil {
|
||||
log.Errorf("failed to update model: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
log.Debugf("model mapped, before: %s, after: %s", oldModel, newModel)
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
250
plugins/wasm-go/extensions/model-mapper/main_test.go
Normal file
250
plugins/wasm-go/extensions/model-mapper/main_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic configs for wasm test host
|
||||
var (
|
||||
basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
},
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
customConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "request.model",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4o",
|
||||
"gpt-3.5*": "gpt-4-mini",
|
||||
"gpt-3.5-t": "gpt-4-turbo",
|
||||
"gpt-3.5-t1": "gpt-4-turbo-1",
|
||||
},
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/embeddings",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("basic config with defaults", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelMapping": {
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
"gpt-4*": "gpt-4o-mini",
|
||||
"*": "gpt-4o"
|
||||
}
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// default modelKey
|
||||
require.Equal(t, "model", cfg.modelKey)
|
||||
// exact mapping
|
||||
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
|
||||
// prefix mapping
|
||||
require.Len(t, cfg.prefixModelMapping, 1)
|
||||
require.Equal(t, "gpt-4", cfg.prefixModelMapping[0].Prefix)
|
||||
// default model
|
||||
require.Equal(t, "gpt-4o", cfg.defaultModel)
|
||||
// default enabled path suffixes
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
|
||||
})
|
||||
|
||||
t.Run("custom modelKey and enableOnPathSuffix", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelKey": "request.model",
|
||||
"modelMapping": {
|
||||
"gpt-3.5-turbo": "gpt-4",
|
||||
"gpt-3.5*": "gpt-4-mini"
|
||||
},
|
||||
"enableOnPathSuffix": ["/v1/chat/completions", "/v1/embeddings"]
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "request.model", cfg.modelKey)
|
||||
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
|
||||
require.Len(t, cfg.prefixModelMapping, 1)
|
||||
require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix)
|
||||
require.Equal(t, "gpt-4-mini", cfg.prefixModelMapping[0].Target)
|
||||
require.Equal(t, 2, len(cfg.enableOnPathSuffix))
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/v1/chat/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/v1/embeddings")
|
||||
})
|
||||
|
||||
t.Run("modelMapping must be object", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"modelMapping": "invalid"
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("enableOnPathSuffix must be array", func(t *testing.T) {
|
||||
var cfg Config
|
||||
jsonData := []byte(`{
|
||||
"enableOnPathSuffix": "not-array"
|
||||
}`)
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("skip when path not matched", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/other"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should still exist because path is not enabled
|
||||
foundContentLength := false
|
||||
for _, h := range newHeaders {
|
||||
if strings.ToLower(h[0]) == "content-length" {
|
||||
foundContentLength = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundContentLength)
|
||||
})
|
||||
|
||||
t.Run("process when path and content-type match", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
// content-length should be removed
|
||||
for _, h := range newHeaders {
|
||||
require.NotEqual(t, strings.ToLower(h[0]), "content-length")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("exact mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
|
||||
})
|
||||
|
||||
t.Run("default model when key missing", func(t *testing.T) {
|
||||
// use customConfig where default model is set with "*"
|
||||
host, status := test.NewTestHost(customConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"request": {
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
// default model should be set at request.model
|
||||
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String())
|
||||
})
|
||||
|
||||
t.Run("prefix mapping takes effect", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"request": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String())
|
||||
})
|
||||
})
|
||||
}
|
||||
2
plugins/wasm-go/extensions/model-router/Makefile
Normal file
2
plugins/wasm-go/extensions/model-router/Makefile
Normal file
@@ -0,0 +1,2 @@
|
||||
build-go:
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
|
||||
98
plugins/wasm-go/extensions/model-router/README.md
Normal file
98
plugins/wasm-go/extensions/model-router/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
## 功能说明
|
||||
`model-router`插件实现了基于LLM协议中的model参数路由的功能
|
||||
|
||||
## 配置字段
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
|
||||
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
|
||||
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 |
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:认证阶段
|
||||
插件执行优先级:900
|
||||
|
||||
## 效果说明
|
||||
|
||||
### 基于 model 参数进行路由
|
||||
|
||||
需要做如下配置:
|
||||
|
||||
```yaml
|
||||
modelToHeader: x-higress-llm-model
|
||||
```
|
||||
|
||||
插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
经过这个插件后,将添加下面这个请求头(可以用于路由匹配):
|
||||
|
||||
x-higress-llm-model: qwen-long
|
||||
|
||||
### 提取 model 参数中的 provider 字段用于路由
|
||||
|
||||
> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
|
||||
|
||||
需要做如下配置:
|
||||
|
||||
```yaml
|
||||
addProviderHeader: x-higress-llm-provider
|
||||
```
|
||||
|
||||
插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "dashscope/qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
经过这个插件后,将添加下面这个请求头(可以用于路由匹配):
|
||||
|
||||
x-higress-llm-provider: dashscope
|
||||
|
||||
原始的 LLM 请求体将被改成:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "higress项目主仓库的github地址是什么"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
97
plugins/wasm-go/extensions/model-router/README_EN.md
Normal file
97
plugins/wasm-go/extensions/model-router/README_EN.md
Normal file
@@ -0,0 +1,97 @@
|
||||
## Feature Description
|
||||
The `model-router` plugin implements routing functionality based on the model parameter in LLM protocols.
|
||||
|
||||
## Configuration Fields
|
||||
|
||||
| Name | Data Type | Requirement | Default Value | Description |
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | Optional | model | Location of the model parameter in the request body |
|
||||
| `addProviderHeader` | string | Optional | - | Which request header to add the provider name parsed from the model parameter |
|
||||
| `modelToHeader` | string | Optional | - | Which request header to directly add the model parameter to |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths |
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
Plugin execution phase: Authentication phase
|
||||
Plugin execution priority: 900
|
||||
|
||||
## Effect Description
|
||||
|
||||
### Routing Based on Model Parameter
|
||||
|
||||
The following configuration is needed:
|
||||
|
||||
```yaml
|
||||
modelToHeader: x-higress-llm-model
|
||||
```
|
||||
|
||||
The plugin extracts the model parameter from the request and sets it to the x-higress-llm-model request header for subsequent routing. For example, the original LLM request body is:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the GitHub address of the Higress project's main repository?"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
After processing by this plugin, the following request header will be added (can be used for route matching):
|
||||
|
||||
x-higress-llm-model: qwen-long
|
||||
|
||||
### Extracting Provider Field from Model Parameter for Routing
|
||||
|
||||
> Note that this mode requires the client to specify the provider in the model parameter using the `/` delimiter
|
||||
|
||||
The following configuration is needed:
|
||||
|
||||
```yaml
|
||||
addProviderHeader: x-higress-llm-provider
|
||||
```
|
||||
|
||||
The plugin extracts the provider part (if any) from the model parameter in the request, sets it to the x-higress-llm-provider request header for subsequent routing, and rewrites the model parameter to only contain the model name part. For example, the original LLM request body is:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "dashscope/qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the GitHub address of the Higress project's main repository?"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
```
|
||||
|
||||
After processing by this plugin, the following request header will be added (can be used for route matching):
|
||||
|
||||
x-higress-llm-provider: dashscope
|
||||
|
||||
The original LLM request body will be changed to:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-long",
|
||||
"frequency_penalty": 0,
|
||||
"max_tokens": 800,
|
||||
"stream": false,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the GitHub address of the Higress project's main repository?"
|
||||
}],
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95
|
||||
}
|
||||
24
plugins/wasm-go/extensions/model-router/go.mod
Normal file
24
plugins/wasm-go/extensions/model-router/go.mod
Normal file
@@ -0,0 +1,24 @@
|
||||
module model-router
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.7
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
30
plugins/wasm-go/extensions/model-router/go.sum
Normal file
30
plugins/wasm-go/extensions/model-router/go.sum
Normal file
@@ -0,0 +1,30 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
259
plugins/wasm-go/extensions/model-router/main.go
Normal file
259
plugins/wasm-go/extensions/model-router/main.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"model-router",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.WithRebuildAfterRequests[ModelRouterConfig](1000),
|
||||
wrapper.WithRebuildMaxMemBytes[ModelRouterConfig](200*1024*1024),
|
||||
)
|
||||
}
|
||||
|
||||
type ModelRouterConfig struct {
|
||||
modelKey string
|
||||
addProviderHeader string
|
||||
modelToHeader string
|
||||
enableOnPathSuffix []string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ModelRouterConfig) error {
|
||||
config.modelKey = json.Get("modelKey").String()
|
||||
if config.modelKey == "" {
|
||||
config.modelKey = "model"
|
||||
}
|
||||
config.addProviderHeader = json.Get("addProviderHeader").String()
|
||||
config.modelToHeader = json.Get("modelToHeader").String()
|
||||
|
||||
enableOnPathSuffix := json.Get("enableOnPathSuffix")
|
||||
if enableOnPathSuffix.Exists() && enableOnPathSuffix.IsArray() {
|
||||
for _, item := range enableOnPathSuffix.Array() {
|
||||
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
|
||||
}
|
||||
} else {
|
||||
// Default suffixes if not provided
|
||||
config.enableOnPathSuffix = []string{
|
||||
"/completions",
|
||||
"/embeddings",
|
||||
"/images/generations",
|
||||
"/audio/speech",
|
||||
"/fine_tuning/jobs",
|
||||
"/moderations",
|
||||
"/image-synthesis",
|
||||
"/video-synthesis",
|
||||
"/rerank",
|
||||
"/messages",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ModelRouterConfig) types.Action {
|
||||
path, err := proxywasm.GetHttpRequestHeader(":path")
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Remove query parameters for suffix check
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
path = path[:idx]
|
||||
}
|
||||
|
||||
enable := false
|
||||
for _, suffix := range config.enableOnPathSuffix {
|
||||
if suffix == "*" || strings.HasSuffix(path, suffix) {
|
||||
enable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !enable || !ctx.HasRequestBody() {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Prepare for body processing
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
// 100MB buffer limit
|
||||
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
|
||||
contentType, err := proxywasm.GetHttpRequestHeader("content-type")
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return handleJsonBody(ctx, config, body)
|
||||
} else if strings.Contains(contentType, "multipart/form-data") {
|
||||
return handleMultipartBody(ctx, config, body, contentType)
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
|
||||
if !json.Valid(body) {
|
||||
log.Error("invalid json body")
|
||||
return types.ActionContinue
|
||||
}
|
||||
modelValue := gjson.GetBytes(body, config.modelKey).String()
|
||||
if modelValue == "" {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if config.modelToHeader != "" {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
|
||||
}
|
||||
|
||||
if config.addProviderHeader != "" {
|
||||
parts := strings.SplitN(modelValue, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
provider := parts[0]
|
||||
model := parts[1]
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
|
||||
|
||||
newBody, err := sjson.SetBytes(body, config.modelKey, model)
|
||||
if err != nil {
|
||||
log.Errorf("failed to update model in json body: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
_ = proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
log.Debugf("model route to provider: %s, model: %s", provider, model)
|
||||
} else {
|
||||
log.Debugf("model route to provider not work, model: %s", modelValue)
|
||||
}
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func handleMultipartBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte, contentType string) types.Action {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse content type: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok {
|
||||
log.Errorf("no boundary in content type")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
var newBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&newBody)
|
||||
writer.SetBoundary(boundary)
|
||||
|
||||
modified := false
|
||||
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("failed to read multipart part: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Read part content
|
||||
partContent, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read part content: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
formName := part.FormName()
|
||||
if formName == config.modelKey {
|
||||
modelValue := string(partContent)
|
||||
|
||||
if config.modelToHeader != "" {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
|
||||
}
|
||||
|
||||
if config.addProviderHeader != "" {
|
||||
parts := strings.SplitN(modelValue, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
provider := parts[0]
|
||||
model := parts[1]
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
|
||||
|
||||
// Write modified part
|
||||
h := make(http.Header)
|
||||
for k, v := range part.Header {
|
||||
h[k] = v
|
||||
}
|
||||
|
||||
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
|
||||
if err != nil {
|
||||
log.Errorf("failed to create part: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, err = pw.Write([]byte(model))
|
||||
if err != nil {
|
||||
log.Errorf("failed to write part content: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
modified = true
|
||||
log.Debugf("model route to provider: %s, model: %s", provider, model)
|
||||
continue
|
||||
} else {
|
||||
log.Debugf("model route to provider not work, model: %s", modelValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write original part
|
||||
h := make(http.Header)
|
||||
for k, v := range part.Header {
|
||||
h[k] = v
|
||||
}
|
||||
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
|
||||
if err != nil {
|
||||
log.Errorf("failed to create part: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, err = pw.Write(partContent)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write part content: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
writer.Close()
|
||||
|
||||
if modified {
|
||||
_ = proxywasm.ReplaceHttpRequestBody(newBody.Bytes())
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
288
plugins/wasm-go/extensions/model-router/main_test.go
Normal file
288
plugins/wasm-go/extensions/model-router/main_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic configs for wasm test host
|
||||
var (
|
||||
basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"addProviderHeader": "x-provider",
|
||||
"modelToHeader": "x-model",
|
||||
"enableOnPathSuffix": []string{
|
||||
"/v1/chat/completions",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
defaultSuffixConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"modelKey": "model",
|
||||
"addProviderHeader": "x-provider",
|
||||
"modelToHeader": "x-model",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
)
|
||||
|
||||
func getHeader(headers [][2]string, key string) (string, bool) {
|
||||
for _, h := range headers {
|
||||
if strings.EqualFold(h[0], key) {
|
||||
return h[1], true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("basic config with defaults", func(t *testing.T) {
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// default modelKey
|
||||
require.Equal(t, "model", cfg.modelKey)
|
||||
// headers
|
||||
require.Equal(t, "x-provider", cfg.addProviderHeader)
|
||||
require.Equal(t, "x-model", cfg.modelToHeader)
|
||||
// default enabled path suffixes should contain common openai paths
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
|
||||
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
|
||||
})
|
||||
|
||||
t.Run("custom enableOnPathSuffix", func(t *testing.T) {
|
||||
jsonData := []byte(`{
|
||||
"modelKey": "my_model",
|
||||
"addProviderHeader": "x-prov",
|
||||
"modelToHeader": "x-mod",
|
||||
"enableOnPathSuffix": ["/foo", "/bar"]
|
||||
}`)
|
||||
var cfg ModelRouterConfig
|
||||
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "my_model", cfg.modelKey)
|
||||
require.Equal(t, "x-prov", cfg.addProviderHeader)
|
||||
require.Equal(t, "x-mod", cfg.modelToHeader)
|
||||
require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("skip when path not matched", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/other"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.True(t, found, "content-length should be kept when path not enabled")
|
||||
})
|
||||
|
||||
t.Run("process when path and content-type match", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.False(t, found, "content-length should be removed when buffering body")
|
||||
})
|
||||
|
||||
t.Run("do not process for unsupported content-type", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "text/plain"},
|
||||
{"content-length", "123"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
_, found := getHeader(newHeaders, "content-length")
|
||||
require.False(t, found, "content-length should not be removed for unsupported content-type")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_JSON(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"model": "openai/gpt-4o",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
// model should be rewritten to only the model part
|
||||
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String())
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai/gpt-4o", hv)
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai", pv)
|
||||
})
|
||||
|
||||
t.Run("no change when model not provided", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
origBody := []byte(`{
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`)
|
||||
action := host.CallOnHttpRequestBody(origBody)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
// body should remain nil or unchanged as plugin does nothing
|
||||
if processed != nil {
|
||||
require.JSONEq(t, string(origBody), string(processed))
|
||||
}
|
||||
_, found := getHeader(host.GetRequestHeaders(), "x-provider")
|
||||
require.False(t, found)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody_Multipart(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
// model field
|
||||
modelWriter, err := writer.CreateFormField("model")
|
||||
require.NoError(t, err)
|
||||
_, err = modelWriter.Write([]byte("openai/gpt-4o"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// another field to ensure others are preserved
|
||||
fileWriter, err := writer.CreateFormField("prompt")
|
||||
require.NoError(t, err)
|
||||
_, err = fileWriter.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
contentType := "multipart/form-data; boundary=" + writer.Boundary()
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(buf.Bytes())
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processed := host.GetRequestBody()
|
||||
require.NotNil(t, processed)
|
||||
|
||||
// Parse multipart body again to verify fields
|
||||
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
|
||||
|
||||
foundModel := false
|
||||
foundPrompt := false
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
name := part.FormName()
|
||||
data, err := io.ReadAll(part)
|
||||
require.NoError(t, err)
|
||||
|
||||
switch name {
|
||||
case "model":
|
||||
foundModel = true
|
||||
require.Equal(t, "gpt-4o", string(data))
|
||||
case "prompt":
|
||||
foundPrompt = true
|
||||
require.Equal(t, "hello", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, foundModel)
|
||||
require.True(t, foundPrompt)
|
||||
|
||||
headers := host.GetRequestHeaders()
|
||||
hv, found := getHeader(headers, "x-model")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai/gpt-4o", hv)
|
||||
pv, found := getHeader(headers, "x-provider")
|
||||
require.True(t, found)
|
||||
require.Equal(t, "openai", pv)
|
||||
})
|
||||
}
|
||||
206
plugins/wasm-go/extensions/traffic-editor/README.md
Normal file
206
plugins/wasm-go/extensions/traffic-editor/README.md
Normal file
@@ -0,0 +1,206 @@
|
||||
---
|
||||
title: 请求响应编辑
|
||||
keywords: [ higress,request,response,edit ]
|
||||
description: 请求响应编辑插件使用说明
|
||||
---
|
||||
|
||||
## 功能说明
|
||||
|
||||
`traffic-editor` 插件可以对请求/响应头进行修改,支持的修改操作类型包括删除、重命名、更新、添加、追加、映射、去重。
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
插件执行优先级:`100`
|
||||
|
||||
## 配置字段
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------------------|-----------------------------------------|----|---------------------|
|
||||
| defaultConfig | object (CommandSet) | 否 | 默认命令集配置,无条件执行的编辑操作 |
|
||||
| conditionalConfigs | array of object (ConditionalCommandSet) | 否 | 条件命令集配置,按条件执行不同编辑操作 |
|
||||
|
||||
### CommandSet 结构
|
||||
|
||||
| 字段名 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|----------------|---------------------------|----|-------|------------|
|
||||
| disableReroute | bool | 否 | false | 是否禁用自动路由重选 |
|
||||
| commands | array of object (Command) | 是 | - | 编辑命令列表 |
|
||||
|
||||
### ConditionalCommandSet 结构
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|------------|---------------------------|----|---------------------|
|
||||
| conditions | array | 是 | 条件列表,见下表 |
|
||||
| commands | array of object (Command) | 是 | 命令列表,结构同 CommandSet |
|
||||
|
||||
#### Command 结构
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|------|--------|----|-------------------|
|
||||
| type | string | 是 | 命令类型。其他配置字段由类型决定。 |
|
||||
|
||||
##### set 命令
|
||||
|
||||
功能为将某个字段设置为指定值。`type` 字段值为 `set`。
|
||||
|
||||
其它字段如下:
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|--------|
|
||||
| target | object (Ref) | 是 | 目标字段信息 |
|
||||
| value | string | 是 | 要设置的值 |
|
||||
|
||||
##### concat 命令
|
||||
|
||||
功能为将多个值拼接后赋值给目标字段。`type` 字段值为 `concat`。
|
||||
|
||||
其它字段如下:
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|-----------------------|----|--------------------------|
|
||||
| target | object (Ref) | 是 | 目标字段信息 |
|
||||
| values | array of (string/Ref) | 是 | 要拼接的值列表,可以是字符串或字段引用(Ref) |
|
||||
|
||||
##### copy 命令
|
||||
|
||||
功能为将源字段的值复制到目标字段。`type` 字段值为 `copy`。
|
||||
|
||||
其它字段如下:
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|--------|
|
||||
| source | object (Ref) | 是 | 源字段信息 |
|
||||
| target | object (Ref) | 是 | 目标字段信息 |
|
||||
|
||||
##### delete 命令
|
||||
|
||||
功能为删除指定字段。`type` 字段值为 `delete`。
|
||||
|
||||
其它字段如下:
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|----------|
|
||||
| target | object (Ref) | 是 | 要删除的字段信息 |
|
||||
|
||||
##### rename 命令
|
||||
|
||||
功能为将字段重命名。`type` 字段值为 `rename`。
|
||||
|
||||
其它字段如下:
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|-------|
|
||||
| source | object (Ref) | 是 | 原字段信息 |
|
||||
| target | object (Ref) | 是 | 新字段信息 |
|
||||
|
||||
#### Condition 结构
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|------|--------|----|-------------------|
|
||||
| type | string | 是 | 条件类型。其他配置字段由类型决定。 |
|
||||
|
||||
##### equals 条件
|
||||
|
||||
判断某字段值是否等于指定值。`type` 字段值为 `equals`。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|---------|
|
||||
| value1 | object (Ref) | 是 | 参与比较的字段 |
|
||||
| value2 | string | 是 | 目标值 |
|
||||
|
||||
##### prefix 条件
|
||||
|
||||
判断某字段值是否以指定前缀开头。`type` 字段值为 `prefix`。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|---------|
|
||||
| value | object (Ref) | 是 | 参与比较的字段 |
|
||||
| prefix | string | 是 | 前缀字符串 |
|
||||
|
||||
##### suffix 条件
|
||||
|
||||
判断某字段值是否以指定后缀结尾。`type` 字段值为 `suffix`。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|---------|
|
||||
| value | object (Ref) | 是 | 参与比较的字段 |
|
||||
| suffix | string | 是 | 后缀字符串 |
|
||||
|
||||
##### contains 条件
|
||||
|
||||
判断某字段值是否包含指定子串。`type` 字段值为 `contains`。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|--------|--------------|----|---------|
|
||||
| value | object (Ref) | 是 | 参与比较的字段 |
|
||||
| substr | string | 是 | 子串 |
|
||||
|
||||
##### regex 条件
|
||||
|
||||
判断某字段值是否匹配指定正则表达式。`type` 字段值为 `regex`。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|---------|--------------|----|---------|
|
||||
| value | object (Ref) | 是 | 参与比较的字段 |
|
||||
| pattern | string | 是 | 正则表达式 |
|
||||
|
||||
#### Ref 结构
|
||||
|
||||
用于标识一个请求或响应中的字段。
|
||||
|
||||
| 字段名 | 类型 | 必填 | 说明 |
|
||||
|------|--------|----|--------------------------------------------------------------|
|
||||
| type | string | 是 | 字段类型。可选值有:`request_header`、`request_query`、`response_header` |
|
||||
| name | string | 是 | 字段名称 |
|
||||
|
||||
### 示例配置
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultConfig": {
|
||||
"disableReroute": false,
|
||||
"commands": [
|
||||
{
|
||||
"type": "set",
|
||||
"target": {
|
||||
"type": "request_header",
|
||||
"name": "x-user"
|
||||
},
|
||||
"value": "admin"
|
||||
},
|
||||
{
|
||||
"type": "delete",
|
||||
"target": {
|
||||
"type": "request_header",
|
||||
"name": "x-dummy"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{
|
||||
"type": "equals",
|
||||
"value1": {
|
||||
"type": "request_query",
|
||||
"name": "id"
|
||||
},
|
||||
"value2": "1"
|
||||
}
|
||||
],
|
||||
"commands": [
|
||||
{
|
||||
"type": "set",
|
||||
"target": {
|
||||
"type": "response_header",
|
||||
"name": "x-id"
|
||||
},
|
||||
"value": "1"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
212
plugins/wasm-go/extensions/traffic-editor/README_EN.md
Normal file
212
plugins/wasm-go/extensions/traffic-editor/README_EN.md
Normal file
@@ -0,0 +1,212 @@
|
||||
---
|
||||
title: Request/Response Editor
|
||||
keywords: [higress,request,response,edit]
|
||||
description: Usage guide for the request/response editor plugin
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
The `traffic-editor` plugin allows you to modify request/response headers. Supported operations include delete, rename, update, add, append, map, and deduplicate.
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
Plugin execution phase: `UNSPECIFIED`
|
||||
Plugin execution priority: `100`
|
||||
|
||||
## Configuration Fields
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|--------------------|-------------------------------------------|----------|-----------------------------------|
|
||||
| defaultConfig | object (CommandSet) | No | Default command set, executed unconditionally |
|
||||
| conditionalConfigs | array of object (ConditionalCommandSet) | No | Conditional command sets, executed based on conditions |
|
||||
|
||||
### CommandSet Structure
|
||||
|
||||
| Field Name | Type | Required | Default | Description |
|
||||
|----------------|----------------------------|----------|---------|----------------------------|
|
||||
| disableReroute | bool | No | false | Whether to disable automatic route selection |
|
||||
| commands | array of object (Command) | Yes | - | List of edit commands |
|
||||
|
||||
### ConditionalCommandSet Structure
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|-------------|----------------------------|----------|-----------------------------------|
|
||||
| conditions | array | Yes | List of conditions, see below |
|
||||
| commands | array of object (Command) | Yes | List of commands, same as CommandSet |
|
||||
|
||||
#### Command Structure
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|--------|----------|------------------------------|
|
||||
| type | string | Yes | Command type, other fields depend on type |
|
||||
|
||||
##### set Command
|
||||
|
||||
Sets a field to a specified value. `type` field value is `set`.
|
||||
|
||||
Other fields:
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| target | object (Ref) | Yes | Target field info|
|
||||
| value | string | Yes | Value to set |
|
||||
|
||||
##### concat Command
|
||||
|
||||
Concatenates multiple values and assigns to the target field. `type` field value is `concat`.
|
||||
|
||||
Other fields:
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|-----------------------|----------|----------------------------------------------|
|
||||
| target | object (Ref) | Yes | Target field info |
|
||||
| values | array of (string/Ref) | Yes | Values to concatenate, can be string or Ref |
|
||||
|
||||
##### copy Command
|
||||
|
||||
Copies the value from the source field to the target field. `type` field value is `copy`.
|
||||
|
||||
Other fields:
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| source | object (Ref) | Yes | Source field info|
|
||||
| target | object (Ref) | Yes | Target field info|
|
||||
|
||||
##### delete Command
|
||||
|
||||
Deletes the specified field. `type` field value is `delete`.
|
||||
|
||||
Other fields:
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| target | object (Ref) | Yes | Field to delete |
|
||||
|
||||
##### rename Command
|
||||
|
||||
Renames a field. `type` field value is `rename`.
|
||||
|
||||
Other fields:
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| source | object (Ref) | Yes | Original field info|
|
||||
| target | object (Ref) | Yes | New field info |
|
||||
|
||||
#### Condition Structure
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|--------|----------|------------------------------|
|
||||
| type | string | Yes | Condition type, other fields depend on type |
|
||||
|
||||
##### equals Condition
|
||||
|
||||
Checks if a field value equals the specified value. `type` field value is `equals`.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| value1 | object (Ref) | Yes | Field to compare |
|
||||
| value2 | string | Yes | Target value |
|
||||
|
||||
##### prefix Condition
|
||||
|
||||
Checks if a field value starts with the specified prefix. `type` field value is `prefix`.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| value | object (Ref) | Yes | Field to compare |
|
||||
| prefix | string | Yes | Prefix string |
|
||||
|
||||
##### suffix Condition
|
||||
|
||||
Checks if a field value ends with the specified suffix. `type` field value is `suffix`.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| value | object (Ref) | Yes | Field to compare |
|
||||
| suffix | string | Yes | Suffix string |
|
||||
|
||||
##### contains Condition
|
||||
|
||||
Checks if a field value contains the specified substring. `type` field value is `contains`.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| value | object (Ref) | Yes | Field to compare |
|
||||
| substr | string | Yes | Substring |
|
||||
|
||||
##### regex Condition
|
||||
|
||||
Checks if a field value matches the specified regular expression. `type` field value is `regex`.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|---------------|----------|------------------|
|
||||
| value | object (Ref) | Yes | Field to compare |
|
||||
| pattern | string | Yes | Regular expression|
|
||||
|
||||
#### Ref Structure
|
||||
|
||||
Used to identify a field in the request or response.
|
||||
|
||||
| Field Name | Type | Required | Description |
|
||||
|------------|--------|----------|------------------------------------------------------------------|
|
||||
| type | string | Yes | Field type: `request_header`, `request_query`, `response_header` |
|
||||
| name | string | Yes | Field name |
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"defaultConfig": {
|
||||
"disableReroute": false,
|
||||
"commands": [
|
||||
{ "type": "set", "target": { "type": "request_header", "name": "x-user" }, "value": "admin" },
|
||||
{ "type": "delete", "target": { "type": "request_header", "name": "x-remove" } }
|
||||
]
|
||||
},
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{ "type": "equals", "value1": { "type": "request_query", "name": "role" }, "value2": "admin" }
|
||||
],
|
||||
"commands": [
|
||||
{ "type": "set", "target": { "type": "response_header", "name": "x-status" }, "value": "is-admin" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
{ "type": "prefix", "value": { "type": "request_header", "name": "x-path" }, "prefix": "/api/" }
|
||||
],
|
||||
"commands": [
|
||||
{ "type": "rename", "source": { "type": "request_header", "name": "x-old" }, "target": { "type": "request_header", "name": "x-new" } }
|
||||
]
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
{ "type": "suffix", "value": { "type": "request_header", "name": "x-path" }, "suffix": ".json" }
|
||||
],
|
||||
"commands": [
|
||||
{ "type": "copy", "source": { "type": "request_query", "name": "id" }, "target": { "type": "response_header", "name": "x-id" } }
|
||||
]
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
{ "type": "contains", "value": { "type": "request_header", "name": "x-info" }, "substr": "test" }
|
||||
],
|
||||
"commands": [
|
||||
{ "type": "concat", "target": { "type": "response_header", "name": "x-token" }, "values": ["prefix-", { "type": "request_query", "name": "token" }] }
|
||||
]
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
{ "type": "regex", "value": { "type": "request_query", "name": "email" }, "pattern": "^.+@example\\.com$" }
|
||||
],
|
||||
"commands": [
|
||||
{ "type": "delete", "target": { "type": "response_header", "name": "x-temp" } }
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
1
plugins/wasm-go/extensions/traffic-editor/VERSION
Normal file
1
plugins/wasm-go/extensions/traffic-editor/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
1.0.0-alpha
|
||||
37
plugins/wasm-go/extensions/traffic-editor/config.go
Normal file
37
plugins/wasm-go/extensions/traffic-editor/config.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor/pkg"
|
||||
)
|
||||
|
||||
type PluginConfig struct {
|
||||
DefaultConfig *pkg.CommandSet `json:"defaultConfig,omitempty"`
|
||||
ConditionalConfigs []*pkg.ConditionalCommandSet `json:"conditionalConfigs,omitempty"`
|
||||
}
|
||||
|
||||
func (c *PluginConfig) FromJson(json gjson.Result) error {
|
||||
c.DefaultConfig = nil
|
||||
defaultConfigJson := json.Get("defaultConfig")
|
||||
if defaultConfigJson.Exists() && defaultConfigJson.IsObject() {
|
||||
c.DefaultConfig = &pkg.CommandSet{}
|
||||
if err := c.DefaultConfig.FromJson(defaultConfigJson); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.ConditionalConfigs = nil
|
||||
conditionalConfigsJson := json.Get("conditionalConfigs")
|
||||
if conditionalConfigsJson.Exists() && conditionalConfigsJson.IsArray() {
|
||||
for _, item := range conditionalConfigsJson.Array() {
|
||||
config := &pkg.ConditionalCommandSet{}
|
||||
if err := config.FromJson(item); err != nil {
|
||||
return err
|
||||
}
|
||||
c.ConditionalConfigs = append(c.ConditionalConfigs, config)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
version: '3.7'
|
||||
services:
|
||||
envoy:
|
||||
image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:v2.1.6
|
||||
entrypoint: /usr/local/bin/envoy
|
||||
# 注意这里对wasm开启了debug级别日志,正式部署时则默认info级别
|
||||
command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug
|
||||
#depends_on:
|
||||
# - httpbin
|
||||
networks:
|
||||
- wasmtest
|
||||
ports:
|
||||
- "10000:10000"
|
||||
volumes:
|
||||
- ./envoy.yaml:/etc/envoy/envoy.yaml
|
||||
- ./plugin.wasm:/etc/envoy/main.wasm
|
||||
|
||||
httpbin:
|
||||
image: kong/httpbin:latest
|
||||
networks:
|
||||
- wasmtest
|
||||
ports:
|
||||
- "12345:80"
|
||||
|
||||
networks:
|
||||
wasmtest: {}
|
||||
24
plugins/wasm-go/extensions/traffic-editor/go.mod
Normal file
24
plugins/wasm-go/extensions/traffic-editor/go.mod
Normal file
@@ -0,0 +1,24 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
31
plugins/wasm-go/extensions/traffic-editor/go.sum
Normal file
31
plugins/wasm-go/extensions/traffic-editor/go.sum
Normal file
@@ -0,0 +1,31 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2 h1:8fQqR+wHts8tP+v7GYxmsCNyW5nAjn9wPYV0/+Seqzg=
|
||||
github.com/higress-group/wasm-go v1.0.2/go.mod h1:882/J8ccU4i+LeyFKmeicbHWAYLj8y7YZr60zk0OOCI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
22
plugins/wasm-go/extensions/traffic-editor/http.go
Normal file
22
plugins/wasm-go/extensions/traffic-editor/http.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import "strings"
|
||||
|
||||
func headerSlice2Map(headerSlice [][2]string) map[string][]string {
|
||||
headerMap := make(map[string][]string)
|
||||
for _, header := range headerSlice {
|
||||
k, v := strings.ToLower(header[0]), header[1]
|
||||
headerMap[k] = append(headerMap[k], v)
|
||||
}
|
||||
return headerMap
|
||||
}
|
||||
|
||||
func headerMap2Slice(headerMap map[string][]string) [][2]string {
|
||||
headerSlice := make([][2]string, 0, len(headerMap))
|
||||
for k, vs := range headerMap {
|
||||
for _, v := range vs {
|
||||
headerSlice = append(headerSlice, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return headerSlice
|
||||
}
|
||||
177
plugins/wasm-go/extensions/traffic-editor/main.go
Normal file
177
plugins/wasm-go/extensions/traffic-editor/main.go
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor/pkg"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
const (
|
||||
ctxKeyEditorContext = "editorContext"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"traffic-editor",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *PluginConfig) (err error) {
|
||||
if err := config.FromJson(json); err != nil {
|
||||
return fmt.Errorf("failed to parse plugin config: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig) types.Action {
|
||||
log.Debugf("onHttpRequestHeaders called with config")
|
||||
|
||||
editorContext := pkg.NewEditorContext()
|
||||
if headers, err := proxywasm.GetHttpRequestHeaders(); err == nil {
|
||||
editorContext.SetRequestHeaders(headerSlice2Map(headers))
|
||||
} else {
|
||||
log.Errorf("failed to get request headers: %v", err)
|
||||
}
|
||||
saveEditorContext(ctx, editorContext)
|
||||
|
||||
effectiveCommandSet := findEffectiveCommandSet(editorContext, &config)
|
||||
if effectiveCommandSet == nil {
|
||||
log.Debugf("no effective command set found for request %s", ctx.Path())
|
||||
return types.ActionContinue
|
||||
}
|
||||
if len(effectiveCommandSet.Commands) == 0 {
|
||||
log.Debugf("the effective command set found for request %s is empty", ctx.Path())
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
log.Debugf("an effective command set found for request %s with %d commands", ctx.Path(), len(effectiveCommandSet.Commands))
|
||||
editorContext.SetEffectiveCommandSet(effectiveCommandSet)
|
||||
editorContext.SetCommandExecutors(effectiveCommandSet.CreatExecutors())
|
||||
|
||||
// Make sure the editor context is clean before executing any command.
|
||||
editorContext.ResetDirtyFlags()
|
||||
|
||||
if effectiveCommandSet.DisableReroute {
|
||||
ctx.DisableReroute()
|
||||
}
|
||||
|
||||
executeCommands(editorContext, pkg.StageRequestHeaders)
|
||||
|
||||
if err := saveRequestHeaderChanges(editorContext); err != nil {
|
||||
log.Errorf("failed to save request header changes: %v", err)
|
||||
}
|
||||
|
||||
// Make sure the editor context is clean before continue.
|
||||
editorContext.ResetDirtyFlags()
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig) types.Action {
|
||||
log.Debugf("onHttpResponseHeaders called with config")
|
||||
|
||||
editorContext := loadEditorContext(ctx)
|
||||
if editorContext.GetEffectiveCommandSet() == nil {
|
||||
log.Debugf("no effective command set found for request %s", ctx.Path())
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if headers, err := proxywasm.GetHttpResponseHeaders(); err == nil {
|
||||
editorContext.SetResponseHeaders(headerSlice2Map(headers))
|
||||
} else {
|
||||
log.Errorf("failed to get response headers: %v", err)
|
||||
}
|
||||
|
||||
// Make sure the editor context is clean before executing any command.
|
||||
editorContext.ResetDirtyFlags()
|
||||
|
||||
executeCommands(editorContext, pkg.StageResponseHeaders)
|
||||
if err := saveResponseHeaderChanges(editorContext); err != nil {
|
||||
log.Errorf("failed to save response header changes: %v", err)
|
||||
}
|
||||
|
||||
// Make sure the editor context is clean before continue.
|
||||
editorContext.ResetDirtyFlags()
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func findEffectiveCommandSet(editorContext pkg.EditorContext, config *PluginConfig) *pkg.CommandSet {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
if len(config.ConditionalConfigs) != 0 {
|
||||
for i, conditionalConfig := range config.ConditionalConfigs {
|
||||
log.Debugf("Evaluating conditional config %d: %+v", i, conditionalConfig)
|
||||
if conditionalConfig.Matches(editorContext) {
|
||||
log.Debugf("Use the conditional command set %d", i)
|
||||
return &conditionalConfig.CommandSet
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("Use the default command set")
|
||||
return config.DefaultConfig
|
||||
}
|
||||
|
||||
func executeCommands(editorContext pkg.EditorContext, stage pkg.Stage) {
|
||||
for _, executor := range editorContext.GetCommandExecutors() {
|
||||
if err := executor.Run(editorContext, stage); err != nil {
|
||||
log.Errorf("failed to execute a %s command in stage %s: %v", executor.GetCommand().GetType(), pkg.Stage2String[stage], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveRequestHeaderChanges(editorContext pkg.EditorContext) error {
|
||||
if !editorContext.IsRequestHeadersDirty() {
|
||||
log.Debugf("no request header change to save")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("saving request header changes: %v", editorContext.GetRequestHeaders())
|
||||
headerSlice := headerMap2Slice(editorContext.GetRequestHeaders())
|
||||
return proxywasm.ReplaceHttpRequestHeaders(headerSlice)
|
||||
}
|
||||
|
||||
func saveResponseHeaderChanges(editorContext pkg.EditorContext) error {
|
||||
if !editorContext.IsResponseHeadersDirty() {
|
||||
log.Debugf("no response header change to save")
|
||||
return nil
|
||||
}
|
||||
log.Debugf("saving response header changes: %v", editorContext.GetResponseHeaders())
|
||||
headerSlice := headerMap2Slice(editorContext.GetResponseHeaders())
|
||||
return proxywasm.ReplaceHttpResponseHeaders(headerSlice)
|
||||
}
|
||||
|
||||
func loadEditorContext(ctx wrapper.HttpContext) pkg.EditorContext {
|
||||
editorContext, _ := ctx.GetContext(ctxKeyEditorContext).(pkg.EditorContext)
|
||||
return editorContext
|
||||
}
|
||||
|
||||
func saveEditorContext(ctx wrapper.HttpContext, editorContext pkg.EditorContext) {
|
||||
ctx.SetContext(ctxKeyEditorContext, editorContext)
|
||||
}
|
||||
306
plugins/wasm-go/extensions/traffic-editor/main_test.go
Normal file
306
plugins/wasm-go/extensions/traffic-editor/main_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("default config only", func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`
|
||||
{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{
|
||||
"type": "set",
|
||||
"target": {
|
||||
"type": "request_header",
|
||||
"name": "x-test"
|
||||
},
|
||||
"value": "123456"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
expectedNewHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"x-test", "123456"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
require.True(t, compareHeaders(expectedNewHeaders, newHeaders), "expected headers: %v, got: %v", expectedNewHeaders, newHeaders)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetMultipleRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-a"}, "value": "aaa"},
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-b"}, "value": "bbb"}
|
||||
]
|
||||
}
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-c", "ccc"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
expectedHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-a", "aaa"},
|
||||
{"x-b", "bbb"},
|
||||
{"x-c", "ccc"},
|
||||
}
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
require.True(t, compareHeaders(expectedHeaders, newHeaders))
|
||||
})
|
||||
}
|
||||
|
||||
func TestConditionalConfigMatch(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-def"}, "value": "default"}
|
||||
]
|
||||
},
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{"type": "equals", "value1": {"type": "request_header", "name": "x-cond"}, "value2": "match"}
|
||||
],
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-special"}, "value": "special"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/data"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-cond", "match"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
expectedHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/data"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-cond", "match"},
|
||||
{"x-special", "special"},
|
||||
}
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
require.True(t, compareHeaders(expectedHeaders, newHeaders))
|
||||
})
|
||||
}
|
||||
|
||||
func TestConditionalConfigNoMatch(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-def"}, "value": "default"}
|
||||
]
|
||||
},
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{"type": "equals", "value1": {"type": "request_header", "name": "x-cond"}, "value2": "match"}
|
||||
],
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_header", "name": "x-special"}, "value": "special"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
originalHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-cond", "notmatch"},
|
||||
}
|
||||
action := host.CallOnHttpRequestHeaders(originalHeaders)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
expectedHeaders := [][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/get"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-cond", "notmatch"},
|
||||
{"x-def", "default"},
|
||||
}
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
require.True(t, compareHeaders(expectedHeaders, newHeaders))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetResponseHeader(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "response_header", "name": "x-res"}, "value": "respval"}
|
||||
]
|
||||
}
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{{"x-origin", "originval"}})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
newHeaders := host.GetResponseHeaders()
|
||||
require.True(t, compareHeaders([][2]string{{"x-origin", "originval"}, {"x-res", "respval"}}, newHeaders))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPathQueryParseAndHeaderChange(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"defaultConfig": {
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "request_query", "name": "foo"}, "value": "bar"}
|
||||
]
|
||||
}
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{":path", "/get?foo=old&baz=1"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
newHeaders := host.GetRequestHeaders()
|
||||
found := false
|
||||
for _, h := range newHeaders {
|
||||
if h[0] == ":path" && strings.Contains(h[1], "foo=bar") {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "path header should be updated with foo=bar")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConditionSetMultiStage(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{"type": "equals", "value1": {"type": "request_header", "name": "x-a"}, "value2": "aaa"}
|
||||
],
|
||||
"commands": [
|
||||
{"type": "set", "target": {"type": "response_header", "name": "x-b"}, "value": "bbb"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
actionReq := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{":path", "/get?foo=old&baz=1"},
|
||||
{"x-a", "aaa"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, actionReq)
|
||||
actionResp := host.CallOnHttpResponseHeaders([][2]string{{"content-type", "application/json"}})
|
||||
require.Equal(t, types.ActionContinue, actionResp)
|
||||
newHeaders := host.GetResponseHeaders()
|
||||
require.True(t, compareHeaders([][2]string{{"x-b", "bbb"}, {"content-type", "application/json"}}, newHeaders))
|
||||
})
|
||||
}
|
||||
|
||||
func TestConditionSetMultiStage2(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{
|
||||
"conditionalConfigs": [
|
||||
{
|
||||
"conditions": [
|
||||
{"type": "equals", "value1": {"type": "request_header", "name": "x-a"}, "value2": "aaa"}
|
||||
],
|
||||
"commands": [
|
||||
{"type": "copy", "source": {"type": "request_header", "name": "x-b"}, "target": {"type": "response_header", "name": "x-c"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
actionReq := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{":path", "/get?foo=old&baz=1"},
|
||||
{"x-a", "aaa"},
|
||||
{"x-b", "bbb"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, actionReq)
|
||||
actionResp := host.CallOnHttpResponseHeaders([][2]string{{"content-type", "application/json"}})
|
||||
require.Equal(t, types.ActionContinue, actionResp)
|
||||
newHeaders := host.GetResponseHeaders()
|
||||
require.True(t, compareHeaders([][2]string{{"x-c", "bbb"}, {"content-type", "application/json"}}, newHeaders))
|
||||
})
|
||||
}
|
||||
func compareHeaders(headers1, headers2 [][2]string) bool {
|
||||
if len(headers1) != len(headers2) {
|
||||
return false
|
||||
}
|
||||
m1 := make(map[string]string, len(headers1))
|
||||
m2 := make(map[string]string, len(headers2))
|
||||
for _, h := range headers1 {
|
||||
m1[strings.ToLower(h[0])] = h[1]
|
||||
}
|
||||
for _, h := range headers2 {
|
||||
m2[strings.ToLower(h[0])] = h[1]
|
||||
}
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v := range m1 {
|
||||
if mv, ok := m2[k]; !ok || mv != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
515
plugins/wasm-go/extensions/traffic-editor/pkg/command.go
Normal file
515
plugins/wasm-go/extensions/traffic-editor/pkg/command.go
Normal file
@@ -0,0 +1,515 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
commandTypeSet = "set"
|
||||
commandTypeConcat = "concat"
|
||||
commandTypeCopy = "copy"
|
||||
commandTypeDelete = "delete"
|
||||
commandTypeRename = "rename"
|
||||
)
|
||||
|
||||
var (
|
||||
commandFactories = map[string]func(gjson.Result) (Command, error){
|
||||
"set": newSetCommand,
|
||||
"concat": newConcatCommand,
|
||||
"copy": newCopyCommand,
|
||||
"delete": newDeleteCommand,
|
||||
"rename": newRenameCommand,
|
||||
}
|
||||
)
|
||||
|
||||
type CommandSet struct {
|
||||
DisableReroute bool `json:"disableReroute"`
|
||||
Commands []Command `json:"commands,omitempty"`
|
||||
RelatedStages map[Stage]bool `json:"-"`
|
||||
}
|
||||
|
||||
func (s *CommandSet) FromJson(json gjson.Result) error {
|
||||
relatedStages := map[Stage]bool{}
|
||||
if commandsJson := json.Get("commands"); commandsJson.Exists() && commandsJson.IsArray() {
|
||||
for _, item := range commandsJson.Array() {
|
||||
if command, err := NewCommand(item); err != nil {
|
||||
return fmt.Errorf("failed to create command from json: %v\n %v", err, item)
|
||||
} else {
|
||||
s.Commands = append(s.Commands, command)
|
||||
for _, ref := range command.GetRefs() {
|
||||
relatedStages[ref.GetStage()] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.RelatedStages = relatedStages
|
||||
if disableReroute := json.Get("disableReroute"); disableReroute.Exists() {
|
||||
s.DisableReroute = disableReroute.Bool()
|
||||
} else {
|
||||
s.DisableReroute = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CommandSet) CreatExecutors() []Executor {
|
||||
executors := make([]Executor, 0, len(s.Commands))
|
||||
for _, command := range s.Commands {
|
||||
executor := command.CreateExecutor()
|
||||
executors = append(executors, executor)
|
||||
}
|
||||
return executors
|
||||
}
|
||||
|
||||
type ConditionalCommandSet struct {
|
||||
ConditionSet
|
||||
CommandSet
|
||||
}
|
||||
|
||||
func (s *ConditionalCommandSet) FromJson(json gjson.Result) error {
|
||||
if err := s.ConditionSet.FromJson(json); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.CommandSet.FromJson(json); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Command interface {
|
||||
GetType() string
|
||||
GetRefs() []*Ref
|
||||
CreateExecutor() Executor
|
||||
}
|
||||
|
||||
type Executor interface {
|
||||
GetCommand() Command
|
||||
Run(editorContext EditorContext, stage Stage) error
|
||||
}
|
||||
|
||||
func NewCommand(json gjson.Result) (Command, error) {
|
||||
t := json.Get("type").String()
|
||||
if t == "" {
|
||||
return nil, errors.New("command type is required")
|
||||
}
|
||||
if constructor, ok := commandFactories[t]; ok && constructor != nil {
|
||||
return constructor(json)
|
||||
} else {
|
||||
return nil, errors.New("unknown command type: " + t)
|
||||
}
|
||||
}
|
||||
|
||||
type baseExecutor struct {
|
||||
finished bool
|
||||
}
|
||||
|
||||
// setCommand
|
||||
func newSetCommand(json gjson.Result) (Command, error) {
|
||||
var targetRef *Ref
|
||||
var err error
|
||||
if t := json.Get("target"); !t.Exists() {
|
||||
return nil, errors.New("setCommand: target field is required")
|
||||
} else {
|
||||
targetRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
var value string
|
||||
if v := json.Get("value"); !v.Exists() {
|
||||
return nil, errors.New("setCommand: value field is required")
|
||||
} else {
|
||||
value = v.String()
|
||||
if value == "" {
|
||||
return nil, errors.New("setCommand: value cannot be empty")
|
||||
}
|
||||
}
|
||||
return &setCommand{
|
||||
targetRef: targetRef,
|
||||
value: value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type setCommand struct {
|
||||
targetRef *Ref
|
||||
value string
|
||||
}
|
||||
|
||||
func (c *setCommand) GetType() string {
|
||||
return commandTypeSet
|
||||
}
|
||||
|
||||
func (c *setCommand) GetRefs() []*Ref {
|
||||
return []*Ref{c.targetRef}
|
||||
}
|
||||
|
||||
func (c *setCommand) CreateExecutor() Executor {
|
||||
return &setExecutor{command: c}
|
||||
}
|
||||
|
||||
type setExecutor struct {
|
||||
baseExecutor
|
||||
command *setCommand
|
||||
}
|
||||
|
||||
func (e *setExecutor) GetCommand() Command {
|
||||
return e.command
|
||||
}
|
||||
|
||||
func (e *setExecutor) Run(editorContext EditorContext, stage Stage) error {
|
||||
if e.finished {
|
||||
return nil
|
||||
}
|
||||
|
||||
command := e.command
|
||||
log.Debugf("setCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
|
||||
if command.targetRef.GetStage() == stage {
|
||||
log.Debugf("setCommand: set %s to %s", command.targetRef, command.value)
|
||||
editorContext.SetRefValue(command.targetRef, command.value)
|
||||
e.finished = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// concatCommand
|
||||
func newConcatCommand(json gjson.Result) (Command, error) {
|
||||
var targetRef *Ref
|
||||
var err error
|
||||
if t := json.Get("target"); !t.Exists() {
|
||||
return nil, errors.New("concatCommand: target field is required")
|
||||
} else {
|
||||
targetRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("concatCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
valuesJson := json.Get("values")
|
||||
if !valuesJson.Exists() || !valuesJson.IsArray() {
|
||||
return nil, errors.New("concatCommand: values field is required and must be an array")
|
||||
}
|
||||
|
||||
values := make([]interface{}, 0, len(valuesJson.Array()))
|
||||
for _, item := range valuesJson.Array() {
|
||||
var value interface{}
|
||||
if item.IsObject() {
|
||||
valueRef, err := NewRef(item)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("concatCommand: failed to create ref from values field: %v\n %v", err, item.Raw)
|
||||
}
|
||||
if valueRef.GetStage() > targetRef.GetStage() {
|
||||
return nil, fmt.Errorf("concatCommand: the processing stage of value [%s] cannot be after the stage of target [%s]", Stage2String[valueRef.GetStage()], Stage2String[targetRef.GetStage()])
|
||||
}
|
||||
value = valueRef
|
||||
} else {
|
||||
value = item.String()
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return &concatCommand{
|
||||
targetRef: targetRef,
|
||||
values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type concatCommand struct {
|
||||
targetRef *Ref
|
||||
values []interface{}
|
||||
}
|
||||
|
||||
func (c *concatCommand) GetType() string {
|
||||
return commandTypeConcat
|
||||
}
|
||||
|
||||
func (c *concatCommand) GetRefs() []*Ref {
|
||||
refs := []*Ref{c.targetRef}
|
||||
if c.values != nil && len(c.values) != 0 {
|
||||
for _, value := range c.values {
|
||||
if ref, ok := value.(*Ref); ok {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
}
|
||||
return refs
|
||||
}
|
||||
|
||||
func (c *concatCommand) CreateExecutor() Executor {
|
||||
return &concatExecutor{command: c}
|
||||
}
|
||||
|
||||
type concatExecutor struct {
|
||||
baseExecutor
|
||||
command *concatCommand
|
||||
values []string
|
||||
}
|
||||
|
||||
func (e *concatExecutor) GetCommand() Command {
|
||||
return e.command
|
||||
}
|
||||
|
||||
func (e *concatExecutor) Run(editorContext EditorContext, stage Stage) error {
|
||||
if e.finished {
|
||||
return nil
|
||||
}
|
||||
|
||||
command := e.command
|
||||
|
||||
if e.values == nil {
|
||||
e.values = make([]string, len(command.values))
|
||||
}
|
||||
|
||||
for i, value := range command.values {
|
||||
if value == nil || e.values[i] != "" {
|
||||
continue
|
||||
}
|
||||
v := ""
|
||||
if s, ok := value.(string); ok {
|
||||
v = s
|
||||
} else if ref, ok := value.(*Ref); ok && ref.GetStage() == stage {
|
||||
v = editorContext.GetRefValue(ref)
|
||||
}
|
||||
e.values[i] = v
|
||||
}
|
||||
|
||||
if command.targetRef.GetStage() == stage {
|
||||
result := ""
|
||||
for _, v := range e.values {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
result += v
|
||||
}
|
||||
log.Debugf("concatCommand: set %s to %s", command.targetRef, result)
|
||||
editorContext.SetRefValue(command.targetRef, result)
|
||||
e.finished = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyCommand
|
||||
func newCopyCommand(json gjson.Result) (Command, error) {
|
||||
var sourceRef *Ref
|
||||
var targetRef *Ref
|
||||
var err error
|
||||
if t := json.Get("source"); !t.Exists() {
|
||||
return nil, errors.New("copyCommand: source field is required")
|
||||
} else {
|
||||
sourceRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copyCommand: failed to create ref from source field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
if t := json.Get("target"); !t.Exists() {
|
||||
return nil, errors.New("copyCommand: target field is required")
|
||||
} else {
|
||||
targetRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copyCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
if sourceRef.GetStage() > targetRef.GetStage() {
|
||||
return nil, fmt.Errorf("copyCommand: the processing stage of source [%s] cannot be after the stage of target [%s]", Stage2String[sourceRef.GetStage()], Stage2String[targetRef.GetStage()])
|
||||
}
|
||||
return ©Command{
|
||||
sourceRef: sourceRef,
|
||||
targetRef: targetRef,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type copyCommand struct {
|
||||
sourceRef *Ref
|
||||
targetRef *Ref
|
||||
}
|
||||
|
||||
func (c *copyCommand) GetType() string {
|
||||
return commandTypeCopy
|
||||
}
|
||||
|
||||
func (c *copyCommand) GetRefs() []*Ref {
|
||||
return []*Ref{c.sourceRef, c.targetRef}
|
||||
}
|
||||
|
||||
func (c *copyCommand) CreateExecutor() Executor {
|
||||
return ©Executor{command: c}
|
||||
}
|
||||
|
||||
type copyExecutor struct {
|
||||
baseExecutor
|
||||
command *copyCommand
|
||||
valueToCopy string
|
||||
}
|
||||
|
||||
func (e *copyExecutor) GetCommand() Command {
|
||||
return e.command
|
||||
}
|
||||
|
||||
func (e *copyExecutor) Run(editorContext EditorContext, stage Stage) error {
|
||||
if e.finished {
|
||||
return nil
|
||||
}
|
||||
|
||||
command := e.command
|
||||
|
||||
if command.sourceRef.GetStage() == stage {
|
||||
e.valueToCopy = editorContext.GetRefValue(command.sourceRef)
|
||||
log.Debugf("copyCommand: valueToCopy=%s", e.valueToCopy)
|
||||
}
|
||||
|
||||
if e.valueToCopy == "" {
|
||||
log.Debug("copyCommand: valueToCopy is empty. skip.")
|
||||
e.finished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if command.targetRef.GetStage() == stage {
|
||||
editorContext.SetRefValue(command.targetRef, e.valueToCopy)
|
||||
log.Debugf("copyCommand: set %s to %s", e.valueToCopy, command.targetRef)
|
||||
e.finished = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteCommand
|
||||
func newDeleteCommand(json gjson.Result) (Command, error) {
|
||||
var targetRef *Ref
|
||||
var err error
|
||||
if t := json.Get("target"); !t.Exists() {
|
||||
return nil, errors.New("deleteCommand: target field is required")
|
||||
} else {
|
||||
targetRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleteCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
return &deleteCommand{
|
||||
targetRef: targetRef,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type deleteCommand struct {
|
||||
targetRef *Ref
|
||||
}
|
||||
|
||||
func (c *deleteCommand) GetType() string {
|
||||
return commandTypeDelete
|
||||
}
|
||||
|
||||
func (c *deleteCommand) GetRefs() []*Ref {
|
||||
return []*Ref{c.targetRef}
|
||||
}
|
||||
|
||||
func (c *deleteCommand) CreateExecutor() Executor {
|
||||
return &deleteExecutor{command: c}
|
||||
}
|
||||
|
||||
type deleteExecutor struct {
|
||||
baseExecutor
|
||||
command *deleteCommand
|
||||
}
|
||||
|
||||
func (e *deleteExecutor) GetCommand() Command {
|
||||
return e.command
|
||||
}
|
||||
|
||||
func (e *deleteExecutor) Run(editorContext EditorContext, stage Stage) error {
|
||||
if e.finished {
|
||||
return nil
|
||||
}
|
||||
|
||||
command := e.command
|
||||
log.Debugf("deleteCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
|
||||
|
||||
if command.targetRef.GetStage() == stage {
|
||||
log.Debugf("deleteCommand: delete %s", command.targetRef)
|
||||
editorContext.DeleteRefValues(command.targetRef)
|
||||
e.finished = true
|
||||
log.Debugf("deleteCommand: finished deleting %s", command.targetRef)
|
||||
} else {
|
||||
log.Debugf("deleteCommand: stage %s does not match targetRef stage %s, skipping.", Stage2String[stage], Stage2String[command.targetRef.GetStage()])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// renameCommand
|
||||
func newRenameCommand(json gjson.Result) (Command, error) {
|
||||
var targetRef *Ref
|
||||
var err error
|
||||
if t := json.Get("target"); !t.Exists() {
|
||||
return nil, errors.New("renameCommand: target field is required")
|
||||
} else {
|
||||
targetRef, err = NewRef(t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("renameCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
|
||||
}
|
||||
}
|
||||
newName := json.Get("newName").String()
|
||||
if newName == "" {
|
||||
return nil, errors.New("renameCommand: newName field is required")
|
||||
}
|
||||
return &renameCommand{
|
||||
targetRef: targetRef,
|
||||
newName: newName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type renameCommand struct {
|
||||
targetRef *Ref
|
||||
newName string
|
||||
}
|
||||
|
||||
func (c *renameCommand) GetType() string {
|
||||
return commandTypeRename
|
||||
}
|
||||
|
||||
func (c *renameCommand) GetRefs() []*Ref {
|
||||
return []*Ref{c.targetRef}
|
||||
}
|
||||
|
||||
func (c *renameCommand) CreateExecutor() Executor {
|
||||
return &renameExecutor{command: c}
|
||||
}
|
||||
|
||||
type renameExecutor struct {
|
||||
baseExecutor
|
||||
command *renameCommand
|
||||
}
|
||||
|
||||
func (e *renameExecutor) GetCommand() Command {
|
||||
return e.command
|
||||
}
|
||||
|
||||
func (e *renameExecutor) Run(editorContext EditorContext, stage Stage) error {
|
||||
if e.finished {
|
||||
return nil
|
||||
}
|
||||
|
||||
command := e.command
|
||||
log.Debugf("renameCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
|
||||
|
||||
if command.targetRef.GetStage() == stage {
|
||||
if command.newName == command.targetRef.Name {
|
||||
log.Debugf("renameCommand: skip renaming %s to itself", command.targetRef)
|
||||
} else {
|
||||
values := editorContext.GetRefValues(command.targetRef)
|
||||
log.Debugf("renameCommand: rename %s to %s value=%v", command.targetRef, command.newName, values)
|
||||
editorContext.SetRefValues(&Ref{
|
||||
Type: command.targetRef.Type,
|
||||
Name: command.newName,
|
||||
}, values)
|
||||
editorContext.DeleteRefValues(command.targetRef)
|
||||
log.Debugf("renameCommand: finished renaming %s to %s", command.targetRef, command.newName)
|
||||
}
|
||||
e.finished = true
|
||||
} else {
|
||||
log.Debugf("renameCommand: stage %s does not match targetRef stage %s, skipping.", Stage2String[stage], Stage2String[command.targetRef.GetStage()])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
309
plugins/wasm-go/extensions/traffic-editor/pkg/command_test.go
Normal file
309
plugins/wasm-go/extensions/traffic-editor/pkg/command_test.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNewSetCommand_Success(t *testing.T) {
|
||||
jsonStr := `{"type":"set","target":{"type":"request_header","name":"foo"},"value":"bar"}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
cmd, err := newSetCommand(json)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cmd.GetType() != "set" {
|
||||
t.Errorf("expected type 'set', got %s", cmd.GetType())
|
||||
}
|
||||
refs := cmd.GetRefs()
|
||||
if len(refs) != 1 {
|
||||
t.Errorf("expected 1 ref, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSetCommand_MissingTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"set","value":"bar"}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newSetCommand(json)
|
||||
if err == nil || err.Error() != "setCommand: target field is required" {
|
||||
t.Errorf("expected target field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSetCommand_MissingValue(t *testing.T) {
|
||||
jsonStr := `{"type":"set","target":{"type":"request_header","name":"foo"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newSetCommand(json)
|
||||
if err == nil || err.Error() != "setCommand: value field is required" {
|
||||
t.Errorf("expected value field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConcatCommand_Success(t *testing.T) {
|
||||
jsonStr := `{"type":"concat","target":{"type":"request_header","name":"foo"},"values":["a","b"]}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
cmd, err := newConcatCommand(json)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cmd.GetType() != "concat" {
|
||||
t.Errorf("expected type 'concat', got %s", cmd.GetType())
|
||||
}
|
||||
refs := cmd.GetRefs()
|
||||
if len(refs) < 1 {
|
||||
t.Errorf("expected at least 1 ref, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConcatCommand_MissingTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"concat","values":["a","b"]}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newConcatCommand(json)
|
||||
if err == nil || err.Error() != "concatCommand: target field is required" {
|
||||
t.Errorf("expected target field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConcatCommand_MissingValues(t *testing.T) {
|
||||
jsonStr := `{"type":"concat","target":{"type":"request_header","name":"foo"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newConcatCommand(json)
|
||||
if err == nil || err.Error() != "concatCommand: values field is required and must be an array" {
|
||||
t.Errorf("expected values field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCopyCommand_Success(t *testing.T) {
|
||||
jsonStr := `{"type":"copy","source":{"type":"request_header","name":"foo"},"target":{"type":"request_header","name":"bar"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
cmd, err := newCopyCommand(json)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cmd.GetType() != "copy" {
|
||||
t.Errorf("expected type 'copy', got %s", cmd.GetType())
|
||||
}
|
||||
refs := cmd.GetRefs()
|
||||
if len(refs) != 2 {
|
||||
t.Errorf("expected 2 refs, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCopyCommand_MissingSource(t *testing.T) {
|
||||
jsonStr := `{"type":"copy","target":{"type":"request_header","name":"bar"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newCopyCommand(json)
|
||||
if err == nil || err.Error() != "copyCommand: source field is required" {
|
||||
t.Errorf("expected source field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCopyCommand_MissingTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"copy","source":{"type":"request_header","name":"foo"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newCopyCommand(json)
|
||||
if err == nil || err.Error() != "copyCommand: target field is required" {
|
||||
t.Errorf("expected target field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCopyCommand_SourceStageAfterTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"copy","source":{"type":"response_header","name":"foo"},"target":{"type":"request_header","name":"bar"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newCopyCommand(json)
|
||||
if err == nil || err.Error() != "copyCommand: the processing stage of source [response_headers] cannot be after the stage of target [request_headers]" {
|
||||
t.Errorf("expected source stage field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeleteCommand_Success(t *testing.T) {
|
||||
jsonStr := `{"type":"delete","target":{"type":"request_header","name":"foo"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
cmd, err := newDeleteCommand(json)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cmd.GetType() != "delete" {
|
||||
t.Errorf("expected type 'delete', got %s", cmd.GetType())
|
||||
}
|
||||
refs := cmd.GetRefs()
|
||||
if len(refs) != 1 {
|
||||
t.Errorf("expected 1 ref, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeleteCommand_MissingTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"delete"}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newDeleteCommand(json)
|
||||
if err == nil || err.Error() != "deleteCommand: target field is required" {
|
||||
t.Errorf("expected target field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRenameCommand_Success(t *testing.T) {
|
||||
jsonStr := `{"type":"rename","target":{"type":"request_header","name":"foo"},"newName":"bar"}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
cmd, err := newRenameCommand(json)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cmd.GetType() != "rename" {
|
||||
t.Errorf("expected type 'rename', got %s", cmd.GetType())
|
||||
}
|
||||
refs := cmd.GetRefs()
|
||||
if len(refs) != 1 {
|
||||
t.Errorf("expected 1 ref, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRenameCommand_MissingTarget(t *testing.T) {
|
||||
jsonStr := `{"type":"rename","newName":"bar"}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newRenameCommand(json)
|
||||
if err == nil || err.Error() != "renameCommand: target field is required" {
|
||||
t.Errorf("expected target field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRenameCommand_MissingNewName(t *testing.T) {
|
||||
jsonStr := `{"type":"rename","target":{"type":"request_header","name":"foo"}}`
|
||||
json := gjson.Parse(jsonStr)
|
||||
_, err := newRenameCommand(json)
|
||||
if err == nil || err.Error() != "renameCommand: newName field is required" {
|
||||
t.Errorf("expected newName field error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetExecutor_Run_SingleStage(t *testing.T) {
|
||||
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
cmd := &setCommand{targetRef: ref, value: "bar"}
|
||||
executor := cmd.CreateExecutor()
|
||||
ctx := NewEditorContext()
|
||||
stage := StageRequestHeaders
|
||||
|
||||
err := executor.Run(ctx, stage)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(ref) != "bar" {
|
||||
t.Errorf("expected value 'bar', got %s", ctx.GetRefValue(ref))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcatExecutor_Run_SingleStage(t *testing.T) {
|
||||
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
srcRef := &Ref{Type: RefTypeRequestHeader, Name: "test"}
|
||||
cmd := &concatCommand{targetRef: ref, values: []interface{}{"a", srcRef, "b"}}
|
||||
executor := cmd.CreateExecutor()
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(srcRef, "-")
|
||||
stage := StageRequestHeaders
|
||||
|
||||
err := executor.Run(ctx, stage)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(ref) != "a-b" {
|
||||
t.Errorf("expected value 'a-b', got %s", ctx.GetRefValue(ref))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcatExecutor_Run_MultiStages(t *testing.T) {
|
||||
ref := &Ref{Type: RefTypeResponseHeader, Name: "foo"}
|
||||
srcRef := &Ref{Type: RefTypeRequestHeader, Name: "test"}
|
||||
cmd := &concatCommand{targetRef: ref, values: []interface{}{"a", srcRef, "b"}}
|
||||
executor := cmd.CreateExecutor()
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(srcRef, "-")
|
||||
|
||||
err := executor.Run(ctx, StageRequestHeaders)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
err = executor.Run(ctx, StageResponseHeaders)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(ref) != "a-b" {
|
||||
t.Errorf("expected value 'a-b', got %s", ctx.GetRefValue(ref))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyExecutor_Run_SingleStage(t *testing.T) {
|
||||
source := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
target := &Ref{Type: RefTypeRequestHeader, Name: "bar"}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(source, "baz")
|
||||
cmd := ©Command{sourceRef: source, targetRef: target}
|
||||
executor := cmd.CreateExecutor()
|
||||
stage := StageRequestHeaders
|
||||
|
||||
err := executor.Run(ctx, stage)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(target) != "baz" {
|
||||
t.Errorf("expected value 'baz' for target, got %s", ctx.GetRefValue(target))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyExecutor_Run_MultiStages(t *testing.T) {
|
||||
source := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
target := &Ref{Type: RefTypeResponseHeader, Name: "bar"}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(source, "baz")
|
||||
cmd := ©Command{sourceRef: source, targetRef: target}
|
||||
executor := cmd.CreateExecutor()
|
||||
|
||||
err := executor.Run(ctx, StageRequestHeaders)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
err = executor.Run(ctx, StageResponseHeaders)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(target) != "baz" {
|
||||
t.Errorf("expected value 'baz' for target, got %s", ctx.GetRefValue(target))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteExecutor_Run(t *testing.T) {
|
||||
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(ref, "bar")
|
||||
cmd := &deleteCommand{targetRef: ref}
|
||||
executor := cmd.CreateExecutor()
|
||||
stage := StageRequestHeaders
|
||||
|
||||
err := executor.Run(ctx, stage)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ctx.GetRefValue(ref) != "" {
|
||||
t.Errorf("expected value to be deleted, got %s", ctx.GetRefValue(ref))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameExecutor_Run(t *testing.T) {
|
||||
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRefValue(ref, "bar")
|
||||
cmd := &renameCommand{targetRef: ref, newName: "baz"}
|
||||
executor := cmd.CreateExecutor()
|
||||
stage := StageRequestHeaders
|
||||
|
||||
err := executor.Run(ctx, stage)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
newRef := &Ref{Type: ref.Type, Name: "baz"}
|
||||
if ctx.GetRefValue(newRef) != "bar" {
|
||||
t.Errorf("expected value 'bar' for new name, got %s", ctx.GetRefValue(newRef))
|
||||
}
|
||||
if ctx.GetRefValue(ref) != "" {
|
||||
t.Errorf("expected old name to be deleted, got %s", ctx.GetRefValue(ref))
|
||||
}
|
||||
}
|
||||
325
plugins/wasm-go/extensions/traffic-editor/pkg/condition.go
Normal file
325
plugins/wasm-go/extensions/traffic-editor/pkg/condition.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
conditionTypeEquals = "equals"
|
||||
conditionTypePrefix = "prefix"
|
||||
conditionTypeSuffix = "suffix"
|
||||
conditionTypeContains = "contains"
|
||||
conditionTypeRegex = "regex"
|
||||
)
|
||||
|
||||
var (
|
||||
conditionFactories = map[string]func(gjson.Result) (Condition, error){
|
||||
conditionTypeEquals: newEqualsCondition,
|
||||
conditionTypePrefix: newPrefixCondition,
|
||||
conditionTypeSuffix: newSuffixCondition,
|
||||
conditionTypeContains: newContainsCondition,
|
||||
conditionTypeRegex: newRegexCondition,
|
||||
}
|
||||
)
|
||||
|
||||
type ConditionSet struct {
|
||||
Conditions []Condition `json:"conditions,omitempty"`
|
||||
RelatedStages map[Stage]bool `json:"-"`
|
||||
}
|
||||
|
||||
func (s *ConditionSet) FromJson(json gjson.Result) error {
|
||||
relatedStages := map[Stage]bool{}
|
||||
s.Conditions = nil
|
||||
if conditionsJson := json.Get("conditions"); conditionsJson.Exists() && conditionsJson.IsArray() {
|
||||
for _, item := range conditionsJson.Array() {
|
||||
if condition, err := CreateCondition(item); err != nil {
|
||||
return fmt.Errorf("failed to create condition from json: %v\n %v", err, item)
|
||||
} else {
|
||||
s.Conditions = append(s.Conditions, condition)
|
||||
for _, ref := range condition.GetRefs() {
|
||||
relatedStages[ref.GetStage()] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.RelatedStages = relatedStages
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ConditionSet) Matches(editorContext EditorContext) bool {
|
||||
if len(s.Conditions) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, condition := range s.Conditions {
|
||||
if !condition.Evaluate(editorContext) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type Condition interface {
|
||||
GetType() string
|
||||
GetRefs() []*Ref
|
||||
Evaluate(ctx EditorContext) bool
|
||||
}
|
||||
|
||||
func CreateCondition(json gjson.Result) (Condition, error) {
|
||||
t := json.Get("type").String()
|
||||
if t == "" {
|
||||
return nil, errors.New("condition type is required")
|
||||
}
|
||||
if constructor, ok := conditionFactories[t]; !ok || constructor == nil {
|
||||
return nil, errors.New("unknown condition type: " + t)
|
||||
} else if condition, err := constructor(json); err != nil {
|
||||
return nil, fmt.Errorf("failed to create condition with type %s: %v", t, err)
|
||||
} else {
|
||||
for _, ref := range condition.GetRefs() {
|
||||
if ref.GetStage() >= StageResponseHeaders {
|
||||
return nil, fmt.Errorf("condition only supports request refs")
|
||||
}
|
||||
}
|
||||
return condition, nil
|
||||
}
|
||||
}
|
||||
|
||||
// equalsCondition
|
||||
func newEqualsCondition(json gjson.Result) (Condition, error) {
|
||||
value1 := json.Get("value1")
|
||||
if value1.Type != gjson.JSON {
|
||||
return nil, errors.New("equalsCondition: value1 field type must be JSON object")
|
||||
}
|
||||
value1Ref, err := NewRef(value1)
|
||||
if err != nil {
|
||||
return nil, errors.New("equalsCondition: failed to create value1 ref: " + err.Error())
|
||||
}
|
||||
value2 := json.Get("value2").String()
|
||||
return &equalsCondition{
|
||||
value1Ref: value1Ref,
|
||||
value2: value2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type equalsCondition struct {
|
||||
value1Ref *Ref
|
||||
value2 string
|
||||
}
|
||||
|
||||
func (c *equalsCondition) GetType() string {
|
||||
return conditionTypeEquals
|
||||
}
|
||||
|
||||
func (c *equalsCondition) GetRefs() []*Ref {
|
||||
return []*Ref{c.value1Ref}
|
||||
}
|
||||
|
||||
func (c *equalsCondition) Evaluate(ctx EditorContext) bool {
|
||||
log.Debugf("Evaluating equals condition: value1Ref=%v, value2=%s", c.value1Ref, c.value2)
|
||||
ref1Values := ctx.GetRefValues(c.value1Ref)
|
||||
if len(ref1Values) == 0 {
|
||||
log.Debugf("No values found for ref1: %v", c.value1Ref)
|
||||
return false
|
||||
}
|
||||
for _, value1 := range ref1Values {
|
||||
if value1 == c.value2 {
|
||||
log.Debugf("Condition matched: %s == %s", value1, c.value2)
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Debugf("No matches found for condition: value1Ref=%v, value2=%s", c.value1Ref, c.value2)
|
||||
return false
|
||||
}
|
||||
|
||||
// prefixCondition
|
||||
func newPrefixCondition(json gjson.Result) (Condition, error) {
|
||||
value := json.Get("value")
|
||||
if value.Type != gjson.JSON {
|
||||
return nil, errors.New("prefixCondition: value field type must be JSON object")
|
||||
}
|
||||
valueRef, err := NewRef(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("prefixCondition: failed to create value ref: " + err.Error())
|
||||
}
|
||||
prefix := json.Get("prefix").String()
|
||||
return &prefixCondition{
|
||||
valueRef: valueRef,
|
||||
prefix: prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type prefixCondition struct {
|
||||
valueRef *Ref
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (c *prefixCondition) GetType() string {
|
||||
return conditionTypePrefix
|
||||
}
|
||||
|
||||
func (c *prefixCondition) GetRefs() []*Ref {
|
||||
return []*Ref{c.valueRef}
|
||||
}
|
||||
|
||||
func (c *prefixCondition) Evaluate(ctx EditorContext) bool {
|
||||
log.Debugf("Evaluating prefix condition: valueRef=%v, prefix=%s", c.valueRef, c.prefix)
|
||||
refValues := ctx.GetRefValues(c.valueRef)
|
||||
if len(refValues) == 0 {
|
||||
log.Debugf("No values found for ref: %v", c.valueRef)
|
||||
return false
|
||||
}
|
||||
for _, value := range refValues {
|
||||
if strings.HasPrefix(value, c.prefix) {
|
||||
log.Debugf("Condition matched: %s starts with %s", value, c.prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Debugf("No matches found for condition: valueRef=%v, prefix=%s", c.valueRef, c.prefix)
|
||||
return false
|
||||
}
|
||||
|
||||
// suffixCondition
|
||||
func newSuffixCondition(json gjson.Result) (Condition, error) {
|
||||
value := json.Get("value")
|
||||
if value.Type != gjson.JSON {
|
||||
return nil, errors.New("suffixCondition: value field type must be JSON object")
|
||||
}
|
||||
valueRef, err := NewRef(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("suffixCondition: failed to create value ref: " + err.Error())
|
||||
}
|
||||
suffix := json.Get("suffix").String()
|
||||
return &suffixCondition{
|
||||
valueRef: valueRef,
|
||||
suffix: suffix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type suffixCondition struct {
|
||||
valueRef *Ref
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (c *suffixCondition) GetType() string {
|
||||
return conditionTypeSuffix
|
||||
}
|
||||
|
||||
func (c *suffixCondition) GetRefs() []*Ref {
|
||||
return []*Ref{c.valueRef}
|
||||
}
|
||||
func (c *suffixCondition) Evaluate(ctx EditorContext) bool {
|
||||
log.Debugf("Evaluating suffix condition: valueRef=%v, prefix=%s", c.valueRef, c.suffix)
|
||||
refValues := ctx.GetRefValues(c.valueRef)
|
||||
if len(refValues) == 0 {
|
||||
log.Debugf("No values found for ref: %v", c.valueRef)
|
||||
return false
|
||||
}
|
||||
for _, value := range refValues {
|
||||
if strings.HasSuffix(value, c.suffix) {
|
||||
log.Debugf("Condition matched: %s ends with %s", value, c.suffix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Debugf("No matches found for condition: valueRef=%v, prefix=%s", c.valueRef, c.suffix)
|
||||
return false
|
||||
}
|
||||
|
||||
// containsCondition
|
||||
func newContainsCondition(json gjson.Result) (Condition, error) {
|
||||
value := json.Get("value")
|
||||
if value.Type != gjson.JSON {
|
||||
return nil, errors.New("containsCondition: value field type must be JSON object")
|
||||
}
|
||||
valueRef, err := NewRef(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("containsCondition: failed to create value ref: " + err.Error())
|
||||
}
|
||||
part := json.Get("part").String()
|
||||
return &containsCondition{
|
||||
valueRef: valueRef,
|
||||
part: part,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type containsCondition struct {
|
||||
valueRef *Ref
|
||||
part string
|
||||
}
|
||||
|
||||
func (c *containsCondition) GetType() string {
|
||||
return conditionTypeContains
|
||||
}
|
||||
|
||||
func (c *containsCondition) GetRefs() []*Ref {
|
||||
return []*Ref{c.valueRef}
|
||||
}
|
||||
|
||||
func (c *containsCondition) Evaluate(ctx EditorContext) bool {
|
||||
refValues := ctx.GetRefValues(c.valueRef)
|
||||
if len(refValues) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, value := range refValues {
|
||||
if strings.Contains(value, c.part) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// regexCondition
|
||||
func newRegexCondition(json gjson.Result) (Condition, error) {
|
||||
value := json.Get("value")
|
||||
if value.Type != gjson.JSON {
|
||||
return nil, errors.New("regexCondition: value field type must be JSON object")
|
||||
}
|
||||
valueRef, err := NewRef(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("regexCondition: failed to create value ref: " + err.Error())
|
||||
}
|
||||
patternStr := json.Get("pattern").String()
|
||||
pattern, err := regexp.Compile(patternStr)
|
||||
if err != nil {
|
||||
return nil, errors.New("regexCondition: failed to compile pattern: " + err.Error())
|
||||
}
|
||||
return ®exCondition{
|
||||
valueRef: valueRef,
|
||||
pattern: pattern,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type regexCondition struct {
|
||||
valueRef *Ref
|
||||
pattern *regexp.Regexp
|
||||
}
|
||||
|
||||
func (c *regexCondition) GetType() string {
|
||||
return conditionTypeRegex
|
||||
}
|
||||
|
||||
func (c *regexCondition) Evaluate(ctx EditorContext) bool {
|
||||
log.Debugf("Evaluating regex condition: valueRef=%v, pattern=%s", c.valueRef, c.pattern.String())
|
||||
refValues := ctx.GetRefValues(c.valueRef)
|
||||
if len(refValues) == 0 {
|
||||
log.Debugf("No values found for ref: %v", c.valueRef)
|
||||
return false
|
||||
}
|
||||
for _, value := range refValues {
|
||||
if c.pattern.MatchString(value) {
|
||||
log.Debugf("Condition matched: %s matches %s", value, c.pattern.String())
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Debugf("No matches found for condition: valueRef=%v, pattern=%s", c.valueRef, c.pattern.String())
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *regexCondition) GetRefs() []*Ref {
|
||||
return []*Ref{c.valueRef}
|
||||
}
|
||||
217
plugins/wasm-go/extensions/traffic-editor/pkg/condition_test.go
Normal file
217
plugins/wasm-go/extensions/traffic-editor/pkg/condition_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// --- equalsCondition tests ---
|
||||
func TestEqualsCondition_Match(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
|
||||
if !cond.Evaluate(ctx) {
|
||||
t.Error("equalsCondition should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualsCondition_NoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
|
||||
cond, _ := CreateCondition(json)
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-test": {"def"}})
|
||||
if cond.Evaluate(ctx) {
|
||||
t.Error("equalsCondition should not match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- prefixCondition tests ---
|
||||
func TestPrefixCondition_Match(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
|
||||
if !cond.Evaluate(ctx) {
|
||||
t.Error("prefixCondition should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixCondition_NoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}`)
|
||||
cond, _ := CreateCondition(json)
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"bazbar"}})
|
||||
if cond.Evaluate(ctx) {
|
||||
t.Error("prefixCondition should not match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- suffixCondition tests ---
|
||||
func TestSuffixCondition_Match(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"suffix","value":{"type":"request_header","name":"x-end"},"suffix":"xyz"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-end": {"123xyz"}})
|
||||
if !cond.Evaluate(ctx) {
|
||||
t.Error("suffixCondition should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuffixCondition_NoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"suffix","value":{"type":"request_header","name":"x-end"},"suffix":"xyz"}`)
|
||||
cond, _ := CreateCondition(json)
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-end": {"xyz123"}})
|
||||
if cond.Evaluate(ctx) {
|
||||
t.Error("suffixCondition should not match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- containsCondition tests ---
|
||||
func TestContainsCondition_Match(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"contains","value":{"type":"request_query","name":"foo"},"part":"baz"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
|
||||
if !cond.Evaluate(ctx) {
|
||||
t.Error("containsCondition should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsCondition_NoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"contains","value":{"type":"request_query","name":"foo"},"part":"baz"}`)
|
||||
cond, _ := CreateCondition(json)
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
|
||||
if cond.Evaluate(ctx) {
|
||||
t.Error("containsCondition should not match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- regexCondition tests ---
|
||||
func TestRegexCondition_Match(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"regex","value":{"type":"request_header","name":"x-reg"},"pattern":"^abc.*"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-reg": {"abcdef"}})
|
||||
if !cond.Evaluate(ctx) {
|
||||
t.Error("regexCondition should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCondition_NoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"regex","value":{"type":"request_header","name":"x-reg"},"pattern":"^abc.*"}`)
|
||||
cond, _ := CreateCondition(json)
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-reg": {"defabc"}})
|
||||
if cond.Evaluate(ctx) {
|
||||
t.Error("regexCondition should not match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- CreateCondition error cases ---
|
||||
func TestCreateCondition_UnknownType(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"unknown","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
|
||||
_, err := CreateCondition(json)
|
||||
if err == nil {
|
||||
t.Error("CreateCondition should fail for unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCondition_MissingType(t *testing.T) {
|
||||
json := gjson.Parse(`{"value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
|
||||
_, err := CreateCondition(json)
|
||||
if err == nil {
|
||||
t.Error("CreateCondition should fail for missing type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCondition_InvalidRefType(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"equals","value1":{"type":"invalid_type","name":"x-test"},"value2":"abc"}`)
|
||||
_, err := CreateCondition(json)
|
||||
if err == nil {
|
||||
t.Error("CreateCondition should fail for invalid ref type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCondition_MissingRefName(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header"},"value2":"abc"}`)
|
||||
_, err := CreateCondition(json)
|
||||
if err == nil {
|
||||
t.Error("CreateCondition should fail for missing ref name")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ConditionSet tests ---
|
||||
func TestConditionSet_Matches_AllMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"conditions":[{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"},{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}]}`)
|
||||
var set ConditionSet
|
||||
if err := set.FromJson(json); err != nil {
|
||||
t.Fatalf("FromJson failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
|
||||
if !set.Matches(ctx) {
|
||||
t.Error("ConditionSet should match when all conditions match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionSet_Matches_OneNoMatch(t *testing.T) {
|
||||
json := gjson.Parse(`{"conditions":[{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"},{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}]}`)
|
||||
var set ConditionSet
|
||||
if err := set.FromJson(json); err != nil {
|
||||
t.Fatalf("FromJson failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"baz"}})
|
||||
if set.Matches(ctx) {
|
||||
t.Error("ConditionSet should not match if one condition does not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionSet_Matches_Empty(t *testing.T) {
|
||||
json := gjson.Parse(`{"conditions":[]}`)
|
||||
var set ConditionSet
|
||||
if err := set.FromJson(json); err != nil {
|
||||
t.Fatalf("FromJson failed: %v", err)
|
||||
}
|
||||
ctx := NewEditorContext()
|
||||
if !set.Matches(ctx) {
|
||||
t.Error("ConditionSet with no conditions should always match")
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetType/GetRefs coverage ---
|
||||
func TestCondition_GetTypeAndRefs(t *testing.T) {
|
||||
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
|
||||
cond, err := CreateCondition(json)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCondition failed: %v", err)
|
||||
}
|
||||
if cond.GetType() != "equals" {
|
||||
t.Error("GetType should return 'equals'")
|
||||
}
|
||||
refs := cond.GetRefs()
|
||||
if len(refs) != 1 || refs[0].Type != "request_header" || refs[0].Name != "x-test" {
|
||||
t.Error("GetRefs should return correct ref")
|
||||
}
|
||||
}
|
||||
310
plugins/wasm-go/extensions/traffic-editor/pkg/context.go
Normal file
310
plugins/wasm-go/extensions/traffic-editor/pkg/context.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
type Stage int
|
||||
|
||||
const (
|
||||
StageInvalid Stage = iota
|
||||
StageRequestHeaders
|
||||
StageRequestBody
|
||||
StageResponseHeaders
|
||||
StageResponseBody
|
||||
|
||||
pathHeader = ":path"
|
||||
)
|
||||
|
||||
var (
|
||||
OrderedStages = []Stage{
|
||||
StageRequestHeaders,
|
||||
StageRequestBody,
|
||||
StageResponseHeaders,
|
||||
StageResponseBody,
|
||||
}
|
||||
Stage2String = map[Stage]string{
|
||||
StageRequestHeaders: "request_headers",
|
||||
StageRequestBody: "request_body",
|
||||
StageResponseHeaders: "response_headers",
|
||||
StageResponseBody: "response_body",
|
||||
}
|
||||
)
|
||||
|
||||
type EditorContext interface {
|
||||
GetEffectiveCommandSet() *CommandSet
|
||||
SetEffectiveCommandSet(cmdSet *CommandSet)
|
||||
GetCommandExecutors() []Executor
|
||||
SetCommandExecutors(executors []Executor)
|
||||
GetCurrentStage() Stage
|
||||
SetCurrentStage(stage Stage)
|
||||
|
||||
GetRequestPath() string
|
||||
SetRequestPath(path string)
|
||||
GetRequestHeader(key string) []string
|
||||
GetRequestHeaders() map[string][]string
|
||||
SetRequestHeaders(map[string][]string)
|
||||
GetRequestQuery(key string) []string
|
||||
GetRequestQueries() map[string][]string
|
||||
SetRequestQueries(map[string][]string)
|
||||
GetResponseHeader(key string) []string
|
||||
GetResponseHeaders() map[string][]string
|
||||
SetResponseHeaders(map[string][]string)
|
||||
|
||||
GetRefValue(ref *Ref) string
|
||||
GetRefValues(ref *Ref) []string
|
||||
SetRefValue(ref *Ref, value string)
|
||||
SetRefValues(ref *Ref, values []string)
|
||||
DeleteRefValues(ref *Ref)
|
||||
|
||||
IsRequestHeadersDirty() bool
|
||||
IsResponseHeadersDirty() bool
|
||||
ResetDirtyFlags()
|
||||
}
|
||||
|
||||
func NewEditorContext() EditorContext {
|
||||
return &editorContext{}
|
||||
}
|
||||
|
||||
type editorContext struct {
|
||||
effectiveCommandSet *CommandSet
|
||||
commandExecutors []Executor
|
||||
|
||||
currentStage Stage
|
||||
|
||||
requestPath string
|
||||
requestHeaders map[string][]string
|
||||
requestQueries map[string][]string
|
||||
responseHeaders map[string][]string
|
||||
|
||||
requestHeadersDirty bool
|
||||
responseHeadersDirty bool
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetEffectiveCommandSet() *CommandSet {
|
||||
return ctx.effectiveCommandSet
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetEffectiveCommandSet(cmdSet *CommandSet) {
|
||||
ctx.effectiveCommandSet = cmdSet
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetCommandExecutors() []Executor {
|
||||
return ctx.commandExecutors
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetCommandExecutors(executors []Executor) {
|
||||
ctx.commandExecutors = executors
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetCurrentStage() Stage {
|
||||
return ctx.currentStage
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetCurrentStage(stage Stage) {
|
||||
ctx.currentStage = stage
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRequestPath() string {
|
||||
return ctx.requestPath
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetRequestPath(path string) {
|
||||
ctx.requestPath = path
|
||||
ctx.savePathToHeader()
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRequestHeader(key string) []string {
|
||||
if ctx.requestHeaders == nil {
|
||||
return nil
|
||||
}
|
||||
return ctx.requestHeaders[strings.ToLower(key)]
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRequestHeaders() map[string][]string {
|
||||
return maps.Clone(ctx.requestHeaders)
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetRequestHeaders(headers map[string][]string) {
|
||||
ctx.requestHeaders = headers
|
||||
ctx.loadPathFromHeader()
|
||||
ctx.requestHeadersDirty = true
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRequestQuery(key string) []string {
|
||||
if ctx.requestQueries == nil {
|
||||
return nil
|
||||
}
|
||||
return ctx.requestQueries[key]
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRequestQueries() map[string][]string {
|
||||
return maps.Clone(ctx.requestQueries)
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetRequestQueries(queries map[string][]string) {
|
||||
ctx.requestQueries = queries
|
||||
ctx.savePathToHeader()
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetResponseHeader(key string) []string {
|
||||
if ctx.responseHeaders == nil {
|
||||
return nil
|
||||
}
|
||||
return ctx.responseHeaders[strings.ToLower(key)]
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetResponseHeaders() map[string][]string {
|
||||
return maps.Clone(ctx.responseHeaders)
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetResponseHeaders(headers map[string][]string) {
|
||||
ctx.responseHeaders = headers
|
||||
ctx.responseHeadersDirty = true
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRefValue(ref *Ref) string {
|
||||
values := ctx.GetRefValues(ref)
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
|
||||
func (ctx *editorContext) GetRefValues(ref *Ref) []string {
|
||||
if ref == nil {
|
||||
return nil
|
||||
}
|
||||
switch ref.Type {
|
||||
case RefTypeRequestHeader:
|
||||
return ctx.GetRequestHeader(strings.ToLower(ref.Name))
|
||||
case RefTypeRequestQuery:
|
||||
return ctx.GetRequestQuery(ref.Name)
|
||||
case RefTypeResponseHeader:
|
||||
return ctx.GetResponseHeader(strings.ToLower(ref.Name))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetRefValue(ref *Ref, value string) {
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
ctx.SetRefValues(ref, []string{value})
|
||||
}
|
||||
|
||||
func (ctx *editorContext) SetRefValues(ref *Ref, values []string) {
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
switch ref.Type {
|
||||
case RefTypeRequestHeader:
|
||||
if ctx.requestHeaders == nil {
|
||||
ctx.requestHeaders = make(map[string][]string)
|
||||
}
|
||||
loweredRefName := strings.ToLower(ref.Name)
|
||||
ctx.requestHeaders[loweredRefName] = values
|
||||
ctx.requestHeadersDirty = true
|
||||
if loweredRefName == pathHeader {
|
||||
ctx.loadPathFromHeader()
|
||||
}
|
||||
break
|
||||
case RefTypeRequestQuery:
|
||||
if ctx.requestQueries == nil {
|
||||
ctx.requestQueries = make(map[string][]string)
|
||||
}
|
||||
ctx.requestQueries[ref.Name] = values
|
||||
ctx.savePathToHeader()
|
||||
break
|
||||
case RefTypeResponseHeader:
|
||||
if ctx.responseHeaders == nil {
|
||||
ctx.responseHeaders = make(map[string][]string)
|
||||
}
|
||||
ctx.responseHeaders[strings.ToLower(ref.Name)] = values
|
||||
ctx.responseHeadersDirty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *editorContext) DeleteRefValues(ref *Ref) {
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
switch ref.Type {
|
||||
case RefTypeRequestHeader:
|
||||
delete(ctx.requestHeaders, strings.ToLower(ref.Name))
|
||||
ctx.requestHeadersDirty = true
|
||||
break
|
||||
case RefTypeRequestQuery:
|
||||
delete(ctx.requestQueries, ref.Name)
|
||||
ctx.savePathToHeader()
|
||||
break
|
||||
case RefTypeResponseHeader:
|
||||
delete(ctx.responseHeaders, strings.ToLower(ref.Name))
|
||||
ctx.responseHeadersDirty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *editorContext) IsRequestHeadersDirty() bool {
|
||||
return ctx.requestHeadersDirty
|
||||
}
|
||||
|
||||
func (ctx *editorContext) IsResponseHeadersDirty() bool {
|
||||
return ctx.responseHeadersDirty
|
||||
}
|
||||
|
||||
func (ctx *editorContext) ResetDirtyFlags() {
|
||||
ctx.requestHeadersDirty = false
|
||||
ctx.responseHeadersDirty = false
|
||||
}
|
||||
|
||||
func (ctx *editorContext) savePathToHeader() {
|
||||
u, err := url.Parse(ctx.requestPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build the new path with query strings: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
for k, vs := range ctx.requestQueries {
|
||||
for _, v := range vs {
|
||||
query.Add(k, v)
|
||||
}
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
ctx.SetRefValue(&Ref{Type: RefTypeRequestHeader, Name: pathHeader}, u.String())
|
||||
}
|
||||
|
||||
func (ctx *editorContext) loadPathFromHeader() {
|
||||
paths := ctx.GetRequestHeader(pathHeader)
|
||||
|
||||
if len(paths) == 0 || paths[0] == "" {
|
||||
log.Warn("the request has an empty path")
|
||||
ctx.requestPath = ""
|
||||
ctx.requestQueries = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
|
||||
path := paths[0]
|
||||
queries := make(map[string][]string)
|
||||
|
||||
u, err := url.Parse(path)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse the request path: %s", path)
|
||||
ctx.requestPath = ""
|
||||
ctx.requestQueries = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.requestPath = u.Path
|
||||
for k, vs := range u.Query() {
|
||||
queries[k] = vs
|
||||
}
|
||||
ctx.requestQueries = queries
|
||||
}
|
||||
218
plugins/wasm-go/extensions/traffic-editor/pkg/context_test.go
Normal file
218
plugins/wasm-go/extensions/traffic-editor/pkg/context_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestRef(t, name string) *Ref {
|
||||
return &Ref{Type: t, Name: name}
|
||||
}
|
||||
|
||||
func TestEditorContext_CommandSetAndExecutors(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
cmdSet := &CommandSet{}
|
||||
ctx.SetEffectiveCommandSet(cmdSet)
|
||||
if ctx.GetEffectiveCommandSet() != cmdSet {
|
||||
t.Errorf("EffectiveCommandSet not set/get correctly")
|
||||
}
|
||||
|
||||
executors := []Executor{nil, nil}
|
||||
ctx.SetCommandExecutors(executors)
|
||||
if !reflect.DeepEqual(ctx.GetCommandExecutors(), executors) {
|
||||
t.Errorf("CommandExecutors not set/get correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_Stage(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
ctx.SetCurrentStage(StageRequestHeaders)
|
||||
if ctx.GetCurrentStage() != StageRequestHeaders {
|
||||
t.Errorf("CurrentStage not set/get correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_RequestPath(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
ctx.SetRequestPath("/foo/bar")
|
||||
if ctx.GetRequestPath() != "/foo/bar" {
|
||||
t.Errorf("RequestPath not set/get correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_RequestHeaders(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
headers := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
|
||||
ctx.SetRequestHeaders(headers)
|
||||
if !reflect.DeepEqual(ctx.GetRequestHeaders(), headers) {
|
||||
t.Errorf("RequestHeaders not set/get correctly")
|
||||
}
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty not set correctly")
|
||||
}
|
||||
if got := ctx.GetRequestHeader("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
|
||||
t.Errorf("GetRequestHeader failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_RequestQueries(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
queries := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
|
||||
ctx.SetRequestQueries(queries)
|
||||
if !reflect.DeepEqual(ctx.GetRequestQueries(), queries) {
|
||||
t.Errorf("RequestQueries not set/get correctly")
|
||||
}
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty not set correctly")
|
||||
}
|
||||
if got := ctx.GetRequestQuery("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
|
||||
t.Errorf("GetRequestQuery failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_ResponseHeaders(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
headers := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
|
||||
ctx.SetResponseHeaders(headers)
|
||||
if !reflect.DeepEqual(ctx.GetResponseHeaders(), headers) {
|
||||
t.Errorf("ResponseHeaders not set/get correctly")
|
||||
}
|
||||
if !ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty not set correctly")
|
||||
}
|
||||
if got := ctx.GetResponseHeader("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
|
||||
t.Errorf("GetResponseHeader failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_RefValueAndValues(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
rh := newTestRef(RefTypeRequestHeader, "foo")
|
||||
rq := newTestRef(RefTypeRequestQuery, "bar")
|
||||
rh2 := newTestRef(RefTypeResponseHeader, "baz")
|
||||
|
||||
ctx.SetRefValue(rh, "v1")
|
||||
ctx.SetRefValues(rq, []string{"v2", "v3"})
|
||||
ctx.SetRefValues(rh2, []string{"v4"})
|
||||
|
||||
if v := ctx.GetRefValue(rh); v != "v1" {
|
||||
t.Errorf("GetRefValue(RequestHeader) failed: %v", v)
|
||||
}
|
||||
if v := ctx.GetRefValues(rq); !reflect.DeepEqual(v, []string{"v2", "v3"}) {
|
||||
t.Errorf("GetRefValues(RequestQuery) failed: %v", v)
|
||||
}
|
||||
if v := ctx.GetRefValues(rh2); !reflect.DeepEqual(v, []string{"v4"}) {
|
||||
t.Errorf("GetRefValues(ResponseHeader) failed: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_DeleteRefValues(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
rh := newTestRef(RefTypeRequestHeader, "foo")
|
||||
rq := newTestRef(RefTypeRequestQuery, "bar")
|
||||
rh2 := newTestRef(RefTypeResponseHeader, "baz")
|
||||
|
||||
ctx.SetRefValue(rh, "v1")
|
||||
ctx.SetRefValues(rq, []string{"v2", "v3"})
|
||||
ctx.SetRefValues(rh2, []string{"v4"})
|
||||
|
||||
ctx.DeleteRefValues(rh)
|
||||
ctx.DeleteRefValues(rq)
|
||||
ctx.DeleteRefValues(rh2)
|
||||
|
||||
if v := ctx.GetRefValues(rh); len(v) != 0 {
|
||||
t.Errorf("DeleteRefValues(RequestHeader) failed: %v", v)
|
||||
}
|
||||
if v := ctx.GetRefValues(rq); len(v) != 0 {
|
||||
t.Errorf("DeleteRefValues(RequestQuery) failed: %v", v)
|
||||
}
|
||||
if v := ctx.GetRefValues(rh2); len(v) != 0 {
|
||||
t.Errorf("DeleteRefValues(ResponseHeader) failed: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_ResetDirtyFlags(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
ctx.SetRequestHeaders(map[string][]string{"foo": {"bar"}})
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
|
||||
ctx.SetResponseHeaders(map[string][]string{"foo": {"bar"}})
|
||||
ctx.ResetDirtyFlags()
|
||||
if ctx.IsRequestHeadersDirty() || ctx.IsRequestHeadersDirty() || ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResetDirtyFlags failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_IsRequestHeadersDirty_SetHeaders(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
if ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty should be false initially")
|
||||
}
|
||||
ctx.SetRequestHeaders(map[string][]string{"foo": {"bar"}})
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty should be true after SetRequestHeaders")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
if ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty should be false after ResetDirtyFlags")
|
||||
}
|
||||
ref := newTestRef(RefTypeRequestHeader, "foo")
|
||||
ctx.SetRefValue(ref, "baz")
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty should be true after SetRefValue")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
ctx.DeleteRefValues(ref)
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestHeadersDirty should be true after DeleteRefValues")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_IsRequestHeadersDirty_SetQueries(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
if ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestQueriesDirty should be false initially")
|
||||
}
|
||||
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestQueriesDirty should be true after SetRequestQueries")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
if ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestQueriesDirty should be false after ResetDirtyFlags")
|
||||
}
|
||||
ref := newTestRef(RefTypeRequestQuery, "foo")
|
||||
ctx.SetRefValues(ref, []string{"baz"})
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestQueriesDirty should be true after SetRefValues")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
ctx.DeleteRefValues(ref)
|
||||
if !ctx.IsRequestHeadersDirty() {
|
||||
t.Errorf("RequestQueriesDirty should be true after DeleteRefValues")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorContext_IsResponseHeadersDirty(t *testing.T) {
|
||||
ctx := NewEditorContext().(*editorContext)
|
||||
if ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty should be false initially")
|
||||
}
|
||||
ctx.SetResponseHeaders(map[string][]string{"foo": {"bar"}})
|
||||
if !ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty should be true after SetResponseHeaders")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
if ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty should be false after ResetDirtyFlags")
|
||||
}
|
||||
ref := newTestRef(RefTypeResponseHeader, "foo")
|
||||
ctx.SetRefValues(ref, []string{"baz"})
|
||||
if !ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty should be true after SetRefValues")
|
||||
}
|
||||
ctx.ResetDirtyFlags()
|
||||
ctx.DeleteRefValues(ref)
|
||||
if !ctx.IsResponseHeadersDirty() {
|
||||
t.Errorf("ResponseHeadersDirty should be true after DeleteRefValues")
|
||||
}
|
||||
}
|
||||
26
plugins/wasm-go/extensions/traffic-editor/pkg/mock_test.go
Normal file
26
plugins/wasm-go/extensions/traffic-editor/pkg/mock_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize mock logger for testing
|
||||
log.SetPluginLog(&mockLogger{})
|
||||
}
|
||||
|
||||
type mockLogger struct{}
|
||||
|
||||
func (m *mockLogger) Trace(msg string) {}
|
||||
func (m *mockLogger) Tracef(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Debug(msg string) {}
|
||||
func (m *mockLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Info(msg string) {}
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Warn(msg string) {}
|
||||
func (m *mockLogger) Warnf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Error(msg string) {}
|
||||
func (m *mockLogger) Errorf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Critical(msg string) {}
|
||||
func (m *mockLogger) Criticalf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) ResetID(pluginID string) {}
|
||||
64
plugins/wasm-go/extensions/traffic-editor/pkg/ref.go
Normal file
64
plugins/wasm-go/extensions/traffic-editor/pkg/ref.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
RefTypeRequestHeader = "request_header"
|
||||
RefTypeRequestQuery = "request_query"
|
||||
RefTypeResponseHeader = "response_header"
|
||||
)
|
||||
|
||||
var (
|
||||
refType2Stage = map[string]Stage{
|
||||
RefTypeRequestHeader: StageRequestHeaders,
|
||||
RefTypeRequestQuery: StageRequestHeaders,
|
||||
RefTypeResponseHeader: StageResponseHeaders,
|
||||
}
|
||||
)
|
||||
|
||||
type Ref struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
stage Stage
|
||||
}
|
||||
|
||||
func NewRef(json gjson.Result) (*Ref, error) {
|
||||
ref := &Ref{}
|
||||
|
||||
if t := json.Get("type").String(); t != "" {
|
||||
ref.Type = t
|
||||
} else {
|
||||
return nil, errors.New("missing type field")
|
||||
}
|
||||
|
||||
if _, ok := refType2Stage[ref.Type]; !ok {
|
||||
return nil, fmt.Errorf("unknown ref type: %s", ref.Type)
|
||||
}
|
||||
|
||||
if name := json.Get("name").String(); name != "" {
|
||||
ref.Name = name
|
||||
} else {
|
||||
return nil, errors.New("missing name field")
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
func (r *Ref) GetStage() Stage {
|
||||
if r.stage == 0 {
|
||||
if stage, ok := refType2Stage[r.Type]; ok {
|
||||
r.stage = stage
|
||||
}
|
||||
}
|
||||
return r.stage
|
||||
}
|
||||
|
||||
func (r *Ref) String() string {
|
||||
return fmt.Sprintf("%s/%s", r.Type, r.Name)
|
||||
}
|
||||
@@ -9,7 +9,7 @@ replace amap-tools => ../amap-tools
|
||||
require (
|
||||
amap-tools v0.0.0-00010101000000-000000000000
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
quark-search v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
@@ -24,6 +24,10 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500 h1:4BKKZ3BreIaIGub88nlvzihTK1uJmZYYoQ7r7Xkgb5Q=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1 h1:+usoX0B1cwECTA2qf73IaLGyCIMVopIMev5cBWGgEZk=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
|
||||
@@ -145,22 +145,43 @@ func (w *watcher) generateServiceEntry(host string) *v1alpha3.ServiceEntry {
|
||||
for _, ep := range strings.Split(w.Domain, common.CommaSeparator) {
|
||||
var endpoint *v1alpha3.WorkloadEntry
|
||||
if w.Type == string(registry.Static) {
|
||||
pair := strings.Split(ep, common.ColonSeparator)
|
||||
if len(pair) != 2 {
|
||||
log.Errorf("invalid endpoint:%s with static type", ep)
|
||||
return nil
|
||||
var ip string
|
||||
var portStr string
|
||||
// Support IPv6 format: [2001:db8::1]:8080
|
||||
if strings.HasPrefix(ep, "[") {
|
||||
// IPv6 format: [IPv6]:port
|
||||
lastBracket := strings.LastIndex(ep, "]")
|
||||
if lastBracket == -1 {
|
||||
log.Errorf("invalid IPv6 endpoint format:%s with static type", ep)
|
||||
return nil
|
||||
}
|
||||
ip = ep[1:lastBracket] // Extract IPv6 address without brackets
|
||||
if lastBracket+1 >= len(ep) || ep[lastBracket+1] != ':' {
|
||||
log.Errorf("invalid IPv6 endpoint format:%s with static type, missing colon after bracket", ep)
|
||||
return nil
|
||||
}
|
||||
portStr = ep[lastBracket+2:] // Extract port after "]:"
|
||||
} else {
|
||||
// IPv4 format: 192.168.1.1:8080
|
||||
pair := strings.Split(ep, common.ColonSeparator)
|
||||
if len(pair) != 2 {
|
||||
log.Errorf("invalid endpoint:%s with static type", ep)
|
||||
return nil
|
||||
}
|
||||
ip = pair[0]
|
||||
portStr = pair[1]
|
||||
}
|
||||
port, err := strconv.ParseUint(pair[1], 10, 32)
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
log.Errorf("invalid port:%s of endpoint:%s", pair[1], ep)
|
||||
log.Errorf("invalid port:%s of endpoint:%s", portStr, ep)
|
||||
return nil
|
||||
}
|
||||
if net.ParseIP(pair[0]) == nil {
|
||||
log.Errorf("invalid ip:%s of endpoint:%s", pair[0], ep)
|
||||
if net.ParseIP(ip) == nil {
|
||||
log.Errorf("invalid ip:%s of endpoint:%s", ip, ep)
|
||||
return nil
|
||||
}
|
||||
endpoint = &v1alpha3.WorkloadEntry{
|
||||
Address: pair[0],
|
||||
Address: ip,
|
||||
Ports: map[string]uint32{protocol: uint32(port)},
|
||||
}
|
||||
} else if w.Type == string(registry.DNS) {
|
||||
@@ -247,3 +268,4 @@ func (w *watcher) getSni(se *v1alpha3.ServiceEntry) string {
|
||||
func (w *watcher) GetRegistryType() string {
|
||||
return w.RegistryConfig.Type
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ type McpServerRule struct {
|
||||
MatchRoute []string `json:"_match_route_,omitempty"`
|
||||
Server *ServerConfig `json:"server,omitempty"`
|
||||
Tools []*McpTool `json:"tools,omitempty"`
|
||||
AllowTools []string `json:"allowTools,omitempty"`
|
||||
AllowTools []string `json:"allowTools"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
|
||||
@@ -434,7 +434,7 @@ func (w *watcher) processToolConfig(dataId, data string, credentials map[string]
|
||||
rule.Server.SecuritySchemes = toolsDescription.SecuritySchemes
|
||||
}
|
||||
|
||||
var allowTools []string
|
||||
allowTools := []string{}
|
||||
for _, t := range toolsDescription.Tools {
|
||||
convertTool := &provider.McpTool{Name: t.Name, Description: t.Description}
|
||||
|
||||
@@ -813,6 +813,9 @@ func generateServiceEntry(host string, services *model.Service) *v1alpha3.Servic
|
||||
endpoints := make([]*v1alpha3.WorkloadEntry, 0)
|
||||
|
||||
for _, service := range services.Hosts {
|
||||
if !service.Healthy || !service.Enable {
|
||||
continue
|
||||
}
|
||||
protocol := common.HTTP
|
||||
if service.Metadata != nil && service.Metadata["protocol"] != "" {
|
||||
protocol = common.ParseProtocol(service.Metadata["protocol"])
|
||||
|
||||
@@ -17,6 +17,7 @@ package v2
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -556,10 +557,19 @@ func (w *watcher) generateServiceEntry(host string, services []model.Instance) *
|
||||
if !isValidIP(service.Ip) {
|
||||
isDnsService = true
|
||||
}
|
||||
// Calculate weight from Nacos instance
|
||||
// Nacos weight is float64, need to convert to uint32 for Istio
|
||||
// Use math.Round to preserve fractional weights (e.g., 0.5, 1.5)
|
||||
// If weight is 0 or negative, use default weight 1
|
||||
weight := uint32(1)
|
||||
if service.Weight > 0 {
|
||||
weight = uint32(math.Round(service.Weight))
|
||||
}
|
||||
endpoint := &v1alpha3.WorkloadEntry{
|
||||
Address: service.Ip,
|
||||
Ports: map[string]uint32{port.Protocol: port.Number},
|
||||
Labels: service.Metadata,
|
||||
Weight: weight,
|
||||
}
|
||||
endpoints = append(endpoints, endpoint)
|
||||
}
|
||||
|
||||
200
registry/nacos/v2/watcher_test.go
Normal file
200
registry/nacos/v2/watcher_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nacos-group/nacos-sdk-go/v2/model"
|
||||
"istio.io/api/networking/v1alpha3"
|
||||
)
|
||||
|
||||
func Test_generateServiceEntry_Weight(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
services []model.Instance
|
||||
expectedWeights []uint32
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "normal integer weights",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 2.0},
|
||||
},
|
||||
expectedWeights: []uint32{5, 3, 2},
|
||||
description: "Integer weights should be converted correctly",
|
||||
},
|
||||
{
|
||||
name: "fractional weights with rounding",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.4},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 2.6},
|
||||
},
|
||||
expectedWeights: []uint32{5, 4, 3},
|
||||
description: "Fractional weights should be rounded to nearest integer",
|
||||
},
|
||||
{
|
||||
name: "zero weight defaults to 1",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 5.0},
|
||||
},
|
||||
expectedWeights: []uint32{1, 5},
|
||||
description: "Zero weight should default to 1",
|
||||
},
|
||||
{
|
||||
name: "negative weight defaults to 1",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: -1.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
|
||||
},
|
||||
expectedWeights: []uint32{1, 3},
|
||||
description: "Negative weight should default to 1",
|
||||
},
|
||||
{
|
||||
name: "very small fractional weight rounds to 0 then defaults to 1",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.4},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 0.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 0.6},
|
||||
},
|
||||
expectedWeights: []uint32{1, 1, 1},
|
||||
description: "Weights less than 0.5 round to 0, then default to 1; 0.5 and above round to 1",
|
||||
},
|
||||
{
|
||||
name: "large weights",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 100.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 50.5},
|
||||
},
|
||||
expectedWeights: []uint32{100, 51},
|
||||
description: "Large weights should be handled correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed weights",
|
||||
services: []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 1.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: -5.0},
|
||||
{Ip: "192.168.1.4", Port: 8080, Weight: 10.7},
|
||||
},
|
||||
expectedWeights: []uint32{1, 2, 1, 11},
|
||||
description: "Mixed zero, negative, and fractional weights",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
se := w.generateServiceEntry("test-host", tc.services)
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != len(tc.expectedWeights) {
|
||||
t.Fatalf("expected %d endpoints, got %d", len(tc.expectedWeights), len(se.Endpoints))
|
||||
}
|
||||
|
||||
for i, endpoint := range se.Endpoints {
|
||||
if endpoint.Weight != tc.expectedWeights[i] {
|
||||
t.Errorf("endpoint[%d]: expected weight %d, got %d (original weight: %f) - %s",
|
||||
i, tc.expectedWeights[i], endpoint.Weight, tc.services[i].Weight, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateServiceEntry_WeightFieldSet(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
services := []model.Instance{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0, Metadata: map[string]string{"zone": "a"}},
|
||||
}
|
||||
|
||||
se := w.generateServiceEntry("test-host", services)
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != 1 {
|
||||
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
|
||||
}
|
||||
|
||||
endpoint := se.Endpoints[0]
|
||||
|
||||
// Verify all fields are set correctly
|
||||
if endpoint.Address != "192.168.1.1" {
|
||||
t.Errorf("expected address 192.168.1.1, got %s", endpoint.Address)
|
||||
}
|
||||
|
||||
if endpoint.Weight != 5 {
|
||||
t.Errorf("expected weight 5, got %d", endpoint.Weight)
|
||||
}
|
||||
|
||||
if endpoint.Labels == nil || endpoint.Labels["zone"] != "a" {
|
||||
t.Errorf("expected labels with zone=a, got %v", endpoint.Labels)
|
||||
}
|
||||
|
||||
if endpoint.Ports == nil {
|
||||
t.Error("expected ports to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateServiceEntry_EmptyServices(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
se := w.generateServiceEntry("test-host", []model.Instance{})
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != 0 {
|
||||
t.Errorf("expected 0 endpoints for empty services, got %d", len(se.Endpoints))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateServiceEntry_DNSResolution(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
services := []model.Instance{
|
||||
{Ip: "example.com", Port: 8080, Weight: 5.0},
|
||||
}
|
||||
|
||||
se := w.generateServiceEntry("test-host", services)
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if se.Resolution != v1alpha3.ServiceEntry_DNS {
|
||||
t.Errorf("expected DNS resolution for domain name, got %v", se.Resolution)
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != 1 {
|
||||
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
|
||||
}
|
||||
|
||||
if se.Endpoints[0].Weight != 5 {
|
||||
t.Errorf("expected weight 5 for DNS endpoint, got %d", se.Endpoints[0].Weight)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@
|
||||
package nacos
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -352,10 +353,19 @@ func (w *watcher) generateServiceEntry(host string, services []model.SubscribeSe
|
||||
portList = append(portList, port)
|
||||
}
|
||||
}
|
||||
// Calculate weight from Nacos instance
|
||||
// Nacos weight is float64, need to convert to uint32 for Istio
|
||||
// Use math.Round to preserve fractional weights (e.g., 0.5, 1.5)
|
||||
// If weight is 0 or negative, use default weight 1
|
||||
weight := uint32(1)
|
||||
if service.Weight > 0 {
|
||||
weight = uint32(math.Round(service.Weight))
|
||||
}
|
||||
endpoint := v1alpha3.WorkloadEntry{
|
||||
Address: service.Ip,
|
||||
Ports: map[string]uint32{port.Protocol: port.Number},
|
||||
Labels: service.Metadata,
|
||||
Weight: weight,
|
||||
}
|
||||
endpoints = append(endpoints, &endpoint)
|
||||
}
|
||||
|
||||
173
registry/nacos/watcher_test.go
Normal file
173
registry/nacos/watcher_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package nacos
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nacos-group/nacos-sdk-go/model"
|
||||
)
|
||||
|
||||
func Test_generateServiceEntry_Weight(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
services []model.SubscribeService
|
||||
expectedWeights []uint32
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "normal integer weights",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 2.0},
|
||||
},
|
||||
expectedWeights: []uint32{5, 3, 2},
|
||||
description: "Integer weights should be converted correctly",
|
||||
},
|
||||
{
|
||||
name: "fractional weights with rounding",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.4},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 2.6},
|
||||
},
|
||||
expectedWeights: []uint32{5, 4, 3},
|
||||
description: "Fractional weights should be rounded to nearest integer",
|
||||
},
|
||||
{
|
||||
name: "zero weight defaults to 1",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 5.0},
|
||||
},
|
||||
expectedWeights: []uint32{1, 5},
|
||||
description: "Zero weight should default to 1",
|
||||
},
|
||||
{
|
||||
name: "negative weight defaults to 1",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: -1.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
|
||||
},
|
||||
expectedWeights: []uint32{1, 3},
|
||||
description: "Negative weight should default to 1",
|
||||
},
|
||||
{
|
||||
name: "very small fractional weight rounds to 0 then defaults to 1",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.4},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 0.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: 0.6},
|
||||
},
|
||||
expectedWeights: []uint32{1, 1, 1},
|
||||
description: "Weights less than 0.5 round to 0, then default to 1; 0.5 and above round to 1",
|
||||
},
|
||||
{
|
||||
name: "large weights",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 100.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 50.5},
|
||||
},
|
||||
expectedWeights: []uint32{100, 51},
|
||||
description: "Large weights should be handled correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed weights",
|
||||
services: []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
|
||||
{Ip: "192.168.1.2", Port: 8080, Weight: 1.5},
|
||||
{Ip: "192.168.1.3", Port: 8080, Weight: -5.0},
|
||||
{Ip: "192.168.1.4", Port: 8080, Weight: 10.7},
|
||||
},
|
||||
expectedWeights: []uint32{1, 2, 1, 11},
|
||||
description: "Mixed zero, negative, and fractional weights",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
se := w.generateServiceEntry("test-host", tc.services)
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != len(tc.expectedWeights) {
|
||||
t.Fatalf("expected %d endpoints, got %d", len(tc.expectedWeights), len(se.Endpoints))
|
||||
}
|
||||
|
||||
for i, endpoint := range se.Endpoints {
|
||||
if endpoint.Weight != tc.expectedWeights[i] {
|
||||
t.Errorf("endpoint[%d]: expected weight %d, got %d (original weight: %f) - %s",
|
||||
i, tc.expectedWeights[i], endpoint.Weight, tc.services[i].Weight, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateServiceEntry_WeightFieldSet(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
services := []model.SubscribeService{
|
||||
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0, Metadata: map[string]string{"zone": "a"}},
|
||||
}
|
||||
|
||||
se := w.generateServiceEntry("test-host", services)
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != 1 {
|
||||
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
|
||||
}
|
||||
|
||||
endpoint := se.Endpoints[0]
|
||||
|
||||
// Verify all fields are set correctly
|
||||
if endpoint.Address != "192.168.1.1" {
|
||||
t.Errorf("expected address 192.168.1.1, got %s", endpoint.Address)
|
||||
}
|
||||
|
||||
if endpoint.Weight != 5 {
|
||||
t.Errorf("expected weight 5, got %d", endpoint.Weight)
|
||||
}
|
||||
|
||||
if endpoint.Labels == nil || endpoint.Labels["zone"] != "a" {
|
||||
t.Errorf("expected labels with zone=a, got %v", endpoint.Labels)
|
||||
}
|
||||
|
||||
if endpoint.Ports == nil {
|
||||
t.Error("expected ports to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateServiceEntry_EmptyServices(t *testing.T) {
|
||||
w := &watcher{}
|
||||
|
||||
se := w.generateServiceEntry("test-host", []model.SubscribeService{})
|
||||
|
||||
if se == nil {
|
||||
t.Fatal("generateServiceEntry returned nil")
|
||||
}
|
||||
|
||||
if len(se.Endpoints) != 0 {
|
||||
t.Errorf("expected 0 endpoints for empty services, got %d", len(se.Endpoints))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user