Compare commits

...

24 Commits

Author SHA1 Message Date
zikunchang
f2fcd68ef8 feature: Support getting the API key from the request header when provider.apiTokens is not configured. (#3394)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2026-01-28 14:03:24 +08:00
rinfx
cbcc3ecf43 bugfix for model-mapper & model-router (#3370) 2026-01-28 10:52:45 +08:00
澄潭
a92c89ce61 fix: remove duplicate loadBalancerClass definition in service.yaml (#3400) 2026-01-27 18:48:39 +08:00
ThxCode-Chen
819f773297 feat: support upstream ipv6 static address (#3384)
Co-authored-by: EricaLiu <30773688+Erica177@users.noreply.github.com>
2026-01-26 17:30:09 +08:00
aias00
255f0bde76 feat: Map Nacos instance weights to Istio WorkloadEntry weights in watchers (#3342)
Co-authored-by: EricaLiu <30773688+Erica177@users.noreply.github.com>
2026-01-23 15:56:58 +08:00
woody
a2eb599eff Implement Vertex Raw mode support in AI Proxy (#3375) 2026-01-21 14:45:06 +08:00
rinfx
3a28a9b6a7 update wasm-go dependency (#3367) 2026-01-20 15:13:59 +08:00
woody
399d2f372e add support for image generation in Vertex AI provider (#3335) 2026-01-19 16:40:29 +08:00
TianHao Zhang
ac69eb5b27 fix concurrent SSE connections returning wrong endpoint (#3341) 2026-01-19 10:22:50 +08:00
johnlanni
9d8a1c2e95 Fix the issue of backend errors not being propagated in streamable proxy mode 2026-01-15 20:36:49 +08:00
johnlanni
fb71d7b33d fix(mcp): remove accept-encoding header to prevent response compression 2026-01-15 16:43:14 +08:00
aias00
eb7b22d2b9 fix: skip unhealthy or disabled services form nacos and always marshal AllowTools field (#3220)
Co-authored-by: EricaLiu <30773688+Erica177@users.noreply.github.com>
2026-01-15 10:46:21 +08:00
woody
f1a5f18c78 feat/ai proxy vertex ai compatible (#3324) 2026-01-14 10:13:00 +08:00
韩贤涛
e7010256fe feat: add authentication wrapper for debug endpoints (#3318) 2026-01-14 09:30:51 +08:00
rinfx
5e787b3258 Replace model-router and model-mapper with Go implementation (#3317) 2026-01-13 20:14:29 +08:00
woody
23fbe0e9e9 feat(vertex): 为 ai-proxy 插件的 Vertex AI Provider 添加 Express Mode 支持 || feat(vertex): Add Express Mode support to Vertex AI Provider of ai-proxy plug-in (#3301) 2026-01-13 20:00:05 +08:00
qshuai
72c87b3e15 docs: unknown config entry <show_limit_quota_header> in ai-token-ratelimit plugin (#3241) 2026-01-10 11:07:43 +08:00
CZJCC
78d4b33424 feat(ai-proxy): add Bearer Token authentication support for Bedrock p… (#3305) 2026-01-07 19:39:20 +08:00
澄潭
b09793c3d4 Update README.md 2026-01-04 10:45:03 +08:00
澄潭
5d7a30783f Update README.md 2026-01-04 09:33:30 +08:00
nixidexiangjiao
b98b51ef06 feat(ai-load-balancer): enhance global least request load balancer (#3255) 2025-12-29 09:28:56 +08:00
johnlanni
9c11c5406f update helm README.md
Change-Id: Ic216d36c4cb0e570c9084b63c9f250c9ab6f4cec
2025-12-26 17:35:49 +08:00
Wilson Wu
10ca6d9515 feat: add topology spread constraints for gateway and controller (#3171)
Signed-off-by: Wilson Wu <iwilsonwu@gmail.com>
2025-12-26 17:30:31 +08:00
Kent Dong
08a7204085 feat: Add traffic-editor plugin (#2825) 2025-12-26 17:29:55 +08:00
73 changed files with 8800 additions and 147 deletions

View File

@@ -45,7 +45,7 @@ Higress was born within Alibaba to solve the issues of Tengine reload affecting
You can click the button below to install the enterprise version of Higress:
[![Deploy on AlibabaCloud](https://img.alicdn.com/imgextra/i1/O1CN01e6vwe71EWTHoZEcpK_!!6000000000359-55-tps-170-40.svg)](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
[![Deploy on AlibabaCloud](https://img.alicdn.com/imgextra/i1/O1CN01e6vwe71EWTHoZEcpK_!!6000000000359-55-tps-170-40.svg)](https://www.aliyun.com/product/api-gateway?spm=higress-github.topbar.0.0.0)
If you use open-source Higress and wish to obtain enterprise-level support, you can contact the project maintainer johnlanni's email: **zty98751@alibaba-inc.com** or social media accounts (WeChat ID: **nomadao**, DingTalk ID: **chengtanzty**). Please note **Higress** when adding as a friend :)
@@ -119,7 +119,16 @@ If you are deploying on the cloud, it is recommended to use the [Enterprise Edit
Higress can function as a feature-rich ingress controller, which is compatible with many annotations of K8s' nginx ingress controller.
[Gateway API](https://gateway-api.sigs.k8s.io/) support is coming soon and will support smooth migration from Ingress API to Gateway API.
[Gateway API](https://gateway-api.sigs.k8s.io/) is already supported, and it supports a smooth migration from Ingress API to Gateway API.
Compared to ingress-nginx, the resource overhead has significantly decreased, and the speed at which route changes take effect has improved by ten times.
> The following resource overhead comparison comes from [sealos](https://github.com/labring).
>
> For details, you can read this [article](https://sealos.io/blog/sealos-envoy-vs-nginx-2000-tenants) to understand how sealos migrates the monitoring of **tens of thousands of ingress** resources from nginx ingress to higress.
![](https://img.alicdn.com/imgextra/i1/O1CN01bhEtb229eeMNBWmdP_!!6000000008093-2-tps-750-547.png)
- **Microservice gateway**:

View File

@@ -250,6 +250,10 @@ template:
tolerations:
{{- toYaml . | nindent 6 }}
{{- end }}
{{- with .Values.gateway.topologySpreadConstraints }}
topologySpreadConstraints:
{{- toYaml . | nindent 6 }}
{{- end }}
volumes:
- emptyDir: {}
name: workload-socket

View File

@@ -301,6 +301,10 @@ spec:
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.controller.topologySpreadConstraints }}
topologySpreadConstraints:
{{- toYaml . | nindent 8 }}
{{- end }}
volumes:
- name: log
emptyDir: {}

View File

@@ -24,9 +24,6 @@ spec:
{{- end }}
{{- with .Values.gateway.service.externalTrafficPolicy }}
externalTrafficPolicy: "{{ . }}"
{{- end }}
{{- with .Values.gateway.service.loadBalancerClass}}
loadBalancerClass: "{{ . }}"
{{- end }}
type: {{ .Values.gateway.service.type }}
ports:

View File

@@ -524,6 +524,8 @@ gateway:
affinity: {}
topologySpreadConstraints: []
# -- If specified, the gateway will act as a network gateway for the given network.
networkGateway: ""
@@ -631,6 +633,8 @@ controller:
affinity: {}
topologySpreadConstraints: []
autoscaling:
enabled: false
minReplicas: 1

View File

@@ -83,6 +83,7 @@ The command removes all the Kubernetes components associated with the chart and
| controller.serviceAccount.name | string | `""` | If not set and create is true, a name is generated using the fullname template |
| controller.tag | string | `""` | |
| controller.tolerations | list | `[]` | |
| controller.topologySpreadConstraints | list | `[]` | |
| downstream | object | `{"connectionBufferLimits":32768,"http2":{"initialConnectionWindowSize":1048576,"initialStreamWindowSize":65535,"maxConcurrentStreams":100},"idleTimeout":180,"maxRequestHeadersKb":60,"routeTimeout":0}` | Downstream config settings |
| gateway.affinity | object | `{}` | |
| gateway.annotations | object | `{}` | Annotations to apply to all resources |
@@ -152,6 +153,7 @@ The command removes all the Kubernetes components associated with the chart and
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
| gateway.tag | string | `""` | |
| gateway.tolerations | list | `[]` | |
| gateway.topologySpreadConstraints | list | `[]` | |
| gateway.unprivilegedPortSupported | string | `nil` | |
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |

View File

@@ -16,12 +16,13 @@ package bootstrap
import (
"fmt"
"istio.io/istio/pkg/config/mesh/meshwatcher"
"istio.io/istio/pkg/kube/krt"
"net"
"net/http"
"time"
"istio.io/istio/pkg/config/mesh/meshwatcher"
"istio.io/istio/pkg/kube/krt"
prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
@@ -436,10 +437,17 @@ func (s *Server) initHttpServer() error {
}
s.xdsServer.AddDebugHandlers(s.httpMux, nil, true, nil)
s.httpMux.HandleFunc("/ready", s.readyHandler)
s.httpMux.HandleFunc("/registry/watcherStatus", s.registryWatcherStatusHandler)
s.httpMux.HandleFunc("/registry/watcherStatus", s.withConditionalAuth(s.registryWatcherStatusHandler))
return nil
}
func (s *Server) withConditionalAuth(handler http.HandlerFunc) http.HandlerFunc {
if features.DebugAuth {
return s.xdsServer.AllowAuthenticatedOrLocalhost(handler)
}
return handler
}
// readyHandler checks whether the http server is ready
func (s *Server) readyHandler(w http.ResponseWriter, _ *http.Request) {
for name, fn := range s.readinessProbes {

View File

@@ -26,8 +26,8 @@ type config struct {
matchList []common.MatchRule
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer
redisClient *common.RedisClient
sharedMCPServer *common.MCPServer // Created once, thread-safe with sync.RWMutex
}
func (c *config) Destroy() {
@@ -110,6 +110,9 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
}
GlobalSSEPathSuffix = ssePathSuffix
// Create shared MCPServer once during config parsing (thread-safe with sync.RWMutex)
conf.sharedMCPServer = common.NewMCPServer(DefaultServerName, Version)
return conf, nil
}
@@ -125,9 +128,6 @@ func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
if childConfig.rateLimitConfig != nil {
newConfig.rateLimitConfig = childConfig.rateLimitConfig
}
if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer
}
return &newConfig
}

View File

@@ -37,6 +37,7 @@ type filter struct {
skipRequestBody bool
skipResponseBody bool
cachedResponseBody []byte
sseServer *common.SSEServer // SSE server instance for this filter (per-request, not shared)
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
@@ -135,11 +136,13 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
trimmed += "?" + rq
}
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
// Create SSE server instance for this filter (per-request, not shared)
// MCPServer is shared (thread-safe), but SSEServer must be per-request (contains request-specific messageEndpoint)
f.sseServer = common.NewSSEServer(f.config.sharedMCPServer,
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(trimmed),
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
f.serverName = f.sseServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
@@ -275,9 +278,9 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
if f.serverName != "" {
if f.config.redisClient != nil {
// handle default server
// handle SSE server for this filter instance
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
f.sseServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
} else {
_ = buffer.SetString(RedisNotEnabledResponseBody)

View File

@@ -16,40 +16,91 @@ import (
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `local seed = KEYS[1]
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
RedisLua = `local seed = tonumber(KEYS[1])
local hset_key = KEYS[2]
local current_target = KEYS[3]
local current_count = 0
local last_clean_key = KEYS[3]
local clean_interval = tonumber(KEYS[4])
local current_target = KEYS[5]
local healthy_count = tonumber(KEYS[6])
local enable_detail_log = KEYS[7]
math.randomseed(seed)
local function randomBool()
return math.random() >= 0.5
end
-- 1. Selection
local current_count = 0
local same_count_hits = 0
if redis.call('HEXISTS', hset_key, current_target) == 1 then
current_count = redis.call('HGET', hset_key, current_target)
for i = 4, #KEYS do
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
local count = redis.call('HGET', hset_key, KEYS[i])
if tonumber(count) < tonumber(current_count) then
current_target = KEYS[i]
current_count = count
elseif count == current_count and randomBool() then
current_target = KEYS[i]
end
end
end
for i = 8, 8 + healthy_count - 1 do
local host = KEYS[i]
local count = 0
local val = redis.call('HGET', hset_key, host)
if val then
count = tonumber(val) or 0
end
if same_count_hits == 0 or count < current_count then
current_target = host
current_count = count
same_count_hits = 1
elseif count == current_count then
same_count_hits = same_count_hits + 1
if math.random(same_count_hits) == 1 then
current_target = host
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
local new_count = redis.call("HGET", hset_key, current_target)
return current_target`
-- Collect host counts for logging
local host_details = {}
if enable_detail_log == "1" then
local fields = {}
for i = 8, #KEYS do
table.insert(fields, KEYS[i])
end
if #fields > 0 then
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
for i, val in ipairs(values) do
table.insert(host_details, fields[i])
table.insert(host_details, tostring(val or 0))
end
end
end
-- 2. Cleanup
local current_time = math.floor(seed / 1000000)
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
if current_time - last_clean_time >= clean_interval then
local all_keys = redis.call('HKEYS', hset_key)
if #all_keys > 0 then
-- Create a lookup table for current hosts (from index 8 onwards)
local current_hosts = {}
for i = 8, #KEYS do
current_hosts[KEYS[i]] = true
end
-- Remove keys not in current hosts
for _, host in ipairs(all_keys) do
if not current_hosts[host] then
redis.call('HDEL', hset_key, host)
end
end
end
redis.call('SET', last_clean_key, current_time)
end
return {current_target, new_count, host_details}`
)
type GlobalLeastRequestLoadBalancer struct {
redisClient wrapper.RedisClient
redisClient wrapper.RedisClient
maxRequestCount int64
cleanInterval int64 // seconds
enableDetailLog bool
}
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
}
// database default is 0
database := json.Get("database").Int()
lb.maxRequestCount = json.Get("maxRequestCount").Int()
lb.cleanInterval = json.Get("cleanInterval").Int()
if lb.cleanInterval == 0 {
lb.cleanInterval = 60 * 60 // default 60 minutes
} else {
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
}
lb.enableDetailLog = true
if val := json.Get("enableDetailLog"); val.Exists() {
lb.enableDetailLog = val.Bool()
}
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
ctx.SetContext("error", true)
return types.ActionContinue
}
allHostMap := make(map[string]struct{})
// Only healthy host can be selected
healthyHostArray := []string{}
for _, hostInfo := range hostInfos {
allHostMap[hostInfo[0]] = struct{}{}
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHostArray = append(healthyHostArray, hostInfo[0])
}
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
}
randomIndex := rand.Intn(len(healthyHostArray))
hostSelected := healthyHostArray[randomIndex]
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
keys := []interface{}{
time.Now().UnixMicro(),
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
lb.cleanInterval,
hostSelected,
len(healthyHostArray),
"0",
}
if lb.enableDetailLog {
keys[6] = "1"
}
for _, v := range healthyHostArray {
keys = append(keys, v)
}
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
for host := range allHostMap {
isHealthy := false
for _, hh := range healthyHostArray {
if host == hh {
isHealthy = true
break
}
}
if !isHealthy {
keys = append(keys, host)
}
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("HGetAll failed: %+v", err)
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
proxywasm.ResumeHttpRequest()
return
}
hostSelected = response.String()
valArray := response.Array()
if len(valArray) < 2 {
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
ctx.SetContext("error", true)
proxywasm.ResumeHttpRequest()
return
}
hostSelected = valArray[0].String()
currentCount := valArray[1].Integer()
// detail log
if lb.enableDetailLog && len(valArray) >= 3 {
detailLogStr := "host and count: "
details := valArray[2].Array()
for i := 0; i+1 < len(details); i += 2 {
h := details[i].String()
c := details[i+1].String()
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
}
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
}
// check rate limit
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
ctx.SetContext("error", true)
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
// return 429
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
ctx.DontReadResponseBody()
return
}
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
proxywasm.ResumeHttpRequest()
return
}
log.Debugf("host_selected: %s", hostSelected)
// finally resume the request
ctx.SetContext("host_selected", hostSelected)
proxywasm.ResumeHttpRequest()
})
if err != nil {
ctx.SetContext("error", true)
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
return types.ActionContinue
}
return types.ActionPause
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
if err != nil {
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
}
}
}
}

View File

@@ -0,0 +1,220 @@
-- Mocking Redis environment
local redis_data = {
hset = {},
kv = {}
}
local redis = {
call = function(cmd, ...)
local args = {...}
if cmd == "HGET" then
local key, field = args[1], args[2]
return redis_data.hset[field]
elseif cmd == "HSET" then
local key, field, val = args[1], args[2], args[3]
redis_data.hset[field] = val
elseif cmd == "HINCRBY" then
local key, field, increment = args[1], args[2], args[3]
local val = tonumber(redis_data.hset[field] or 0)
redis_data.hset[field] = tostring(val + increment)
return redis_data.hset[field]
elseif cmd == "HKEYS" then
local keys = {}
for k, _ in pairs(redis_data.hset) do
table.insert(keys, k)
end
return keys
elseif cmd == "HDEL" then
local key, field = args[1], args[2]
redis_data.hset[field] = nil
elseif cmd == "GET" then
return redis_data.kv[args[1]]
elseif cmd == "HMGET" then
local key = args[1]
local res = {}
for i = 2, #args do
table.insert(res, redis_data.hset[args[i]])
end
return res
elseif cmd == "SET" then
redis_data.kv[args[1]] = args[2]
end
end
}
-- The actual logic from lb_policy.go
local function run_lb_logic(KEYS)
local seed = tonumber(KEYS[1])
local hset_key = KEYS[2]
local last_clean_key = KEYS[3]
local clean_interval = tonumber(KEYS[4])
local current_target = KEYS[5]
local healthy_count = tonumber(KEYS[6])
local enable_detail_log = KEYS[7]
math.randomseed(seed)
-- 1. Selection
local current_count = 0
local same_count_hits = 0
for i = 8, 8 + healthy_count - 1 do
local host = KEYS[i]
local count = 0
local val = redis.call('HGET', hset_key, host)
if val then
count = tonumber(val) or 0
end
if same_count_hits == 0 or count < current_count then
current_target = host
current_count = count
same_count_hits = 1
elseif count == current_count then
same_count_hits = same_count_hits + 1
if math.random(same_count_hits) == 1 then
current_target = host
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
local new_count = redis.call("HGET", hset_key, current_target)
-- Collect host counts for logging
local host_details = {}
if enable_detail_log == "1" then
local fields = {}
for i = 8, #KEYS do
table.insert(fields, KEYS[i])
end
if #fields > 0 then
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
for i, val in ipairs(values) do
table.insert(host_details, fields[i])
table.insert(host_details, tostring(val or 0))
end
end
end
-- 2. Cleanup
local current_time = math.floor(seed / 1000000)
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
if current_time - last_clean_time >= clean_interval then
local all_keys = redis.call('HKEYS', hset_key)
if #all_keys > 0 then
-- Create a lookup table for current hosts (from index 8 onwards)
local current_hosts = {}
for i = 8, #KEYS do
current_hosts[KEYS[i]] = true
end
-- Remove keys not in current hosts
for _, host in ipairs(all_keys) do
if not current_hosts[host] then
redis.call('HDEL', hset_key, host)
end
end
end
redis.call('SET', last_clean_key, current_time)
end
return {current_target, new_count, host_details}
end
-- --- Test 1: Load Balancing Distribution ---
print("--- Test 1: Load Balancing Distribution ---")
local hosts = {"host1", "host2", "host3", "host4", "host5"}
local iterations = 100000
local results = {}
for _, h in ipairs(hosts) do results[h] = 0 end
-- Reset redis
redis_data.hset = {}
for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end
print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts))
for i = 1, iterations do
local initial_host = hosts[math.random(#hosts)]
-- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts]
local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"}
for _, h in ipairs(hosts) do table.insert(keys, h) end
local res = run_lb_logic(keys)
local selected = res[1]
results[selected] = results[selected] + 1
end
for _, h in ipairs(hosts) do
local percentage = (results[h] / iterations) * 100
print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage))
end
-- --- Test 2: IP Cleanup Logic ---
print("\n--- Test 2: IP Cleanup Logic ---")
local function test_cleanup()
redis_data.hset = {
["host1"] = "10",
["host2"] = "5",
["old_ip_1"] = "1",
["old_ip_2"] = "1",
}
redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s
local current_hosts = {"host1", "host2"}
local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say)
local clean_interval = 300
print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", "))
-- Run logic (seed is microtime)
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"}
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
run_lb_logic(keys)
print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", "))
local exists_old1 = redis_data.hset["old_ip_1"] ~= nil
local exists_old2 = redis_data.hset["old_ip_2"] ~= nil
if not exists_old1 and not exists_old2 then
print("Success: Outdated IPs removed.")
else
print("Failure: Outdated IPs still exist.")
end
print("New last_clean_time:", redis_data.kv["clean_key"])
end
test_cleanup()
-- --- Test 3: No Cleanup if Interval Not Reached ---
print("\n--- Test 3: No Cleanup if Interval Not Reached ---")
local function test_no_cleanup()
redis_data.hset = {
["host1"] = "10",
["old_ip_1"] = "1",
}
redis_data.kv["clean_key"] = "1000"
local current_hosts = {"host1"}
local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached)
local clean_interval = 300
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"}
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
run_lb_logic(keys)
if redis_data.hset["old_ip_1"] then
print("Success: Cleanup not triggered as expected.")
else
print("Failure: Cleanup triggered unexpectedly.")
end
end
test_no_cleanup()

View File

@@ -0,0 +1,24 @@
package global_least_request
import (
"fmt"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool {
// 如果没有配置最大请求数,直接通过
if lb.maxRequestCount <= 0 {
return true
}
// 如果当前请求数大于最大请求数,则限流
// 注意Lua脚本已经加了1所以这里比较的是加1后的值
if currentCount > lb.maxRequestCount {
// 恢复 Redis 计数
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil)
return false
}
return true
}

View File

@@ -26,6 +26,8 @@ description: AI 代理插件配置参考
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body再转换为对应 LLM 厂商的文本向量协议
> 请求路径后缀匹配 `/v1/images/generations` 时,对应文生图场景,会用 OpenAI 的图片生成协议解析请求 Body再转换为对应 LLM 厂商的图片生成协议
## 运行属性
插件执行阶段:`默认阶段`
@@ -309,7 +311,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
#### Google Vertex AI
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式
**标准模式**(使用 Service Account
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
@@ -320,25 +324,56 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
**Express Mode**(使用 API Key简化配置
Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key从 Google Cloud Console 的 API & Services > Credentials 获取 |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
**OpenAI 兼容模式**(使用 Vertex AI Chat Completions API
Vertex AI 提供了 OpenAI 兼容的 Chat Completions API 端点,可以直接使用 OpenAI 格式的请求和响应,无需进行协议转换。详见 [Vertex AI OpenAI 兼容性文档](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview)。
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
| `vertexOpenAICompatible` | boolean | 非必填 | false | 启用 OpenAI 兼容模式。启用后将使用 Vertex AI 的 OpenAI-compatible Chat Completions API |
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key |
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等) |
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID |
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称 |
**注意**OpenAI 兼容模式与 Express Mode 互斥,不能同时配置 `apiTokens``vertexOpenAICompatible`
#### AWS Bedrock
AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下
AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|--------|------|-----|------------------------------|
| `modelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
1. **AWS Signature V4 认证**:使用 `awsAccessKey``awsSecretKey` 进行 AWS 标准签名认证
2. **Bearer Token 认证**:使用 `apiTokens` 配置 AWS Bearer Token适用于 IAM Identity Center 等场景)
**注意**:两种认证方式二选一,如果同时配置了 `apiTokens`,将优先使用 Bearer Token 认证方式。
它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|---------------|-------------------|-------|---------------------------------------------------|
| `apiTokens` | array of string | 与 ak/sk 二选一 | - | AWS Bearer Token用于 Bearer Token 认证方式 |
| `awsAccessKey` | string | 与 apiTokens 二选一 | - | AWS Access Key用于 AWS Signature V4 认证 |
| `awsSecretKey` | string | 与 apiTokens 二选一 | - | AWS Secret Access Key用于 AWS Signature V4 认证 |
| `awsRegion` | string | 必填 | - | AWS 区域例如us-east-1 |
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
#### NVIDIA Triton Interference Server
NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配置字段如下
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|--------|------|-----|------------------------------|
| `awsAccessKey` | string | 必填 | - | AWS Access Key用于身份认证 |
| `awsSecretKey` | string | 必填 | - | AWS Secret Access Key用于身份认证 |
| `awsRegion` | string | 必填 | - | AWS 区域例如us-east-1 |
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|----------------------|--------|--------|-------|------------------------------------------|
| `tritonModelVersion` | string | 必填 | - | 用于指定 Triton Server 中 model version |
| `tritonDomain` | string | 必填 | - | Triton Server 部署的指定请求 Domain |
## 用法示例
@@ -1947,7 +1982,7 @@ provider:
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务
### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式)
**配置信息**
@@ -2009,8 +2044,236 @@ provider:
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务Express Mode
Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。
**配置信息**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
```
**请求示例**
```json
{
"model": "gemini-2.5-flash",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "chatcmpl-0000000000000",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是 Gemini由 Google 开发的人工智能助手。有什么我可以帮您的吗?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.5-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务OpenAI 兼容模式)
OpenAI 兼容模式使用 Vertex AI 的 OpenAI-compatible Chat Completions API请求和响应都使用 OpenAI 格式,无需进行协议转换。
**配置信息**
```yaml
provider:
type: vertex
vertexOpenAICompatible: true
vertexAuthKey: |
{
"type": "service_account",
"project_id": "your-project-id",
"private_key_id": "your-private-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
"token_uri": "https://oauth2.googleapis.com/token"
}
vertexRegion: us-central1
vertexProjectId: your-project-id
vertexAuthServiceName: your-auth-service-name
modelMapping:
"gpt-4": "gemini-2.0-flash"
"*": "gemini-1.5-flash"
```
**请求示例**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "chatcmpl-abc123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是由 Google 开发的 Gemini 模型。我可以帮助回答问题、提供信息和进行对话。有什么我可以帮您的吗?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.0-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 12,
"completion_tokens": 35,
"total_tokens": 47
}
}
```
### 使用 OpenAI 协议代理 Google Vertex 图片生成服务
Vertex AI 支持使用 Gemini 模型进行图片生成。通过 ai-proxy 插件,可以使用 OpenAI 的 `/v1/images/generations` 接口协议来调用 Vertex AI 的图片生成能力。
**配置信息**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
modelMapping:
"dall-e-3": "gemini-2.0-flash-exp"
geminiSafetySetting:
HARM_CATEGORY_HARASSMENT: "OFF"
HARM_CATEGORY_HATE_SPEECH: "OFF"
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
```
**使用 curl 请求**
```bash
curl -X POST "http://your-gateway-address/v1/images/generations" \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-2.0-flash-exp",
"prompt": "一只可爱的橘猫在阳光下打盹",
"size": "1024x1024"
}'
```
**使用 OpenAI Python SDK**
```python
from openai import OpenAI
client = OpenAI(
api_key="any-value", # 可以是任意值,认证由网关处理
base_url="http://your-gateway-address/v1"
)
response = client.images.generate(
model="gemini-2.0-flash-exp",
prompt="一只可爱的橘猫在阳光下打盹",
size="1024x1024",
n=1
)
# 获取生成的图片base64 编码)
image_data = response.data[0].b64_json
print(f"Generated image (base64): {image_data[:100]}...")
```
**响应示例**
```json
{
"created": 1729986750,
"data": [
{
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
}
],
"usage": {
"total_tokens": 1356,
"input_tokens": 13,
"output_tokens": 1120
}
}
```
**支持的尺寸参数**
Vertex AI 支持的宽高比aspectRatio`1:1``3:2``2:3``3:4``4:3``4:5``5:4``9:16``16:9``21:9`
Vertex AI 支持的分辨率imageSize`1k``2k``4k`
| OpenAI size 参数 | Vertex AI aspectRatio | Vertex AI imageSize |
|------------------|----------------------|---------------------|
| 256x256 | 1:1 | 1k |
| 512x512 | 1:1 | 1k |
| 1024x1024 | 1:1 | 1k |
| 1792x1024 | 16:9 | 2k |
| 1024x1792 | 9:16 | 2k |
| 2048x2048 | 1:1 | 2k |
| 4096x4096 | 1:1 | 4k |
| 1536x1024 | 3:2 | 2k |
| 1024x1536 | 2:3 | 2k |
| 1024x768 | 4:3 | 1k |
| 768x1024 | 3:4 | 1k |
| 1280x1024 | 5:4 | 1k |
| 1024x1280 | 4:5 | 1k |
| 2560x1080 | 21:9 | 2k |
**注意事项**
- 图片生成使用 Gemini 模型(如 `gemini-2.0-flash-exp``gemini-3-pro-image-preview`),不同模型的可用性可能因区域而异
- 返回的图片数据为 base64 编码格式(`b64_json`
- 可以通过 `geminiSafetySetting` 配置内容安全过滤级别
- 如果需要使用模型映射(如将 `dall-e-3` 映射到 Gemini 模型),可以配置 `modelMapping`
### 使用 OpenAI 协议代理 AWS Bedrock 服务
AWS Bedrock 支持两种认证方式:
#### 方式一:使用 AWS Access Key/Secret Key 认证AWS Signature V4
**配置信息**
```yaml
@@ -2018,7 +2281,21 @@ provider:
type: bedrock
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
awsRegion: "YOUR_AWS_REGION"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
#### 方式二:使用 Bearer Token 认证(适用于 IAM Identity Center 等场景)
**配置信息**
```yaml
provider:
type: bedrock
apiTokens:
- "YOUR_AWS_BEARER_TOKEN"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
@@ -2027,7 +2304,7 @@ provider:
```json
{
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
"messages": [
{
"role": "user",

View File

@@ -25,6 +25,8 @@ The plugin now supports **automatic protocol detection**, allowing seamless comp
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.
> When the request path suffix matches `/v1/images/generations`, it corresponds to text-to-image scenarios. The request body will be parsed using OpenAI's image generation protocol and then converted to the corresponding LLM vendor's image generation protocol.
## Execution Properties
Plugin execution phase: `Default Phase`
Plugin execution priority: `100`
@@ -255,7 +257,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
#### Google Vertex AI
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes:
**Standard Mode** (using Service Account):
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
@@ -266,16 +270,47 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
**Express Mode** (using API Key, simplified configuration):
Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview).
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials |
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
**OpenAI Compatible Mode** (using Vertex AI Chat Completions API):
Vertex AI provides an OpenAI-compatible Chat Completions API endpoint, allowing you to use OpenAI format requests and responses directly without protocol conversion. See [Vertex AI OpenAI Compatibility documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview).
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `vertexOpenAICompatible` | boolean | Optional | false | Enable OpenAI compatible mode. When enabled, uses Vertex AI's OpenAI-compatible Chat Completions API |
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key for authentication |
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) |
| `vertexProjectId` | string | Required | - | Google Cloud Project ID |
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication |
**Note**: OpenAI Compatible Mode and Express Mode are mutually exclusive. You cannot configure both `apiTokens` and `vertexOpenAICompatible` at the same time.
#### AWS Bedrock
For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration field is:
For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods:
| Name | Data Type | Requirement | Default | Description |
|---------------------------|-----------|-------------|---------|---------------------------------------------------------|
| `awsAccessKey` | string | Required | - | AWS Access Key used for authentication |
| `awsSecretKey` | string | Required | - | AWS Secret Access Key used for authentication |
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
1. **AWS Signature V4 Authentication**: Uses `awsAccessKey` and `awsSecretKey` for standard AWS signature authentication
2. **Bearer Token Authentication**: Uses `apiTokens` to configure AWS Bearer Token (suitable for IAM Identity Center and similar scenarios)
**Note**: Choose one of the two authentication methods. If `apiTokens` is configured, Bearer Token authentication will be used preferentially.
Its unique configuration fields are:
| Name | Data Type | Requirement | Default | Description |
|---------------------------|-----------------|--------------------------|---------|-------------------------------------------------------------------|
| `apiTokens` | array of string | Either this or ak/sk | - | AWS Bearer Token for Bearer Token authentication |
| `awsAccessKey` | string | Either this or apiTokens | - | AWS Access Key for AWS Signature V4 authentication |
| `awsSecretKey` | string | Either this or apiTokens | - | AWS Secret Access Key for AWS Signature V4 authentication |
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
## Usage Examples
@@ -1720,7 +1755,7 @@ provider:
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode)
**Configuration Information**
```yaml
provider:
@@ -1778,14 +1813,250 @@ provider:
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode)
Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly.
**Configuration Information**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
```
**Request Example**
```json
{
"model": "gemini-2.5-flash",
"messages": [
{
"role": "user",
"content": "Who are you?"
}
],
"stream": false
}
```
**Response Example**
```json
{
"id": "chatcmpl-0000000000000",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.5-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (OpenAI Compatible Mode)
OpenAI Compatible Mode uses Vertex AI's OpenAI-compatible Chat Completions API. Both requests and responses use OpenAI format, requiring no protocol conversion.
**Configuration Information**
```yaml
provider:
type: vertex
vertexOpenAICompatible: true
vertexAuthKey: |
{
"type": "service_account",
"project_id": "your-project-id",
"private_key_id": "your-private-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
"token_uri": "https://oauth2.googleapis.com/token"
}
vertexRegion: us-central1
vertexProjectId: your-project-id
vertexAuthServiceName: your-auth-service-name
modelMapping:
"gpt-4": "gemini-2.0-flash"
"*": "gemini-1.5-flash"
```
**Request Example**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
],
"stream": false
}
```
**Response Example**
```json
{
"id": "chatcmpl-abc123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am Gemini, an AI model developed by Google. I can help answer questions, provide information, and engage in conversations. How can I assist you today?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.0-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 12,
"completion_tokens": 35,
"total_tokens": 47
}
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Image Generation
Vertex AI supports image generation using Gemini models. Through the ai-proxy plugin, you can use OpenAI's `/v1/images/generations` API to call Vertex AI's image generation capabilities.
**Configuration Information**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
modelMapping:
"dall-e-3": "gemini-2.0-flash-exp"
geminiSafetySetting:
HARM_CATEGORY_HARASSMENT: "OFF"
HARM_CATEGORY_HATE_SPEECH: "OFF"
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
```
**Using curl**
```bash
curl -X POST "http://your-gateway-address/v1/images/generations" \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-2.0-flash-exp",
"prompt": "A cute orange cat napping in the sunshine",
"size": "1024x1024"
}'
```
**Using OpenAI Python SDK**
```python
from openai import OpenAI
client = OpenAI(
api_key="any-value", # Can be any value, authentication is handled by the gateway
base_url="http://your-gateway-address/v1"
)
response = client.images.generate(
model="gemini-2.0-flash-exp",
prompt="A cute orange cat napping in the sunshine",
size="1024x1024",
n=1
)
# Get the generated image (base64 encoded)
image_data = response.data[0].b64_json
print(f"Generated image (base64): {image_data[:100]}...")
```
**Response Example**
```json
{
"created": 1729986750,
"data": [
{
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
}
],
"usage": {
"total_tokens": 1356,
"input_tokens": 13,
"output_tokens": 1120
}
}
```
**Supported Size Parameters**
Vertex AI supported aspect ratios: `1:1`, `3:2`, `2:3`, `3:4`, `4:3`, `4:5`, `5:4`, `9:16`, `16:9`, `21:9`
Vertex AI supported resolutions (imageSize): `1k`, `2k`, `4k`
| OpenAI size parameter | Vertex AI aspectRatio | Vertex AI imageSize |
|-----------------------|----------------------|---------------------|
| 256x256 | 1:1 | 1k |
| 512x512 | 1:1 | 1k |
| 1024x1024 | 1:1 | 1k |
| 1792x1024 | 16:9 | 2k |
| 1024x1792 | 9:16 | 2k |
| 2048x2048 | 1:1 | 2k |
| 4096x4096 | 1:1 | 4k |
| 1536x1024 | 3:2 | 2k |
| 1024x1536 | 2:3 | 2k |
| 1024x768 | 4:3 | 1k |
| 768x1024 | 3:4 | 1k |
| 1280x1024 | 5:4 | 1k |
| 1024x1280 | 4:5 | 1k |
| 2560x1080 | 21:9 | 2k |
**Notes**
- Image generation uses Gemini models (e.g., `gemini-2.0-flash-exp`, `gemini-3-pro-image-preview`). Model availability may vary by region
- The returned image data is in base64 encoded format (`b64_json`)
- Content safety filtering levels can be configured via `geminiSafetySetting`
- If you need model mapping (e.g., mapping `dall-e-3` to a Gemini model), configure `modelMapping`
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
AWS Bedrock supports two authentication methods:
#### Method 1: Using AWS Access Key/Secret Key Authentication (AWS Signature V4)
**Configuration Information**
```yaml
provider:
type: bedrock
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
awsRegion: "YOUR_AWS_REGION"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
#### Method 2: Using Bearer Token Authentication (suitable for IAM Identity Center and similar scenarios)
**Configuration Information**
```yaml
provider:
type: bedrock
apiTokens:
- "YOUR_AWS_BEARER_TOKEN"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
@@ -1793,7 +2064,7 @@ provider:
**Request Example**
```json
{
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
"messages": [
{
"role": "user",

View File

@@ -8,7 +8,7 @@ toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)

View File

@@ -4,8 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -128,3 +128,25 @@ func TestGeneric(t *testing.T) {
test.RunGenericOnHttpRequestHeadersTests(t)
test.RunGenericOnHttpRequestBodyTests(t)
}
func TestVertex(t *testing.T) {
test.RunVertexParseConfigTests(t)
test.RunVertexExpressModeOnHttpRequestHeadersTests(t)
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
// Vertex Raw 模式测试
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
test.RunVertexRawModeOnHttpRequestBodyTests(t)
test.RunVertexRawModeOnHttpResponseBodyTests(t)
}
func TestBedrock(t *testing.T) {
test.RunBedrockParseConfigTests(t)
test.RunBedrockOnHttpRequestHeadersTests(t)
test.RunBedrockOnHttpRequestBodyTests(t)
test.RunBedrockOnHttpResponseHeadersTests(t)
test.RunBedrockOnHttpResponseBodyTests(t)
}

View File

@@ -43,8 +43,11 @@ const (
type bedrockProviderInitializer struct{}
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
return errors.New("missing bedrock access authentication parameters")
hasAkSk := len(config.awsAccessKey) > 0 && len(config.awsSecretKey) > 0
hasApiToken := len(config.apiTokens) > 0
if !hasAkSk && !hasApiToken {
return errors.New("missing bedrock access authentication parameters: either apiTokens or (awsAccessKey + awsSecretKey) is required")
}
if len(config.awsRegion) == 0 {
return errors.New("missing bedrock region parameters")
@@ -634,6 +637,13 @@ func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
// If apiTokens is configured, set Bearer token authentication here
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
// AWS SigV4 authentication is handled in setAuthHeaders because it requires the request body
if len(b.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+b.config.GetApiTokenInUse(ctx))
}
}
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
@@ -659,18 +669,18 @@ func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
case ApiNameChatCompletion:
return b.onChatCompletionResponseBody(ctx, body)
case ApiNameImageGeneration:
return b.onImageGenerationResponseBody(ctx, body)
return b.onImageGenerationResponseBody(body)
}
return nil, errUnsupportedApiName
}
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
func (b *bedrockProvider) onImageGenerationResponseBody(body []byte) ([]byte, error) {
bedrockResponse := &bedrockImageGenerationResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
}
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
response := b.buildBedrockImageGenerationResponse(bedrockResponse)
return json.Marshal(response)
}
@@ -710,7 +720,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
return requestBytes, err
}
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
data := make([]imageGenerationData, len(bedrockResponse.Images))
for i, image := range bedrockResponse.Images {
data[i] = imageGenerationData{
@@ -1138,6 +1148,13 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
}
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
// Bearer token authentication is already set in TransformRequestHeaders
// This function only handles AWS SigV4 authentication which requires the request body
if len(b.config.apiTokens) > 0 {
return
}
// Use AWS Signature V4 authentication
t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102")

View File

@@ -7,6 +7,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
@@ -134,8 +135,63 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
} else {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
}
var token string
// 1. If apiTokens is configured, use it first
if len(m.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
token = m.config.GetApiTokenInUse(ctx)
if token == "" {
log.Warnf("[openaiProvider.TransformRequestHeaders] apiTokens count > 0 but GetApiTokenInUse returned empty")
}
} else {
// If no apiToken is configured, try to extract from original request headers
// 2. If authHeaderKey is configured, use the specified header
if m.config.authHeaderKey != "" {
if apiKey, err := proxywasm.GetHttpRequestHeader(m.config.authHeaderKey); err == nil && apiKey != "" {
token = apiKey
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from configured header: %s", m.config.authHeaderKey)
}
}
// 3. If authHeaderKey is not configured, check default headers in priority order
if token == "" {
defaultHeaders := []string{"x-api-key", "x-authorization"}
for _, headerName := range defaultHeaders {
if apiKey, err := proxywasm.GetHttpRequestHeader(headerName); err == nil && apiKey != "" {
token = apiKey
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from %s header", headerName)
break
}
}
}
// 4. Finally check Authorization header
if token == "" {
if auth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil && auth != "" {
// Extract token from "Bearer <token>" format
if strings.HasPrefix(auth, "Bearer ") {
token = strings.TrimPrefix(auth, "Bearer ")
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from Authorization header (Bearer format)")
} else {
token = auth
log.Debugf("[openaiProvider.TransformRequestHeaders] Using token from Authorization header (no Bearer prefix)")
}
}
}
}
// 5. Set Authorization header (avoid duplicate Bearer prefix)
if token != "" {
// Check if token already contains Bearer prefix
if !strings.HasPrefix(token, "Bearer ") {
token = "Bearer " + token
}
util.OverwriteRequestAuthorizationHeader(headers, token)
log.Debugf("[openaiProvider.TransformRequestHeaders] Set Authorization header successfully")
} else {
log.Warnf("[openaiProvider.TransformRequestHeaders] No auth token available - neither configured in apiTokens nor in request headers")
}
headers.Del("Content-Length")
}

View File

@@ -70,6 +70,7 @@ const (
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
ApiNameVertexRaw ApiName = "vertex/raw"
// OpenAI
PathOpenAIPrefix = "/v1"
@@ -387,12 +388,18 @@ type ProviderConfig struct {
// @Title zh-CN Vertex token刷新提前时间
// @Description zh-CN 用于Google服务账号认证access token过期时间判定提前刷新单位为秒默认值为60秒
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
// @Title zh-CN Vertex AI OpenAI兼容模式
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API请求和响应均使用OpenAI格式无需协议转换。与Express Mode(apiTokens)互斥。
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
// @Title zh-CN 自定义认证Header名称
// @Description zh-CN 用于从请求中提取认证token的自定义header名称。如不配置则按默认优先级检查 x-api-key、x-authorization、anthropic-api-key 和 Authorization header。
authHeaderKey string `required:"false" yaml:"authHeaderKey" json:"authHeaderKey"`
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
@@ -540,6 +547,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.vertexTokenRefreshAhead == 0 {
c.vertexTokenRefreshAhead = 60
}
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {

View File

@@ -21,14 +21,21 @@ import (
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
vertexAuthDomain = "oauth2.googleapis.com"
vertexDomain = "aiplatform.googleapis.com"
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
// Express Mode 路径模板 (不含 project/location)
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
// OpenAI-compatible endpoint 路径模板
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
@@ -36,12 +43,44 @@ const (
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
contextClaudeMarker = "isClaudeRequest"
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
contextVertexRawMarker = "isVertexRawRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
)
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
// 允许任意 basePath 前缀,兼容 basePathHandling 配置
var vertexRawPathRegex = regexp.MustCompile(`^.*/([^/]+)/projects/([^/]+)/locations/([^/]+)/publishers/([^/]+)/models/([^/:]+):([^/?]+)`)
type vertexProviderInitializer struct{}
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// Express Mode: 如果配置了 apiTokens则使用 API Key 认证
if len(config.apiTokens) > 0 {
// Express Mode 与 OpenAI 兼容模式互斥
if config.vertexOpenAICompatible {
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
}
// Express Mode 不需要其他配置
return nil
}
// OpenAI 兼容模式: 需要 OAuth 认证配置
if config.vertexOpenAICompatible {
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
}
if config.vertexRegion == "" || config.vertexProjectId == "" {
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
}
if config.vertexAuthServiceName == "" {
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
}
return nil
}
// 标准模式: 保持原有验证逻辑
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config")
}
@@ -56,26 +95,47 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameImageGeneration): vertexPathTemplate,
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
}
}
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(v.DefaultCapabilities())
return &vertexProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
}),
provider := &vertexProvider{
config: config,
contextCache: createContextCache(&config),
claude: &claudeProvider{
config: config,
contextCache: createContextCache(&config),
},
}, nil
}
// 仅标准模式需要 OAuth 客户端Express Mode 通过 apiTokens 配置)
if !provider.isExpressMode() {
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
})
}
return provider, nil
}
// isExpressMode 检测是否启用 Express Mode
// 如果配置了 apiTokens则使用 Express ModeAPI Key 认证)
func (v *vertexProvider) isExpressMode() bool {
return len(v.config.apiTokens) > 0
}
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
func (v *vertexProvider) isOpenAICompatibleMode() bool {
return v.config.vertexOpenAICompatible
}
type vertexProvider struct {
@@ -90,6 +150,12 @@ func (v *vertexProvider) GetProviderType() string {
}
func (v *vertexProvider) GetApiName(path string) ApiName {
// 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
// 必须在其他 action 检查之前,因为 :predict、:generateContent 等 action 会被其他规则匹配
if vertexRawPathRegex.MatchString(path) {
return ApiNameVertexRaw
}
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
return ApiNameChatCompletion
}
@@ -106,11 +172,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
var finalVertexDomain string
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
if v.isExpressMode() {
// Express Mode: 固定域名,不带 region 前缀
finalVertexDomain = vertexDomain
} else {
// 标准模式: 带 region 前缀
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
finalVertexDomain = vertexDomain
}
}
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
}
@@ -150,12 +224,66 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if !v.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
// Vertex Raw 模式: 透传请求体,只做 OAuth 认证
// 用于直接访问 Vertex AI REST API不做协议转换
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
if apiName == ApiNameVertexRaw {
ctx.SetContext(contextVertexRawMarker, true)
// Express Mode 不需要 OAuth 认证
if v.isExpressMode() {
return types.ActionContinue, nil
}
// 标准模式需要获取 OAuth token
cached, err := v.getToken()
if cached {
return types.ActionContinue, nil
}
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
if v.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
if v.isOpenAICompatibleMode() {
ctx.SetContext(contextOpenAICompatibleMarker, true)
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return types.ActionContinue, err
}
// OpenAI 兼容模式需要 OAuth token
cached, err := v.getToken()
if cached {
return types.ActionContinue, nil
}
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
if v.isExpressMode() {
// Express Mode: 不需要 Authorization headerAPI Key 已在 URL 中
headers.Del("Authorization")
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
return types.ActionContinue, err
}
// 标准模式: 需要获取 OAuth token
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -172,13 +300,44 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return v.onChatCompletionRequestBody(ctx, body, headers)
} else {
case ApiNameEmbeddings:
return v.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return v.onImageGenerationRequestBody(ctx, body, headers)
default:
return body, nil
}
}
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
// 不转换请求体格式,只进行模型映射和路径设置
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
}
// 解析请求进行模型映射
request := &chatCompletionRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
// 设置 OpenAI 兼容端点路径
path := v.getOpenAICompatibleRequestPath()
util.OverwriteRequestPathHeader(headers, path)
// 如果模型被映射,需要更新请求体中的模型字段
if request.Model != "" {
body, _ = sjson.SetBytes(body, "model", request.Model)
}
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
return body, nil
}
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := v.config.parseRequestAndMapModel(ctx, request, body)
@@ -219,7 +378,126 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageGenerationRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
// 图片生成不使用流式端点,需要完整响应
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexImageGenerationRequest(request)
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
// 构建安全设置
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {
safetySettings = append(safetySettings, vertexChatSafetySetting{
Category: category,
Threshold: threshold,
})
}
// 解析尺寸参数
aspectRatio, imageSize := v.parseImageSize(request.Size)
// 确定输出 MIME 类型
mimeType := "image/png"
if request.OutputFormat != "" {
switch request.OutputFormat {
case "jpeg", "jpg":
mimeType = "image/jpeg"
case "webp":
mimeType = "image/webp"
default:
mimeType = "image/png"
}
}
vertexRequest := &vertexChatRequest{
Contents: []vertexChatContent{{
Role: roleUser,
Parts: []vertexPart{{
Text: request.Prompt,
}},
}},
SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{
Temperature: 1.0,
MaxOutputTokens: 32768,
ResponseModalities: []string{"TEXT", "IMAGE"},
ImageConfig: &vertexImageConfig{
AspectRatio: aspectRatio,
ImageSize: imageSize,
ImageOutputOptions: &vertexImageOutputOptions{
MimeType: mimeType,
},
PersonGeneration: "ALLOW_ALL",
},
},
}
return vertexRequest
}
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9
// Vertex AI 支持的 imageSize: 1k, 2k, 4k
func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) {
// 默认值
aspectRatio = "1:1"
imageSize = "1k"
if size == "" {
return
}
// 预定义的尺寸映射OpenAI 标准尺寸)
sizeMapping := map[string]struct {
aspectRatio string
imageSize string
}{
// OpenAI DALL-E 标准尺寸
"256x256": {"1:1", "1k"},
"512x512": {"1:1", "1k"},
"1024x1024": {"1:1", "1k"},
"1792x1024": {"16:9", "2k"},
"1024x1792": {"9:16", "2k"},
// 扩展尺寸支持
"2048x2048": {"1:1", "2k"},
"4096x4096": {"1:1", "4k"},
// 3:2 和 2:3 比例
"1536x1024": {"3:2", "2k"},
"1024x1536": {"2:3", "2k"},
// 4:3 和 3:4 比例
"1024x768": {"4:3", "1k"},
"768x1024": {"3:4", "1k"},
"1365x1024": {"4:3", "1k"},
"1024x1365": {"3:4", "1k"},
// 5:4 和 4:5 比例
"1280x1024": {"5:4", "1k"},
"1024x1280": {"4:5", "1k"},
// 21:9 超宽比例
"2560x1080": {"21:9", "2k"},
}
if mapping, ok := sizeMapping[size]; ok {
return mapping.aspectRatio, mapping.imageSize
}
return
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapesInSSE(chunk), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
@@ -260,13 +538,25 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapes(body), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return v.onChatCompletionResponseBody(ctx, body)
} else {
case ApiNameEmbeddings:
return v.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration:
return v.onImageGenerationResponseBody(ctx, body)
default:
return body, nil
}
}
@@ -359,6 +649,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
return &response
}
func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
// 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据
// 这样可以显著减少内存分配和复制次数
response := v.buildImageGenerationResponseFromJSON(body)
return json.Marshal(response)
}
// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应
// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高
func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse {
result := gjson.ParseBytes(body)
data := make([]imageGenerationData, 0)
// 遍历所有 candidates提取图片数据
candidates := result.Get("candidates")
candidates.ForEach(func(_, candidate gjson.Result) bool {
parts := candidate.Get("content.parts")
parts.ForEach(func(_, part gjson.Result) bool {
// 跳过思考过程 (thought: true)
if part.Get("thought").Bool() {
return true
}
// 提取图片数据
inlineData := part.Get("inlineData.data")
if inlineData.Exists() && inlineData.String() != "" {
data = append(data, imageGenerationData{
B64: inlineData.String(),
})
}
return true
})
return true
})
// 提取 usage 信息
usage := result.Get("usageMetadata")
return &imageGenerationResponse{
Created: time.Now().UnixMilli() / 1000,
Data: data,
Usage: &imageGenerationUsage{
TotalTokens: int(usage.Get("totalTokenCount").Int()),
InputTokens: int(usage.Get("promptTokenCount").Int()),
OutputTokens: int(usage.Get("candidatesTokenCount").Int()),
},
}
}
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
choice.Delta = &chatMessage{}
@@ -422,19 +760,62 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
} else {
action = vertexAnthropicMessageAction
}
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?,使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
switch apiName {
case ApiNameEmbeddings:
action = vertexEmbeddingAction
} else if stream {
action = vertexChatCompletionStreamAction
} else {
case ApiNameImageGeneration:
// 图片生成使用非流式端点,需要完整响应
action = vertexChatCompletionAction
default:
if stream {
action = vertexChatCompletionStreamAction
} else {
action = vertexChatCompletionAction
}
}
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
}
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
@@ -521,7 +902,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
})
}
case contentTypeImageUrl:
vpart, err := convertImageContent(part.ImageUrl.Url)
vpart, err := convertMediaContent(part.ImageUrl.Url)
if err != nil {
log.Errorf("unable to convert image content: %v", err)
} else {
@@ -636,12 +1017,25 @@ type vertexChatSafetySetting struct {
}
type vertexChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
}
type vertexImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"`
ImageSize string `json:"imageSize,omitempty"`
ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type vertexImageOutputOptions struct {
MimeType string `json:"mimeType,omitempty"`
}
type vertexThinkingConfig struct {
@@ -852,32 +1246,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
return proxywasm.SetSharedData(key, data, cas)
}
func convertImageContent(imageUrl string) (vertexPart, error) {
// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式
// 支持图片、视频、音频等多种媒体类型
func convertMediaContent(mediaUrl string) (vertexPart, error) {
part := vertexPart{}
if strings.HasPrefix(imageUrl, "http") {
arr := strings.Split(imageUrl, ".")
mimeType := "image/" + arr[len(arr)-1]
if strings.HasPrefix(mediaUrl, "http") {
mimeType := detectMimeTypeFromURL(mediaUrl)
part.FileData = &fileData{
MimeType: mimeType,
FileUri: imageUrl,
FileUri: mediaUrl,
}
return part, nil
} else {
// Base64 data URL 格式: data:<mimeType>;base64,<data>
re := regexp.MustCompile(`^data:([^;]+);base64,`)
matches := re.FindStringSubmatch(imageUrl)
matches := re.FindStringSubmatch(mediaUrl)
if len(matches) < 2 {
return part, fmt.Errorf("invalid base64 format")
return part, fmt.Errorf("invalid base64 format, expected data:<mimeType>;base64,<data>")
}
mimeType := matches[1] // e.g. image/png
mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3
parts := strings.Split(mimeType, "/")
if len(parts) < 2 {
return part, fmt.Errorf("invalid mimeType")
return part, fmt.Errorf("invalid mimeType: %s", mimeType)
}
part.InlineData = &blob{
MimeType: mimeType,
Data: strings.TrimPrefix(imageUrl, matches[0]),
Data: strings.TrimPrefix(mediaUrl, matches[0]),
}
return part, nil
}
}
// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型
// 支持图片、视频、音频和文档类型
func detectMimeTypeFromURL(url string) string {
// 移除查询参数和片段标识符
if idx := strings.Index(url, "?"); idx != -1 {
url = url[:idx]
}
if idx := strings.Index(url, "#"); idx != -1 {
url = url[:idx]
}
// 获取最后一个路径段
lastSlash := strings.LastIndex(url, "/")
if lastSlash != -1 {
url = url[lastSlash+1:]
}
// 获取扩展名
lastDot := strings.LastIndex(url, ".")
if lastDot == -1 || lastDot == len(url)-1 {
return "application/octet-stream"
}
ext := strings.ToLower(url[lastDot+1:])
// 扩展名到 MIME 类型的映射
mimeTypes := map[string]string{
// 图片格式
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
"svg": "image/svg+xml",
"ico": "image/x-icon",
"heic": "image/heic",
"heif": "image/heif",
"tiff": "image/tiff",
"tif": "image/tiff",
// 视频格式
"mp4": "video/mp4",
"mpeg": "video/mpeg",
"mpg": "video/mpeg",
"mov": "video/quicktime",
"avi": "video/x-msvideo",
"wmv": "video/x-ms-wmv",
"webm": "video/webm",
"mkv": "video/x-matroska",
"flv": "video/x-flv",
"3gp": "video/3gpp",
"3g2": "video/3gpp2",
"m4v": "video/x-m4v",
// 音频格式
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"m4a": "audio/mp4",
"wma": "audio/x-ms-wma",
"opus": "audio/opus",
// 文档格式
"pdf": "application/pdf",
}
if mimeType, ok := mimeTypes[ext]; ok {
return mimeType
}
return "application/octet-stream"
}

View File

@@ -0,0 +1,527 @@
package test
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// Test config: Basic Bedrock config with AWS Access Key/Secret Key (AWS Signature V4)
var basicBedrockConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with Bearer Token authentication
var bedrockApiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with multiple Bearer Tokens
var bedrockMultiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-1-for-unit-test",
"test-token-2-for-unit-test",
},
"awsRegion": "us-west-2",
"modelMapping": map[string]string{
"gpt-4": "anthropic.claude-3-opus-20240229-v1:0",
"*": "anthropic.claude-3-haiku-20240307-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with additional fields
var bedrockWithAdditionalFieldsConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"bedrockAdditionalFields": map[string]interface{}{
"top_k": 200,
},
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - missing both apiTokens and ak/sk
var bedrockInvalidConfigMissingAuth = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - missing region
var bedrockInvalidConfigMissingRegion = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - only has access key without secret key
var bedrockInvalidConfigPartialAkSk = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
func RunBedrockParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// Test basic Bedrock config with AWS Signature V4 authentication
t.Run("basic bedrock config with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with Bearer Token authentication
t.Run("bedrock config with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with multiple tokens
t.Run("bedrock config with multiple tokens", func(t *testing.T) {
host, status := test.NewTestHost(bedrockMultiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with additional fields
t.Run("bedrock config with additional fields", func(t *testing.T) {
host, status := test.NewTestHost(bedrockWithAdditionalFieldsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test invalid config - missing authentication
t.Run("bedrock invalid config missing auth", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigMissingAuth)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// Test invalid config - missing region
t.Run("bedrock invalid config missing region", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigMissingRegion)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// Test invalid config - partial ak/sk (only access key, no secret key)
t.Run("bedrock invalid config partial ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigPartialAkSk)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
func RunBedrockOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock request headers processing with AWS Signature V4
t.Run("bedrock chat completion request headers with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Verify request headers
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Host is changed to Bedrock service domain
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
})
// Test Bedrock request headers processing with Bearer Token
t.Run("bedrock chat completion request headers with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Verify request headers
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Host is changed to Bedrock service domain
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
})
})
}
func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock request body processing with Bearer Token authentication
t.Run("bedrock chat completion request body with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
],
"temperature": 0.7
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify request headers for Bearer Token authentication
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Authorization header uses Bearer token
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain the configured token")
// Verify path is transformed to Bedrock format
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
})
// Test Bedrock request body processing with AWS Signature V4 authentication
t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
],
"temperature": 0.7
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify request headers for AWS Signature V4 authentication
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Authorization header uses AWS Signature
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
// Verify X-Amz-Date header exists
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
// Verify path is transformed to Bedrock format
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
})
// Test Bedrock streaming request
t.Run("bedrock streaming request", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set streaming request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"stream": true
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify path is transformed to Bedrock streaming format
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse-stream", "Path should contain converse-stream endpoint for streaming")
})
})
}
func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock response headers processing
t.Run("bedrock response headers", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Process response headers
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Amzn-Requestid", "test-request-id-12345"},
})
require.Equal(t, types.ActionContinue, action)
// Verify response headers
responseHeaders := host.GetResponseHeaders()
require.NotNil(t, responseHeaders)
// Verify status code
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "200", statusValue, "Status should be 200")
})
})
}
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock response body processing
t.Run("bedrock response body", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Set response property to ensure IsResponseFromUpstream() returns true
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// Process response headers (must include :status 200 for body processing)
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.ActionContinue, action)
// Process response body (Bedrock format)
responseBody := `{
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Hello! How can I help you today?"
}
]
}
},
"stopReason": "end_turn",
"usage": {
"inputTokens": 10,
"outputTokens": 15,
"totalTokens": 25
}
}`
action = host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// Verify response body is transformed to OpenAI format
transformedResponseBody := host.GetResponseBody()
require.NotNil(t, transformedResponseBody)
var responseMap map[string]interface{}
err := json.Unmarshal(transformedResponseBody, &responseMap)
require.NoError(t, err)
// Verify choices exist in transformed response
choices, exists := responseMap["choices"]
require.True(t, exists, "Choices should exist in response body")
require.NotNil(t, choices, "Choices should not be nil")
// Verify usage exists
usage, exists := responseMap["usage"]
require.True(t, exists, "Usage should exist in response body")
require.NotNil(t, usage, "Usage should not be nil")
})
})
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,10 @@
package util
import "regexp"
import (
"regexp"
"strconv"
"strings"
)
func StripPrefix(s string, prefix string) string {
if len(prefix) != 0 && len(s) >= len(prefix) && s[0:len(prefix)] == prefix {
@@ -18,3 +22,43 @@ func MatchStatus(status string, patterns []string) bool {
}
return false
}
// unicodeEscapeRegex matches Unicode escape sequences like \uXXXX
var unicodeEscapeRegex = regexp.MustCompile(`\\u([0-9a-fA-F]{4})`)
// DecodeUnicodeEscapes decodes Unicode escape sequences (\uXXXX) in a string to UTF-8 characters.
// This is useful when a JSON response contains ASCII-safe encoded non-ASCII characters.
func DecodeUnicodeEscapes(input []byte) []byte {
result := unicodeEscapeRegex.ReplaceAllFunc(input, func(match []byte) []byte {
// match is like \uXXXX, extract the hex part (XXXX)
hexStr := string(match[2:6])
codePoint, err := strconv.ParseInt(hexStr, 16, 32)
if err != nil {
return match // return original if parse fails
}
return []byte(string(rune(codePoint)))
})
return result
}
// DecodeUnicodeEscapesInSSE decodes Unicode escape sequences in SSE formatted data.
// It processes each line that starts with "data: " and decodes Unicode escapes in the JSON payload.
func DecodeUnicodeEscapesInSSE(input []byte) []byte {
lines := strings.Split(string(input), "\n")
var result strings.Builder
for i, line := range lines {
if strings.HasPrefix(line, "data: ") {
// Decode Unicode escapes in the JSON payload
jsonData := line[6:]
decodedData := DecodeUnicodeEscapes([]byte(jsonData))
result.WriteString("data: ")
result.Write(decodedData)
} else {
result.WriteString(line)
}
if i < len(lines)-1 {
result.WriteString("\n")
}
}
return []byte(result.String())
}

View File

@@ -0,0 +1,108 @@
package util
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDecodeUnicodeEscapes(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Chinese characters",
input: `\u4e2d\u6587\u6d4b\u8bd5`,
expected: `中文测试`,
},
{
name: "Mixed content",
input: `Hello \u4e16\u754c World`,
expected: `Hello 世界 World`,
},
{
name: "No escape sequences",
input: `Hello World`,
expected: `Hello World`,
},
{
name: "JSON with Unicode escapes",
input: `{"content":"\u76c8\u5229\u80fd\u529b"}`,
expected: `{"content":"盈利能力"}`,
},
{
name: "Full width parentheses",
input: `\uff08\u76c8\u5229\uff09`,
expected: `(盈利)`,
},
{
name: "Empty string",
input: ``,
expected: ``,
},
{
name: "Invalid escape sequence (not modified)",
input: `\u00GG`,
expected: `\u00GG`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DecodeUnicodeEscapes([]byte(tt.input))
assert.Equal(t, tt.expected, string(result))
})
}
}
func TestDecodeUnicodeEscapesInSSE(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "SSE data with Unicode escapes",
input: `data: {"choices":[{"delta":{"content":"\u4e2d\u6587"}}]}
`,
expected: `data: {"choices":[{"delta":{"content":"中文"}}]}
`,
},
{
name: "Multiple SSE data lines",
input: `data: {"content":"\u4e2d\u6587"}
data: {"content":"\u82f1\u6587"}
data: [DONE]
`,
expected: `data: {"content":"中文"}
data: {"content":"英文"}
data: [DONE]
`,
},
{
name: "Non-data lines unchanged",
input: ": comment\nevent: message\ndata: test\n",
expected: ": comment\nevent: message\ndata: test\n",
},
{
name: "Real Vertex AI response format",
input: `data: {"choices":[{"delta":{"content":"\uff08\u76c8\u5229\u80fd\u529b\uff09","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
`,
expected: `data: {"choices":[{"delta":{"content":"(盈利能力)","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DecodeUnicodeEscapesInSSE([]byte(tt.input))
assert.Equal(t, tt.expected, string(result))
})
}
}

View File

@@ -5,21 +5,21 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/google/uuid v1.6.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -6,7 +6,7 @@ toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)

View File

@@ -4,8 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d h1:LgYbzEBtg0+LEqoebQeMVgAB6H5SgqG+KN+gBhNfKbM=
github.com/higress-group/wasm-go v1.0.10-0.20260120033417-1c84f010156d/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -83,7 +83,6 @@ global_threshold:
token_per_minute: 1000 # 自定义规则组每分钟1000个token
redis:
service_name: redis.static
show_limit_quota_header: true
```
### 识别请求参数 apikey进行区别限流

View File

@@ -89,7 +89,6 @@ global_threshold:
token_per_minute: 1000 # 1000 tokens per minute for the custom rule group
redis:
service_name: redis.static
show_limit_quota_header: true
```
### Identify request parameter apikey for differentiated rate limiting

View File

@@ -0,0 +1,2 @@
build-go:
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go

View File

@@ -0,0 +1,61 @@
# 功能说明
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
# 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效 |
## 效果说明
如下配置
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
例如原本的请求是:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,61 @@
# Function Description
The `model-mapper` plugin implements model parameter mapping functionality based on the LLM protocol.
# Configuration Fields
| Name | Type | Requirement | Default Value | Description |
| --- | --- | --- | --- | --- |
| `modelKey` | string | Optional | model | The position of the model parameter in the request body. |
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model name in the request to the model name supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all names starting with "gpt-3-";<br/>2. Supports using "*" as a key to configure a generic fallback mapping;<br/>3. If the target mapping name is an empty string "", it indicates keeping the original model name. |
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes. |
## Effect Description
Configuration example:
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
After enabling, model parameters starting with `gpt-4-` will be replaced with `qwen-max`, `gpt-4o` will be replaced with `qwen-vl-plus`, and all other models will be replaced with `qwen-turbo`.
For example, the original request is:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the github address of the main repository of the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the original LLM request body will be modified to:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the github address of the main repository of the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,24 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/model-mapper
go 1.24.1
toolchain go1.24.7
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,30 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,195 @@
package main
import (
"encoding/json"
"errors"
"sort"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
)
func main() {}
func init() {
wrapper.SetCtx(
"model-mapper",
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.WithRebuildAfterRequests[Config](1000),
wrapper.WithRebuildMaxMemBytes[Config](200*1024*1024),
)
}
type ModelMapping struct {
Prefix string
Target string
}
type Config struct {
modelKey string
exactModelMapping map[string]string
prefixModelMapping []ModelMapping
defaultModel string
enableOnPathSuffix []string
}
func parseConfig(json gjson.Result, config *Config) error {
config.modelKey = json.Get("modelKey").String()
if config.modelKey == "" {
config.modelKey = "model"
}
modelMapping := json.Get("modelMapping")
if modelMapping.Exists() && !modelMapping.IsObject() {
return errors.New("modelMapping must be an object")
}
config.exactModelMapping = make(map[string]string)
config.prefixModelMapping = make([]ModelMapping, 0)
// To replicate C++ behavior (nlohmann::json iterates keys alphabetically),
// we collect entries and sort them by key.
type mappingEntry struct {
key string
value string
}
var entries []mappingEntry
modelMapping.ForEach(func(key, value gjson.Result) bool {
entries = append(entries, mappingEntry{
key: key.String(),
value: value.String(),
})
return true
})
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
for _, entry := range entries {
key := entry.key
value := entry.value
if key == "*" {
config.defaultModel = value
} else if strings.HasSuffix(key, "*") {
prefix := strings.TrimSuffix(key, "*")
config.prefixModelMapping = append(config.prefixModelMapping, ModelMapping{
Prefix: prefix,
Target: value,
})
} else {
config.exactModelMapping[key] = value
}
}
enableOnPathSuffix := json.Get("enableOnPathSuffix")
if enableOnPathSuffix.Exists() {
if !enableOnPathSuffix.IsArray() {
return errors.New("enableOnPathSuffix must be an array")
}
for _, item := range enableOnPathSuffix.Array() {
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
}
} else {
config.enableOnPathSuffix = []string{
"/completions",
"/embeddings",
"/images/generations",
"/audio/speech",
"/fine_tuning/jobs",
"/moderations",
"/image-synthesis",
"/video-synthesis",
"/rerank",
"/messages",
}
}
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
// Check path suffix
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
return types.ActionContinue
}
// Strip query parameters
if idx := strings.Index(path, "?"); idx != -1 {
path = path[:idx]
}
matched := false
for _, suffix := range config.enableOnPathSuffix {
if strings.HasSuffix(path, suffix) {
matched = true
break
}
}
if !matched || !ctx.HasRequestBody() {
ctx.DontReadRequestBody()
return types.ActionContinue
}
// Prepare for body processing
proxywasm.RemoveHttpRequestHeader("content-length")
// 100MB buffer limit
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
if len(body) == 0 {
return types.ActionContinue
}
if !json.Valid(body) {
log.Error("invalid json body")
return types.ActionContinue
}
oldModel := gjson.GetBytes(body, config.modelKey).String()
newModel := config.defaultModel
if newModel == "" {
newModel = oldModel
}
// Exact match
if target, ok := config.exactModelMapping[oldModel]; ok {
newModel = target
} else {
// Prefix match
for _, mapping := range config.prefixModelMapping {
if strings.HasPrefix(oldModel, mapping.Prefix) {
newModel = mapping.Target
break
}
}
}
if newModel != "" && newModel != oldModel {
newBody, err := sjson.SetBytes(body, config.modelKey, newModel)
if err != nil {
log.Errorf("failed to update model: %v", err)
return types.ActionContinue
}
proxywasm.ReplaceHttpRequestBody(newBody)
log.Debugf("model mapped, before: %s, after: %s", oldModel, newModel)
}
return types.ActionContinue
}

View File

@@ -0,0 +1,250 @@
package main
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// Basic configs for wasm test host
var (
basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "model",
"modelMapping": map[string]string{
"gpt-3.5-turbo": "gpt-4",
},
"enableOnPathSuffix": []string{
"/v1/chat/completions",
},
})
return data
}()
customConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "request.model",
"modelMapping": map[string]string{
"*": "gpt-4o",
"gpt-3.5*": "gpt-4-mini",
"gpt-3.5-t": "gpt-4-turbo",
"gpt-3.5-t1": "gpt-4-turbo-1",
},
"enableOnPathSuffix": []string{
"/v1/chat/completions",
"/v1/embeddings",
},
})
return data
}()
)
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
t.Run("basic config with defaults", func(t *testing.T) {
var cfg Config
jsonData := []byte(`{
"modelMapping": {
"gpt-3.5-turbo": "gpt-4",
"gpt-4*": "gpt-4o-mini",
"*": "gpt-4o"
}
}`)
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.NoError(t, err)
// default modelKey
require.Equal(t, "model", cfg.modelKey)
// exact mapping
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
// prefix mapping
require.Len(t, cfg.prefixModelMapping, 1)
require.Equal(t, "gpt-4", cfg.prefixModelMapping[0].Prefix)
// default model
require.Equal(t, "gpt-4o", cfg.defaultModel)
// default enabled path suffixes
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
})
t.Run("custom modelKey and enableOnPathSuffix", func(t *testing.T) {
var cfg Config
jsonData := []byte(`{
"modelKey": "request.model",
"modelMapping": {
"gpt-3.5-turbo": "gpt-4",
"gpt-3.5*": "gpt-4-mini"
},
"enableOnPathSuffix": ["/v1/chat/completions", "/v1/embeddings"]
}`)
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.NoError(t, err)
require.Equal(t, "request.model", cfg.modelKey)
require.Equal(t, "gpt-4", cfg.exactModelMapping["gpt-3.5-turbo"])
require.Len(t, cfg.prefixModelMapping, 1)
require.Equal(t, "gpt-3.5", cfg.prefixModelMapping[0].Prefix)
require.Equal(t, "gpt-4-mini", cfg.prefixModelMapping[0].Target)
require.Equal(t, 2, len(cfg.enableOnPathSuffix))
require.Contains(t, cfg.enableOnPathSuffix, "/v1/chat/completions")
require.Contains(t, cfg.enableOnPathSuffix, "/v1/embeddings")
})
t.Run("modelMapping must be object", func(t *testing.T) {
var cfg Config
jsonData := []byte(`{
"modelMapping": "invalid"
}`)
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.Error(t, err)
})
t.Run("enableOnPathSuffix must be array", func(t *testing.T) {
var cfg Config
jsonData := []byte(`{
"enableOnPathSuffix": "not-array"
}`)
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.Error(t, err)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("skip when path not matched", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/other"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
newHeaders := host.GetRequestHeaders()
// content-length should still exist because path is not enabled
foundContentLength := false
for _, h := range newHeaders {
if strings.ToLower(h[0]) == "content-length" {
foundContentLength = true
break
}
}
require.True(t, foundContentLength)
})
t.Run("process when path and content-type match", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.HeaderStopIteration, action)
newHeaders := host.GetRequestHeaders()
// content-length should be removed
for _, h := range newHeaders {
require.NotEqual(t, strings.ToLower(h[0]), "content-length")
}
})
})
}
func TestOnHttpRequestBody_ModelMapping(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("exact mapping", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hello"}]
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
require.Equal(t, "gpt-4", gjson.GetBytes(processed, "model").String())
})
t.Run("default model when key missing", func(t *testing.T) {
// use customConfig where default model is set with "*"
host, status := test.NewTestHost(customConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"request": {
"messages": [{"role": "user", "content": "hello"}]
}
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
// default model should be set at request.model
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "request.model").String())
})
t.Run("prefix mapping takes effect", func(t *testing.T) {
host, status := test.NewTestHost(customConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"request": {
"model": "gpt-3.5-turbo-16k",
"messages": [{"role": "user", "content": "hello"}]
}
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
require.Equal(t, "gpt-4-mini", gjson.GetBytes(processed, "request.model").String())
})
})
}

View File

@@ -0,0 +1,2 @@
build-go:
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go

View File

@@ -0,0 +1,98 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 |
## 运行属性
插件执行阶段:认证阶段
插件执行优先级900
## 效果说明
### 基于 model 参数进行路由
需要做如下配置:
```yaml
modelToHeader: x-higress-llm-model
```
插件会将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-model: qwen-long
### 提取 model 参数中的 provider 字段用于路由
> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
需要做如下配置:
```yaml
addProviderHeader: x-higress-llm-provider
```
插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: dashscope
原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,97 @@
## Feature Description
The `model-router` plugin implements routing functionality based on the model parameter in LLM protocols.
## Configuration Fields
| Name | Data Type | Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | Location of the model parameter in the request body |
| `addProviderHeader` | string | Optional | - | Which request header to add the provider name parsed from the model parameter |
| `modelToHeader` | string | Optional | - | Which request header to directly add the model parameter to |
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis","/rerank","/messages"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths |
## Runtime Properties
Plugin execution phase: Authentication phase
Plugin execution priority: 900
## Effect Description
### Routing Based on Model Parameter
The following configuration is needed:
```yaml
modelToHeader: x-higress-llm-model
```
The plugin extracts the model parameter from the request and sets it to the x-higress-llm-model request header for subsequent routing. For example, the original LLM request body is:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header will be added (can be used for route matching):
x-higress-llm-model: qwen-long
### Extracting Provider Field from Model Parameter for Routing
> Note that this mode requires the client to specify the provider in the model parameter using the `/` delimiter
The following configuration is needed:
```yaml
addProviderHeader: x-higress-llm-provider
```
The plugin extracts the provider part (if any) from the model parameter in the request, sets it to the x-higress-llm-provider request header for subsequent routing, and rewrites the model parameter to only contain the model name part. For example, the original LLM request body is:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header will be added (can be used for route matching):
x-higress-llm-provider: dashscope
The original LLM request body will be changed to:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the Higress project's main repository?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}

View File

@@ -0,0 +1,24 @@
module model-router
go 1.24.1
toolchain go1.24.7
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,30 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,259 @@
package main
import (
"bytes"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
)
func main() {}
func init() {
wrapper.SetCtx(
"model-router",
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.WithRebuildAfterRequests[ModelRouterConfig](1000),
wrapper.WithRebuildMaxMemBytes[ModelRouterConfig](200*1024*1024),
)
}
type ModelRouterConfig struct {
modelKey string
addProviderHeader string
modelToHeader string
enableOnPathSuffix []string
}
func parseConfig(json gjson.Result, config *ModelRouterConfig) error {
config.modelKey = json.Get("modelKey").String()
if config.modelKey == "" {
config.modelKey = "model"
}
config.addProviderHeader = json.Get("addProviderHeader").String()
config.modelToHeader = json.Get("modelToHeader").String()
enableOnPathSuffix := json.Get("enableOnPathSuffix")
if enableOnPathSuffix.Exists() && enableOnPathSuffix.IsArray() {
for _, item := range enableOnPathSuffix.Array() {
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
}
} else {
// Default suffixes if not provided
config.enableOnPathSuffix = []string{
"/completions",
"/embeddings",
"/images/generations",
"/audio/speech",
"/fine_tuning/jobs",
"/moderations",
"/image-synthesis",
"/video-synthesis",
"/rerank",
"/messages",
}
}
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ModelRouterConfig) types.Action {
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
return types.ActionContinue
}
// Remove query parameters for suffix check
if idx := strings.Index(path, "?"); idx != -1 {
path = path[:idx]
}
enable := false
for _, suffix := range config.enableOnPathSuffix {
if suffix == "*" || strings.HasSuffix(path, suffix) {
enable = true
break
}
}
if !enable || !ctx.HasRequestBody() {
ctx.DontReadRequestBody()
return types.ActionContinue
}
// Prepare for body processing
proxywasm.RemoveHttpRequestHeader("content-length")
// 100MB buffer limit
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
contentType, err := proxywasm.GetHttpRequestHeader("content-type")
if err != nil {
return types.ActionContinue
}
if strings.Contains(contentType, "application/json") {
return handleJsonBody(ctx, config, body)
} else if strings.Contains(contentType, "multipart/form-data") {
return handleMultipartBody(ctx, config, body, contentType)
}
return types.ActionContinue
}
func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
if !json.Valid(body) {
log.Error("invalid json body")
return types.ActionContinue
}
modelValue := gjson.GetBytes(body, config.modelKey).String()
if modelValue == "" {
return types.ActionContinue
}
if config.modelToHeader != "" {
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
}
if config.addProviderHeader != "" {
parts := strings.SplitN(modelValue, "/", 2)
if len(parts) == 2 {
provider := parts[0]
model := parts[1]
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
newBody, err := sjson.SetBytes(body, config.modelKey, model)
if err != nil {
log.Errorf("failed to update model in json body: %v", err)
return types.ActionContinue
}
_ = proxywasm.ReplaceHttpRequestBody(newBody)
log.Debugf("model route to provider: %s, model: %s", provider, model)
} else {
log.Debugf("model route to provider not work, model: %s", modelValue)
}
}
return types.ActionContinue
}
func handleMultipartBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte, contentType string) types.Action {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
log.Errorf("failed to parse content type: %v", err)
return types.ActionContinue
}
boundary, ok := params["boundary"]
if !ok {
log.Errorf("no boundary in content type")
return types.ActionContinue
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
var newBody bytes.Buffer
writer := multipart.NewWriter(&newBody)
writer.SetBoundary(boundary)
modified := false
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
log.Errorf("failed to read multipart part: %v", err)
return types.ActionContinue
}
// Read part content
partContent, err := io.ReadAll(part)
if err != nil {
log.Errorf("failed to read part content: %v", err)
return types.ActionContinue
}
formName := part.FormName()
if formName == config.modelKey {
modelValue := string(partContent)
if config.modelToHeader != "" {
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
}
if config.addProviderHeader != "" {
parts := strings.SplitN(modelValue, "/", 2)
if len(parts) == 2 {
provider := parts[0]
model := parts[1]
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
// Write modified part
h := make(http.Header)
for k, v := range part.Header {
h[k] = v
}
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
if err != nil {
log.Errorf("failed to create part: %v", err)
return types.ActionContinue
}
_, err = pw.Write([]byte(model))
if err != nil {
log.Errorf("failed to write part content: %v", err)
return types.ActionContinue
}
modified = true
log.Debugf("model route to provider: %s, model: %s", provider, model)
continue
} else {
log.Debugf("model route to provider not work, model: %s", modelValue)
}
}
}
// Write original part
h := make(http.Header)
for k, v := range part.Header {
h[k] = v
}
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
if err != nil {
log.Errorf("failed to create part: %v", err)
return types.ActionContinue
}
_, err = pw.Write(partContent)
if err != nil {
log.Errorf("failed to write part content: %v", err)
return types.ActionContinue
}
}
writer.Close()
if modified {
_ = proxywasm.ReplaceHttpRequestBody(newBody.Bytes())
}
return types.ActionContinue
}

View File

@@ -0,0 +1,288 @@
package main
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// Basic configs for wasm test host
var (
basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "model",
"addProviderHeader": "x-provider",
"modelToHeader": "x-model",
"enableOnPathSuffix": []string{
"/v1/chat/completions",
},
})
return data
}()
defaultSuffixConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"modelKey": "model",
"addProviderHeader": "x-provider",
"modelToHeader": "x-model",
})
return data
}()
)
func getHeader(headers [][2]string, key string) (string, bool) {
for _, h := range headers {
if strings.EqualFold(h[0], key) {
return h[1], true
}
}
return "", false
}
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
t.Run("basic config with defaults", func(t *testing.T) {
var cfg ModelRouterConfig
err := parseConfig(gjson.ParseBytes(defaultSuffixConfig), &cfg)
require.NoError(t, err)
// default modelKey
require.Equal(t, "model", cfg.modelKey)
// headers
require.Equal(t, "x-provider", cfg.addProviderHeader)
require.Equal(t, "x-model", cfg.modelToHeader)
// default enabled path suffixes should contain common openai paths
require.Contains(t, cfg.enableOnPathSuffix, "/completions")
require.Contains(t, cfg.enableOnPathSuffix, "/embeddings")
})
t.Run("custom enableOnPathSuffix", func(t *testing.T) {
jsonData := []byte(`{
"modelKey": "my_model",
"addProviderHeader": "x-prov",
"modelToHeader": "x-mod",
"enableOnPathSuffix": ["/foo", "/bar"]
}`)
var cfg ModelRouterConfig
err := parseConfig(gjson.ParseBytes(jsonData), &cfg)
require.NoError(t, err)
require.Equal(t, "my_model", cfg.modelKey)
require.Equal(t, "x-prov", cfg.addProviderHeader)
require.Equal(t, "x-mod", cfg.modelToHeader)
require.Equal(t, []string{"/foo", "/bar"}, cfg.enableOnPathSuffix)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("skip when path not matched", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/other"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.True(t, found, "content-length should be kept when path not enabled")
})
t.Run("process when path and content-type match", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.HeaderStopIteration, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.False(t, found, "content-length should be removed when buffering body")
})
t.Run("do not process for unsupported content-type", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "text/plain"},
{"content-length", "123"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.HeaderStopIteration, action)
newHeaders := host.GetRequestHeaders()
_, found := getHeader(newHeaders, "content-length")
require.False(t, found, "content-length should not be removed for unsupported content-type")
})
})
}
func TestOnHttpRequestBody_JSON(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("set headers and rewrite model when provider/model format", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"model": "openai/gpt-4o",
"messages": [{"role": "user", "content": "hello"}]
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
// model should be rewritten to only the model part
require.Equal(t, "gpt-4o", gjson.GetBytes(processed, "model").String())
headers := host.GetRequestHeaders()
hv, found := getHeader(headers, "x-model")
require.True(t, found)
require.Equal(t, "openai/gpt-4o", hv)
pv, found := getHeader(headers, "x-provider")
require.True(t, found)
require.Equal(t, "openai", pv)
})
t.Run("no change when model not provided", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", "application/json"},
})
origBody := []byte(`{
"messages": [{"role": "user", "content": "hello"}]
}`)
action := host.CallOnHttpRequestBody(origBody)
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
// body should remain nil or unchanged as plugin does nothing
if processed != nil {
require.JSONEq(t, string(origBody), string(processed))
}
_, found := getHeader(host.GetRequestHeaders(), "x-provider")
require.False(t, found)
})
})
}
func TestOnHttpRequestBody_Multipart(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// model field
modelWriter, err := writer.CreateFormField("model")
require.NoError(t, err)
_, err = modelWriter.Write([]byte("openai/gpt-4o"))
require.NoError(t, err)
// another field to ensure others are preserved
fileWriter, err := writer.CreateFormField("prompt")
require.NoError(t, err)
_, err = fileWriter.Write([]byte("hello"))
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
contentType := "multipart/form-data; boundary=" + writer.Boundary()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-type", contentType},
})
action := host.CallOnHttpRequestBody(buf.Bytes())
require.Equal(t, types.ActionContinue, action)
processed := host.GetRequestBody()
require.NotNil(t, processed)
// Parse multipart body again to verify fields
reader := multipart.NewReader(bytes.NewReader(processed), writer.Boundary())
foundModel := false
foundPrompt := false
for {
part, err := reader.NextPart()
if err != nil {
break
}
name := part.FormName()
data, err := io.ReadAll(part)
require.NoError(t, err)
switch name {
case "model":
foundModel = true
require.Equal(t, "gpt-4o", string(data))
case "prompt":
foundPrompt = true
require.Equal(t, "hello", string(data))
}
}
require.True(t, foundModel)
require.True(t, foundPrompt)
headers := host.GetRequestHeaders()
hv, found := getHeader(headers, "x-model")
require.True(t, found)
require.Equal(t, "openai/gpt-4o", hv)
pv, found := getHeader(headers, "x-provider")
require.True(t, found)
require.Equal(t, "openai", pv)
})
}

View File

@@ -0,0 +1,206 @@
---
title: 请求响应编辑
keywords: [ higress,request,response,edit ]
description: 请求响应编辑插件使用说明
---
## 功能说明
`traffic-editor` 插件可以对请求/响应头进行修改,支持的修改操作类型包括删除、重命名、更新、添加、追加、映射、去重。
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`100`
## 配置字段
| 字段名 | 类型 | 必填 | 说明 |
|--------------------|-----------------------------------------|----|---------------------|
| defaultConfig | object (CommandSet) | 否 | 默认命令集配置,无条件执行的编辑操作 |
| conditionalConfigs | array of object (ConditionalCommandSet) | 否 | 条件命令集配置,按条件执行不同编辑操作 |
### CommandSet 结构
| 字段名 | 类型 | 必填 | 默认值 | 说明 |
|----------------|---------------------------|----|-------|------------|
| disableReroute | bool | 否 | false | 是否禁用自动路由重选 |
| commands | array of object (Command) | 是 | - | 编辑命令列表 |
### ConditionalCommandSet 结构
| 字段名 | 类型 | 必填 | 说明 |
|------------|---------------------------|----|---------------------|
| conditions | array | 是 | 条件列表,见下表 |
| commands | array of object (Command) | 是 | 命令列表,结构同 CommandSet |
#### Command 结构
| 字段名 | 类型 | 必填 | 说明 |
|------|--------|----|-------------------|
| type | string | 是 | 命令类型。其他配置字段由类型决定。 |
##### set 命令
功能为将某个字段设置为指定值。`type` 字段值为 `set`
其它字段如下:
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|--------|
| target | object (Ref) | 是 | 目标字段信息 |
| value | string | 是 | 要设置的值 |
##### concat 命令
功能为将多个值拼接后赋值给目标字段。`type` 字段值为 `concat`
其它字段如下:
| 字段名 | 类型 | 必填 | 说明 |
|--------|-----------------------|----|--------------------------|
| target | object (Ref) | 是 | 目标字段信息 |
| values | array of (string/Ref) | 是 | 要拼接的值列表可以是字符串或字段引用Ref |
##### copy 命令
功能为将源字段的值复制到目标字段。`type` 字段值为 `copy`
其它字段如下:
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|--------|
| source | object (Ref) | 是 | 源字段信息 |
| target | object (Ref) | 是 | 目标字段信息 |
##### delete 命令
功能为删除指定字段。`type` 字段值为 `delete`
其它字段如下:
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|----------|
| target | object (Ref) | 是 | 要删除的字段信息 |
##### rename 命令
功能为将字段重命名。`type` 字段值为 `rename`
其它字段如下:
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|-------|
| source | object (Ref) | 是 | 原字段信息 |
| target | object (Ref) | 是 | 新字段信息 |
#### Condition 结构
| 字段名 | 类型 | 必填 | 说明 |
|------|--------|----|-------------------|
| type | string | 是 | 条件类型。其他配置字段由类型决定。 |
##### equals 条件
判断某字段值是否等于指定值。`type` 字段值为 `equals`
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|---------|
| value1 | object (Ref) | 是 | 参与比较的字段 |
| value2 | string | 是 | 目标值 |
##### prefix 条件
判断某字段值是否以指定前缀开头。`type` 字段值为 `prefix`
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|---------|
| value | object (Ref) | 是 | 参与比较的字段 |
| prefix | string | 是 | 前缀字符串 |
##### suffix 条件
判断某字段值是否以指定后缀结尾。`type` 字段值为 `suffix`
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|---------|
| value | object (Ref) | 是 | 参与比较的字段 |
| suffix | string | 是 | 后缀字符串 |
##### contains 条件
判断某字段值是否包含指定子串。`type` 字段值为 `contains`
| 字段名 | 类型 | 必填 | 说明 |
|--------|--------------|----|---------|
| value | object (Ref) | 是 | 参与比较的字段 |
| substr | string | 是 | 子串 |
##### regex 条件
判断某字段值是否匹配指定正则表达式。`type` 字段值为 `regex`
| 字段名 | 类型 | 必填 | 说明 |
|---------|--------------|----|---------|
| value | object (Ref) | 是 | 参与比较的字段 |
| pattern | string | 是 | 正则表达式 |
#### Ref 结构
用于标识一个请求或响应中的字段。
| 字段名 | 类型 | 必填 | 说明 |
|------|--------|----|--------------------------------------------------------------|
| type | string | 是 | 字段类型。可选值有:`request_header``request_query``response_header` |
| name | string | 是 | 字段名称 |
### 示例配置
```json
{
"defaultConfig": {
"disableReroute": false,
"commands": [
{
"type": "set",
"target": {
"type": "request_header",
"name": "x-user"
},
"value": "admin"
},
{
"type": "delete",
"target": {
"type": "request_header",
"name": "x-dummy"
}
}
]
},
"conditionalConfigs": [
{
"conditions": [
{
"type": "equals",
"value1": {
"type": "request_query",
"name": "id"
},
"value2": "1"
}
],
"commands": [
{
"type": "set",
"target": {
"type": "response_header",
"name": "x-id"
},
"value": "1"
}
]
}
]
}
```

View File

@@ -0,0 +1,212 @@
---
title: Request/Response Editor
keywords: [higress,request,response,edit]
description: Usage guide for the request/response editor plugin
---
## Features
The `traffic-editor` plugin allows you to modify request/response headers. Supported operations include delete, rename, update, add, append, map, and deduplicate.
## Runtime Properties
Plugin execution phase: `UNSPECIFIED`
Plugin execution priority: `100`
## Configuration Fields
| Field Name | Type | Required | Description |
|--------------------|-------------------------------------------|----------|-----------------------------------|
| defaultConfig | object (CommandSet) | No | Default command set, executed unconditionally |
| conditionalConfigs | array of object (ConditionalCommandSet) | No | Conditional command sets, executed based on conditions |
### CommandSet Structure
| Field Name | Type | Required | Default | Description |
|----------------|----------------------------|----------|---------|----------------------------|
| disableReroute | bool | No | false | Whether to disable automatic route selection |
| commands | array of object (Command) | Yes | - | List of edit commands |
### ConditionalCommandSet Structure
| Field Name | Type | Required | Description |
|-------------|----------------------------|----------|-----------------------------------|
| conditions | array | Yes | List of conditions, see below |
| commands | array of object (Command) | Yes | List of commands, same as CommandSet |
#### Command Structure
| Field Name | Type | Required | Description |
|------------|--------|----------|------------------------------|
| type | string | Yes | Command type, other fields depend on type |
##### set Command
Sets a field to a specified value. `type` field value is `set`.
Other fields:
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| target | object (Ref) | Yes | Target field info|
| value | string | Yes | Value to set |
##### concat Command
Concatenates multiple values and assigns to the target field. `type` field value is `concat`.
Other fields:
| Field Name | Type | Required | Description |
|------------|-----------------------|----------|----------------------------------------------|
| target | object (Ref) | Yes | Target field info |
| values | array of (string/Ref) | Yes | Values to concatenate, can be string or Ref |
##### copy Command
Copies the value from the source field to the target field. `type` field value is `copy`.
Other fields:
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| source | object (Ref) | Yes | Source field info|
| target | object (Ref) | Yes | Target field info|
##### delete Command
Deletes the specified field. `type` field value is `delete`.
Other fields:
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| target | object (Ref) | Yes | Field to delete |
##### rename Command
Renames a field. `type` field value is `rename`.
Other fields:
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| source | object (Ref) | Yes | Original field info|
| target | object (Ref) | Yes | New field info |
#### Condition Structure
| Field Name | Type | Required | Description |
|------------|--------|----------|------------------------------|
| type | string | Yes | Condition type, other fields depend on type |
##### equals Condition
Checks if a field value equals the specified value. `type` field value is `equals`.
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| value1 | object (Ref) | Yes | Field to compare |
| value2 | string | Yes | Target value |
##### prefix Condition
Checks if a field value starts with the specified prefix. `type` field value is `prefix`.
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| value | object (Ref) | Yes | Field to compare |
| prefix | string | Yes | Prefix string |
##### suffix Condition
Checks if a field value ends with the specified suffix. `type` field value is `suffix`.
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| value | object (Ref) | Yes | Field to compare |
| suffix | string | Yes | Suffix string |
##### contains Condition
Checks if a field value contains the specified substring. `type` field value is `contains`.
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| value | object (Ref) | Yes | Field to compare |
| substr | string | Yes | Substring |
##### regex Condition
Checks if a field value matches the specified regular expression. `type` field value is `regex`.
| Field Name | Type | Required | Description |
|------------|---------------|----------|------------------|
| value | object (Ref) | Yes | Field to compare |
| pattern | string | Yes | Regular expression|
#### Ref Structure
Used to identify a field in the request or response.
| Field Name | Type | Required | Description |
|------------|--------|----------|------------------------------------------------------------------|
| type | string | Yes | Field type: `request_header`, `request_query`, `response_header` |
| name | string | Yes | Field name |
### Example Configuration
```json
{
"defaultConfig": {
"disableReroute": false,
"commands": [
{ "type": "set", "target": { "type": "request_header", "name": "x-user" }, "value": "admin" },
{ "type": "delete", "target": { "type": "request_header", "name": "x-remove" } }
]
},
"conditionalConfigs": [
{
"conditions": [
{ "type": "equals", "value1": { "type": "request_query", "name": "role" }, "value2": "admin" }
],
"commands": [
{ "type": "set", "target": { "type": "response_header", "name": "x-status" }, "value": "is-admin" }
]
},
{
"conditions": [
{ "type": "prefix", "value": { "type": "request_header", "name": "x-path" }, "prefix": "/api/" }
],
"commands": [
{ "type": "rename", "source": { "type": "request_header", "name": "x-old" }, "target": { "type": "request_header", "name": "x-new" } }
]
},
{
"conditions": [
{ "type": "suffix", "value": { "type": "request_header", "name": "x-path" }, "suffix": ".json" }
],
"commands": [
{ "type": "copy", "source": { "type": "request_query", "name": "id" }, "target": { "type": "response_header", "name": "x-id" } }
]
},
{
"conditions": [
{ "type": "contains", "value": { "type": "request_header", "name": "x-info" }, "substr": "test" }
],
"commands": [
{ "type": "concat", "target": { "type": "response_header", "name": "x-token" }, "values": ["prefix-", { "type": "request_query", "name": "token" }] }
]
},
{
"conditions": [
{ "type": "regex", "value": { "type": "request_query", "name": "email" }, "pattern": "^.+@example\\.com$" }
],
"commands": [
{ "type": "delete", "target": { "type": "response_header", "name": "x-temp" } }
]
}
]
}
```

View File

@@ -0,0 +1 @@
1.0.0-alpha

View File

@@ -0,0 +1,37 @@
package main
import (
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor/pkg"
)
type PluginConfig struct {
DefaultConfig *pkg.CommandSet `json:"defaultConfig,omitempty"`
ConditionalConfigs []*pkg.ConditionalCommandSet `json:"conditionalConfigs,omitempty"`
}
func (c *PluginConfig) FromJson(json gjson.Result) error {
c.DefaultConfig = nil
defaultConfigJson := json.Get("defaultConfig")
if defaultConfigJson.Exists() && defaultConfigJson.IsObject() {
c.DefaultConfig = &pkg.CommandSet{}
if err := c.DefaultConfig.FromJson(defaultConfigJson); err != nil {
return err
}
}
c.ConditionalConfigs = nil
conditionalConfigsJson := json.Get("conditionalConfigs")
if conditionalConfigsJson.Exists() && conditionalConfigsJson.IsArray() {
for _, item := range conditionalConfigsJson.Array() {
config := &pkg.ConditionalCommandSet{}
if err := config.FromJson(item); err != nil {
return err
}
c.ConditionalConfigs = append(c.ConditionalConfigs, config)
}
}
return nil
}

View File

@@ -0,0 +1,26 @@
version: '3.7'
services:
envoy:
image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:v2.1.6
entrypoint: /usr/local/bin/envoy
# 注意这里对wasm开启了debug级别日志正式部署时则默认info级别
command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug
#depends_on:
# - httpbin
networks:
- wasmtest
ports:
- "10000:10000"
volumes:
- ./envoy.yaml:/etc/envoy/envoy.yaml
- ./plugin.wasm:/etc/envoy/main.wasm
httpbin:
image: kong/httpbin:latest
networks:
- wasmtest
ports:
- "12345:80"
networks:
wasmtest: {}

View File

@@ -0,0 +1,24 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor
go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,31 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2 h1:8fQqR+wHts8tP+v7GYxmsCNyW5nAjn9wPYV0/+Seqzg=
github.com/higress-group/wasm-go v1.0.2/go.mod h1:882/J8ccU4i+LeyFKmeicbHWAYLj8y7YZr60zk0OOCI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,22 @@
package main
import "strings"
func headerSlice2Map(headerSlice [][2]string) map[string][]string {
headerMap := make(map[string][]string)
for _, header := range headerSlice {
k, v := strings.ToLower(header[0]), header[1]
headerMap[k] = append(headerMap[k], v)
}
return headerMap
}
func headerMap2Slice(headerMap map[string][]string) [][2]string {
headerSlice := make([][2]string, 0, len(headerMap))
for k, vs := range headerMap {
for _, v := range vs {
headerSlice = append(headerSlice, [2]string{k, v})
}
}
return headerSlice
}

View File

@@ -0,0 +1,177 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/extensions/traffic-editor/pkg"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
const (
ctxKeyEditorContext = "editorContext"
)
func main() {}
func init() {
wrapper.SetCtx(
"traffic-editor",
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
)
}
func parseConfig(json gjson.Result, config *PluginConfig) (err error) {
if err := config.FromJson(json); err != nil {
return fmt.Errorf("failed to parse plugin config: %v", err)
}
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig) types.Action {
log.Debugf("onHttpRequestHeaders called with config")
editorContext := pkg.NewEditorContext()
if headers, err := proxywasm.GetHttpRequestHeaders(); err == nil {
editorContext.SetRequestHeaders(headerSlice2Map(headers))
} else {
log.Errorf("failed to get request headers: %v", err)
}
saveEditorContext(ctx, editorContext)
effectiveCommandSet := findEffectiveCommandSet(editorContext, &config)
if effectiveCommandSet == nil {
log.Debugf("no effective command set found for request %s", ctx.Path())
return types.ActionContinue
}
if len(effectiveCommandSet.Commands) == 0 {
log.Debugf("the effective command set found for request %s is empty", ctx.Path())
return types.ActionContinue
}
log.Debugf("an effective command set found for request %s with %d commands", ctx.Path(), len(effectiveCommandSet.Commands))
editorContext.SetEffectiveCommandSet(effectiveCommandSet)
editorContext.SetCommandExecutors(effectiveCommandSet.CreatExecutors())
// Make sure the editor context is clean before executing any command.
editorContext.ResetDirtyFlags()
if effectiveCommandSet.DisableReroute {
ctx.DisableReroute()
}
executeCommands(editorContext, pkg.StageRequestHeaders)
if err := saveRequestHeaderChanges(editorContext); err != nil {
log.Errorf("failed to save request header changes: %v", err)
}
// Make sure the editor context is clean before continue.
editorContext.ResetDirtyFlags()
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig) types.Action {
log.Debugf("onHttpResponseHeaders called with config")
editorContext := loadEditorContext(ctx)
if editorContext.GetEffectiveCommandSet() == nil {
log.Debugf("no effective command set found for request %s", ctx.Path())
return types.ActionContinue
}
if headers, err := proxywasm.GetHttpResponseHeaders(); err == nil {
editorContext.SetResponseHeaders(headerSlice2Map(headers))
} else {
log.Errorf("failed to get response headers: %v", err)
}
// Make sure the editor context is clean before executing any command.
editorContext.ResetDirtyFlags()
executeCommands(editorContext, pkg.StageResponseHeaders)
if err := saveResponseHeaderChanges(editorContext); err != nil {
log.Errorf("failed to save response header changes: %v", err)
}
// Make sure the editor context is clean before continue.
editorContext.ResetDirtyFlags()
return types.ActionContinue
}
func findEffectiveCommandSet(editorContext pkg.EditorContext, config *PluginConfig) *pkg.CommandSet {
if config == nil {
return nil
}
if len(config.ConditionalConfigs) != 0 {
for i, conditionalConfig := range config.ConditionalConfigs {
log.Debugf("Evaluating conditional config %d: %+v", i, conditionalConfig)
if conditionalConfig.Matches(editorContext) {
log.Debugf("Use the conditional command set %d", i)
return &conditionalConfig.CommandSet
}
}
}
log.Debugf("Use the default command set")
return config.DefaultConfig
}
func executeCommands(editorContext pkg.EditorContext, stage pkg.Stage) {
for _, executor := range editorContext.GetCommandExecutors() {
if err := executor.Run(editorContext, stage); err != nil {
log.Errorf("failed to execute a %s command in stage %s: %v", executor.GetCommand().GetType(), pkg.Stage2String[stage], err)
}
}
}
func saveRequestHeaderChanges(editorContext pkg.EditorContext) error {
if !editorContext.IsRequestHeadersDirty() {
log.Debugf("no request header change to save")
return nil
}
log.Debugf("saving request header changes: %v", editorContext.GetRequestHeaders())
headerSlice := headerMap2Slice(editorContext.GetRequestHeaders())
return proxywasm.ReplaceHttpRequestHeaders(headerSlice)
}
func saveResponseHeaderChanges(editorContext pkg.EditorContext) error {
if !editorContext.IsResponseHeadersDirty() {
log.Debugf("no response header change to save")
return nil
}
log.Debugf("saving response header changes: %v", editorContext.GetResponseHeaders())
headerSlice := headerMap2Slice(editorContext.GetResponseHeaders())
return proxywasm.ReplaceHttpResponseHeaders(headerSlice)
}
func loadEditorContext(ctx wrapper.HttpContext) pkg.EditorContext {
editorContext, _ := ctx.GetContext(ctxKeyEditorContext).(pkg.EditorContext)
return editorContext
}
func saveEditorContext(ctx wrapper.HttpContext, editorContext pkg.EditorContext) {
ctx.SetContext(ctxKeyEditorContext, editorContext)
}

View File

@@ -0,0 +1,306 @@
package main
import (
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
func TestSample(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("default config only", func(t *testing.T) {
host, status := test.NewTestHost([]byte(`
{
"defaultConfig": {
"commands": [
{
"type": "set",
"target": {
"type": "request_header",
"name": "x-test"
},
"value": "123456"
}
]
}
}
`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.ActionContinue, action)
expectedNewHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"x-test", "123456"},
{"Content-Type", "application/json"},
}
newHeaders := host.GetRequestHeaders()
require.True(t, compareHeaders(expectedNewHeaders, newHeaders), "expected headers: %v, got: %v", expectedNewHeaders, newHeaders)
})
})
}
func TestSetMultipleRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"defaultConfig": {
"commands": [
{"type": "set", "target": {"type": "request_header", "name": "x-a"}, "value": "aaa"},
{"type": "set", "target": {"type": "request_header", "name": "x-b"}, "value": "bbb"}
]
}
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-c", "ccc"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
expectedHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-a", "aaa"},
{"x-b", "bbb"},
{"x-c", "ccc"},
}
newHeaders := host.GetRequestHeaders()
require.True(t, compareHeaders(expectedHeaders, newHeaders))
})
}
func TestConditionalConfigMatch(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"defaultConfig": {
"commands": [
{"type": "set", "target": {"type": "request_header", "name": "x-def"}, "value": "default"}
]
},
"conditionalConfigs": [
{
"conditions": [
{"type": "equals", "value1": {"type": "request_header", "name": "x-cond"}, "value2": "match"}
],
"commands": [
{"type": "set", "target": {"type": "request_header", "name": "x-special"}, "value": "special"}
]
}
]
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/data"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-cond", "match"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
expectedHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/data"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-cond", "match"},
{"x-special", "special"},
}
newHeaders := host.GetRequestHeaders()
require.True(t, compareHeaders(expectedHeaders, newHeaders))
})
}
func TestConditionalConfigNoMatch(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"defaultConfig": {
"commands": [
{"type": "set", "target": {"type": "request_header", "name": "x-def"}, "value": "default"}
]
},
"conditionalConfigs": [
{
"conditions": [
{"type": "equals", "value1": {"type": "request_header", "name": "x-cond"}, "value2": "match"}
],
"commands": [
{"type": "set", "target": {"type": "request_header", "name": "x-special"}, "value": "special"}
]
}
]
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-cond", "notmatch"},
}
action := host.CallOnHttpRequestHeaders(originalHeaders)
require.Equal(t, types.ActionContinue, action)
expectedHeaders := [][2]string{
{":authority", "example.com"},
{":path", "/get"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"x-cond", "notmatch"},
{"x-def", "default"},
}
newHeaders := host.GetRequestHeaders()
require.True(t, compareHeaders(expectedHeaders, newHeaders))
})
}
func TestSetResponseHeader(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"defaultConfig": {
"commands": [
{"type": "set", "target": {"type": "response_header", "name": "x-res"}, "value": "respval"}
]
}
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpResponseHeaders([][2]string{{"x-origin", "originval"}})
require.Equal(t, types.ActionContinue, action)
newHeaders := host.GetResponseHeaders()
require.True(t, compareHeaders([][2]string{{"x-origin", "originval"}, {"x-res", "respval"}}, newHeaders))
})
}
func TestPathQueryParseAndHeaderChange(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"defaultConfig": {
"commands": [
{"type": "set", "target": {"type": "request_query", "name": "foo"}, "value": "bar"}
]
}
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":method", "POST"},
{"Content-Type", "application/json"},
{":path", "/get?foo=old&baz=1"},
})
require.Equal(t, types.ActionContinue, action)
newHeaders := host.GetRequestHeaders()
found := false
for _, h := range newHeaders {
if h[0] == ":path" && strings.Contains(h[1], "foo=bar") {
found = true
}
}
require.True(t, found, "path header should be updated with foo=bar")
})
}
func TestConditionSetMultiStage(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"conditionalConfigs": [
{
"conditions": [
{"type": "equals", "value1": {"type": "request_header", "name": "x-a"}, "value2": "aaa"}
],
"commands": [
{"type": "set", "target": {"type": "response_header", "name": "x-b"}, "value": "bbb"}
]
}
]
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
actionReq := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":method", "POST"},
{"Content-Type", "application/json"},
{":path", "/get?foo=old&baz=1"},
{"x-a", "aaa"},
})
require.Equal(t, types.ActionContinue, actionReq)
actionResp := host.CallOnHttpResponseHeaders([][2]string{{"content-type", "application/json"}})
require.Equal(t, types.ActionContinue, actionResp)
newHeaders := host.GetResponseHeaders()
require.True(t, compareHeaders([][2]string{{"x-b", "bbb"}, {"content-type", "application/json"}}, newHeaders))
})
}
func TestConditionSetMultiStage2(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost([]byte(`{
"conditionalConfigs": [
{
"conditions": [
{"type": "equals", "value1": {"type": "request_header", "name": "x-a"}, "value2": "aaa"}
],
"commands": [
{"type": "copy", "source": {"type": "request_header", "name": "x-b"}, "target": {"type": "response_header", "name": "x-c"}}
]
}
]
}`))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
actionReq := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":method", "POST"},
{"Content-Type", "application/json"},
{":path", "/get?foo=old&baz=1"},
{"x-a", "aaa"},
{"x-b", "bbb"},
})
require.Equal(t, types.ActionContinue, actionReq)
actionResp := host.CallOnHttpResponseHeaders([][2]string{{"content-type", "application/json"}})
require.Equal(t, types.ActionContinue, actionResp)
newHeaders := host.GetResponseHeaders()
require.True(t, compareHeaders([][2]string{{"x-c", "bbb"}, {"content-type", "application/json"}}, newHeaders))
})
}
func compareHeaders(headers1, headers2 [][2]string) bool {
if len(headers1) != len(headers2) {
return false
}
m1 := make(map[string]string, len(headers1))
m2 := make(map[string]string, len(headers2))
for _, h := range headers1 {
m1[strings.ToLower(h[0])] = h[1]
}
for _, h := range headers2 {
m2[strings.ToLower(h[0])] = h[1]
}
if len(m1) != len(m2) {
return false
}
for k, v := range m1 {
if mv, ok := m2[k]; !ok || mv != v {
return false
}
}
return true
}

View File

@@ -0,0 +1,515 @@
package pkg
import (
"errors"
"fmt"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/tidwall/gjson"
)
const (
commandTypeSet = "set"
commandTypeConcat = "concat"
commandTypeCopy = "copy"
commandTypeDelete = "delete"
commandTypeRename = "rename"
)
var (
commandFactories = map[string]func(gjson.Result) (Command, error){
"set": newSetCommand,
"concat": newConcatCommand,
"copy": newCopyCommand,
"delete": newDeleteCommand,
"rename": newRenameCommand,
}
)
type CommandSet struct {
DisableReroute bool `json:"disableReroute"`
Commands []Command `json:"commands,omitempty"`
RelatedStages map[Stage]bool `json:"-"`
}
func (s *CommandSet) FromJson(json gjson.Result) error {
relatedStages := map[Stage]bool{}
if commandsJson := json.Get("commands"); commandsJson.Exists() && commandsJson.IsArray() {
for _, item := range commandsJson.Array() {
if command, err := NewCommand(item); err != nil {
return fmt.Errorf("failed to create command from json: %v\n %v", err, item)
} else {
s.Commands = append(s.Commands, command)
for _, ref := range command.GetRefs() {
relatedStages[ref.GetStage()] = true
}
}
}
}
s.RelatedStages = relatedStages
if disableReroute := json.Get("disableReroute"); disableReroute.Exists() {
s.DisableReroute = disableReroute.Bool()
} else {
s.DisableReroute = false
}
return nil
}
func (s *CommandSet) CreatExecutors() []Executor {
executors := make([]Executor, 0, len(s.Commands))
for _, command := range s.Commands {
executor := command.CreateExecutor()
executors = append(executors, executor)
}
return executors
}
type ConditionalCommandSet struct {
ConditionSet
CommandSet
}
func (s *ConditionalCommandSet) FromJson(json gjson.Result) error {
if err := s.ConditionSet.FromJson(json); err != nil {
return err
}
if err := s.CommandSet.FromJson(json); err != nil {
return err
}
return nil
}
type Command interface {
GetType() string
GetRefs() []*Ref
CreateExecutor() Executor
}
type Executor interface {
GetCommand() Command
Run(editorContext EditorContext, stage Stage) error
}
func NewCommand(json gjson.Result) (Command, error) {
t := json.Get("type").String()
if t == "" {
return nil, errors.New("command type is required")
}
if constructor, ok := commandFactories[t]; ok && constructor != nil {
return constructor(json)
} else {
return nil, errors.New("unknown command type: " + t)
}
}
type baseExecutor struct {
finished bool
}
// setCommand
func newSetCommand(json gjson.Result) (Command, error) {
var targetRef *Ref
var err error
if t := json.Get("target"); !t.Exists() {
return nil, errors.New("setCommand: target field is required")
} else {
targetRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("setCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
}
}
var value string
if v := json.Get("value"); !v.Exists() {
return nil, errors.New("setCommand: value field is required")
} else {
value = v.String()
if value == "" {
return nil, errors.New("setCommand: value cannot be empty")
}
}
return &setCommand{
targetRef: targetRef,
value: value,
}, nil
}
type setCommand struct {
targetRef *Ref
value string
}
func (c *setCommand) GetType() string {
return commandTypeSet
}
func (c *setCommand) GetRefs() []*Ref {
return []*Ref{c.targetRef}
}
func (c *setCommand) CreateExecutor() Executor {
return &setExecutor{command: c}
}
type setExecutor struct {
baseExecutor
command *setCommand
}
func (e *setExecutor) GetCommand() Command {
return e.command
}
func (e *setExecutor) Run(editorContext EditorContext, stage Stage) error {
if e.finished {
return nil
}
command := e.command
log.Debugf("setCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
if command.targetRef.GetStage() == stage {
log.Debugf("setCommand: set %s to %s", command.targetRef, command.value)
editorContext.SetRefValue(command.targetRef, command.value)
e.finished = true
}
return nil
}
// concatCommand
func newConcatCommand(json gjson.Result) (Command, error) {
var targetRef *Ref
var err error
if t := json.Get("target"); !t.Exists() {
return nil, errors.New("concatCommand: target field is required")
} else {
targetRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("concatCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
}
}
valuesJson := json.Get("values")
if !valuesJson.Exists() || !valuesJson.IsArray() {
return nil, errors.New("concatCommand: values field is required and must be an array")
}
values := make([]interface{}, 0, len(valuesJson.Array()))
for _, item := range valuesJson.Array() {
var value interface{}
if item.IsObject() {
valueRef, err := NewRef(item)
if err != nil {
return nil, fmt.Errorf("concatCommand: failed to create ref from values field: %v\n %v", err, item.Raw)
}
if valueRef.GetStage() > targetRef.GetStage() {
return nil, fmt.Errorf("concatCommand: the processing stage of value [%s] cannot be after the stage of target [%s]", Stage2String[valueRef.GetStage()], Stage2String[targetRef.GetStage()])
}
value = valueRef
} else {
value = item.String()
}
values = append(values, value)
}
return &concatCommand{
targetRef: targetRef,
values: values,
}, nil
}
type concatCommand struct {
targetRef *Ref
values []interface{}
}
func (c *concatCommand) GetType() string {
return commandTypeConcat
}
func (c *concatCommand) GetRefs() []*Ref {
refs := []*Ref{c.targetRef}
if c.values != nil && len(c.values) != 0 {
for _, value := range c.values {
if ref, ok := value.(*Ref); ok {
refs = append(refs, ref)
}
}
}
return refs
}
func (c *concatCommand) CreateExecutor() Executor {
return &concatExecutor{command: c}
}
type concatExecutor struct {
baseExecutor
command *concatCommand
values []string
}
func (e *concatExecutor) GetCommand() Command {
return e.command
}
func (e *concatExecutor) Run(editorContext EditorContext, stage Stage) error {
if e.finished {
return nil
}
command := e.command
if e.values == nil {
e.values = make([]string, len(command.values))
}
for i, value := range command.values {
if value == nil || e.values[i] != "" {
continue
}
v := ""
if s, ok := value.(string); ok {
v = s
} else if ref, ok := value.(*Ref); ok && ref.GetStage() == stage {
v = editorContext.GetRefValue(ref)
}
e.values[i] = v
}
if command.targetRef.GetStage() == stage {
result := ""
for _, v := range e.values {
if v == "" {
continue
}
result += v
}
log.Debugf("concatCommand: set %s to %s", command.targetRef, result)
editorContext.SetRefValue(command.targetRef, result)
e.finished = true
}
return nil
}
// copyCommand
func newCopyCommand(json gjson.Result) (Command, error) {
var sourceRef *Ref
var targetRef *Ref
var err error
if t := json.Get("source"); !t.Exists() {
return nil, errors.New("copyCommand: source field is required")
} else {
sourceRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("copyCommand: failed to create ref from source field: %v\n %v", err, t.Raw)
}
}
if t := json.Get("target"); !t.Exists() {
return nil, errors.New("copyCommand: target field is required")
} else {
targetRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("copyCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
}
}
if sourceRef.GetStage() > targetRef.GetStage() {
return nil, fmt.Errorf("copyCommand: the processing stage of source [%s] cannot be after the stage of target [%s]", Stage2String[sourceRef.GetStage()], Stage2String[targetRef.GetStage()])
}
return &copyCommand{
sourceRef: sourceRef,
targetRef: targetRef,
}, nil
}
type copyCommand struct {
sourceRef *Ref
targetRef *Ref
}
func (c *copyCommand) GetType() string {
return commandTypeCopy
}
func (c *copyCommand) GetRefs() []*Ref {
return []*Ref{c.sourceRef, c.targetRef}
}
func (c *copyCommand) CreateExecutor() Executor {
return &copyExecutor{command: c}
}
type copyExecutor struct {
baseExecutor
command *copyCommand
valueToCopy string
}
func (e *copyExecutor) GetCommand() Command {
return e.command
}
func (e *copyExecutor) Run(editorContext EditorContext, stage Stage) error {
if e.finished {
return nil
}
command := e.command
if command.sourceRef.GetStage() == stage {
e.valueToCopy = editorContext.GetRefValue(command.sourceRef)
log.Debugf("copyCommand: valueToCopy=%s", e.valueToCopy)
}
if e.valueToCopy == "" {
log.Debug("copyCommand: valueToCopy is empty. skip.")
e.finished = true
return nil
}
if command.targetRef.GetStage() == stage {
editorContext.SetRefValue(command.targetRef, e.valueToCopy)
log.Debugf("copyCommand: set %s to %s", e.valueToCopy, command.targetRef)
e.finished = true
}
return nil
}
// deleteCommand
func newDeleteCommand(json gjson.Result) (Command, error) {
var targetRef *Ref
var err error
if t := json.Get("target"); !t.Exists() {
return nil, errors.New("deleteCommand: target field is required")
} else {
targetRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("deleteCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
}
}
return &deleteCommand{
targetRef: targetRef,
}, nil
}
type deleteCommand struct {
targetRef *Ref
}
func (c *deleteCommand) GetType() string {
return commandTypeDelete
}
func (c *deleteCommand) GetRefs() []*Ref {
return []*Ref{c.targetRef}
}
func (c *deleteCommand) CreateExecutor() Executor {
return &deleteExecutor{command: c}
}
type deleteExecutor struct {
baseExecutor
command *deleteCommand
}
func (e *deleteExecutor) GetCommand() Command {
return e.command
}
func (e *deleteExecutor) Run(editorContext EditorContext, stage Stage) error {
if e.finished {
return nil
}
command := e.command
log.Debugf("deleteCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
if command.targetRef.GetStage() == stage {
log.Debugf("deleteCommand: delete %s", command.targetRef)
editorContext.DeleteRefValues(command.targetRef)
e.finished = true
log.Debugf("deleteCommand: finished deleting %s", command.targetRef)
} else {
log.Debugf("deleteCommand: stage %s does not match targetRef stage %s, skipping.", Stage2String[stage], Stage2String[command.targetRef.GetStage()])
}
return nil
}
// renameCommand
func newRenameCommand(json gjson.Result) (Command, error) {
var targetRef *Ref
var err error
if t := json.Get("target"); !t.Exists() {
return nil, errors.New("renameCommand: target field is required")
} else {
targetRef, err = NewRef(t)
if err != nil {
return nil, fmt.Errorf("renameCommand: failed to create ref from target field: %v\n %v", err, t.Raw)
}
}
newName := json.Get("newName").String()
if newName == "" {
return nil, errors.New("renameCommand: newName field is required")
}
return &renameCommand{
targetRef: targetRef,
newName: newName,
}, nil
}
type renameCommand struct {
targetRef *Ref
newName string
}
func (c *renameCommand) GetType() string {
return commandTypeRename
}
func (c *renameCommand) GetRefs() []*Ref {
return []*Ref{c.targetRef}
}
func (c *renameCommand) CreateExecutor() Executor {
return &renameExecutor{command: c}
}
type renameExecutor struct {
baseExecutor
command *renameCommand
}
func (e *renameExecutor) GetCommand() Command {
return e.command
}
func (e *renameExecutor) Run(editorContext EditorContext, stage Stage) error {
if e.finished {
return nil
}
command := e.command
log.Debugf("renameCommand: checking stage %s for target %s", Stage2String[stage], command.targetRef)
if command.targetRef.GetStage() == stage {
if command.newName == command.targetRef.Name {
log.Debugf("renameCommand: skip renaming %s to itself", command.targetRef)
} else {
values := editorContext.GetRefValues(command.targetRef)
log.Debugf("renameCommand: rename %s to %s value=%v", command.targetRef, command.newName, values)
editorContext.SetRefValues(&Ref{
Type: command.targetRef.Type,
Name: command.newName,
}, values)
editorContext.DeleteRefValues(command.targetRef)
log.Debugf("renameCommand: finished renaming %s to %s", command.targetRef, command.newName)
}
e.finished = true
} else {
log.Debugf("renameCommand: stage %s does not match targetRef stage %s, skipping.", Stage2String[stage], Stage2String[command.targetRef.GetStage()])
}
return nil
}

View File

@@ -0,0 +1,309 @@
package pkg
import (
"testing"
"github.com/tidwall/gjson"
)
func TestNewSetCommand_Success(t *testing.T) {
jsonStr := `{"type":"set","target":{"type":"request_header","name":"foo"},"value":"bar"}`
json := gjson.Parse(jsonStr)
cmd, err := newSetCommand(json)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if cmd.GetType() != "set" {
t.Errorf("expected type 'set', got %s", cmd.GetType())
}
refs := cmd.GetRefs()
if len(refs) != 1 {
t.Errorf("expected 1 ref, got %d", len(refs))
}
}
func TestNewSetCommand_MissingTarget(t *testing.T) {
jsonStr := `{"type":"set","value":"bar"}`
json := gjson.Parse(jsonStr)
_, err := newSetCommand(json)
if err == nil || err.Error() != "setCommand: target field is required" {
t.Errorf("expected target field error, got %v", err)
}
}
func TestNewSetCommand_MissingValue(t *testing.T) {
jsonStr := `{"type":"set","target":{"type":"request_header","name":"foo"}}`
json := gjson.Parse(jsonStr)
_, err := newSetCommand(json)
if err == nil || err.Error() != "setCommand: value field is required" {
t.Errorf("expected value field error, got %v", err)
}
}
func TestNewConcatCommand_Success(t *testing.T) {
jsonStr := `{"type":"concat","target":{"type":"request_header","name":"foo"},"values":["a","b"]}`
json := gjson.Parse(jsonStr)
cmd, err := newConcatCommand(json)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if cmd.GetType() != "concat" {
t.Errorf("expected type 'concat', got %s", cmd.GetType())
}
refs := cmd.GetRefs()
if len(refs) < 1 {
t.Errorf("expected at least 1 ref, got %d", len(refs))
}
}
func TestNewConcatCommand_MissingTarget(t *testing.T) {
jsonStr := `{"type":"concat","values":["a","b"]}`
json := gjson.Parse(jsonStr)
_, err := newConcatCommand(json)
if err == nil || err.Error() != "concatCommand: target field is required" {
t.Errorf("expected target field error, got %v", err)
}
}
func TestNewConcatCommand_MissingValues(t *testing.T) {
jsonStr := `{"type":"concat","target":{"type":"request_header","name":"foo"}}`
json := gjson.Parse(jsonStr)
_, err := newConcatCommand(json)
if err == nil || err.Error() != "concatCommand: values field is required and must be an array" {
t.Errorf("expected values field error, got %v", err)
}
}
func TestNewCopyCommand_Success(t *testing.T) {
jsonStr := `{"type":"copy","source":{"type":"request_header","name":"foo"},"target":{"type":"request_header","name":"bar"}}`
json := gjson.Parse(jsonStr)
cmd, err := newCopyCommand(json)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if cmd.GetType() != "copy" {
t.Errorf("expected type 'copy', got %s", cmd.GetType())
}
refs := cmd.GetRefs()
if len(refs) != 2 {
t.Errorf("expected 2 refs, got %d", len(refs))
}
}
func TestNewCopyCommand_MissingSource(t *testing.T) {
jsonStr := `{"type":"copy","target":{"type":"request_header","name":"bar"}}`
json := gjson.Parse(jsonStr)
_, err := newCopyCommand(json)
if err == nil || err.Error() != "copyCommand: source field is required" {
t.Errorf("expected source field error, got %v", err)
}
}
func TestNewCopyCommand_MissingTarget(t *testing.T) {
jsonStr := `{"type":"copy","source":{"type":"request_header","name":"foo"}}`
json := gjson.Parse(jsonStr)
_, err := newCopyCommand(json)
if err == nil || err.Error() != "copyCommand: target field is required" {
t.Errorf("expected target field error, got %v", err)
}
}
func TestNewCopyCommand_SourceStageAfterTarget(t *testing.T) {
jsonStr := `{"type":"copy","source":{"type":"response_header","name":"foo"},"target":{"type":"request_header","name":"bar"}}`
json := gjson.Parse(jsonStr)
_, err := newCopyCommand(json)
if err == nil || err.Error() != "copyCommand: the processing stage of source [response_headers] cannot be after the stage of target [request_headers]" {
t.Errorf("expected source stage field error, got %v", err)
}
}
func TestNewDeleteCommand_Success(t *testing.T) {
jsonStr := `{"type":"delete","target":{"type":"request_header","name":"foo"}}`
json := gjson.Parse(jsonStr)
cmd, err := newDeleteCommand(json)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if cmd.GetType() != "delete" {
t.Errorf("expected type 'delete', got %s", cmd.GetType())
}
refs := cmd.GetRefs()
if len(refs) != 1 {
t.Errorf("expected 1 ref, got %d", len(refs))
}
}
func TestNewDeleteCommand_MissingTarget(t *testing.T) {
jsonStr := `{"type":"delete"}`
json := gjson.Parse(jsonStr)
_, err := newDeleteCommand(json)
if err == nil || err.Error() != "deleteCommand: target field is required" {
t.Errorf("expected target field error, got %v", err)
}
}
func TestNewRenameCommand_Success(t *testing.T) {
jsonStr := `{"type":"rename","target":{"type":"request_header","name":"foo"},"newName":"bar"}`
json := gjson.Parse(jsonStr)
cmd, err := newRenameCommand(json)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if cmd.GetType() != "rename" {
t.Errorf("expected type 'rename', got %s", cmd.GetType())
}
refs := cmd.GetRefs()
if len(refs) != 1 {
t.Errorf("expected 1 ref, got %d", len(refs))
}
}
func TestNewRenameCommand_MissingTarget(t *testing.T) {
jsonStr := `{"type":"rename","newName":"bar"}`
json := gjson.Parse(jsonStr)
_, err := newRenameCommand(json)
if err == nil || err.Error() != "renameCommand: target field is required" {
t.Errorf("expected target field error, got %v", err)
}
}
func TestNewRenameCommand_MissingNewName(t *testing.T) {
jsonStr := `{"type":"rename","target":{"type":"request_header","name":"foo"}}`
json := gjson.Parse(jsonStr)
_, err := newRenameCommand(json)
if err == nil || err.Error() != "renameCommand: newName field is required" {
t.Errorf("expected newName field error, got %v", err)
}
}
func TestSetExecutor_Run_SingleStage(t *testing.T) {
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
cmd := &setCommand{targetRef: ref, value: "bar"}
executor := cmd.CreateExecutor()
ctx := NewEditorContext()
stage := StageRequestHeaders
err := executor.Run(ctx, stage)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(ref) != "bar" {
t.Errorf("expected value 'bar', got %s", ctx.GetRefValue(ref))
}
}
func TestConcatExecutor_Run_SingleStage(t *testing.T) {
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
srcRef := &Ref{Type: RefTypeRequestHeader, Name: "test"}
cmd := &concatCommand{targetRef: ref, values: []interface{}{"a", srcRef, "b"}}
executor := cmd.CreateExecutor()
ctx := NewEditorContext()
ctx.SetRefValue(srcRef, "-")
stage := StageRequestHeaders
err := executor.Run(ctx, stage)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(ref) != "a-b" {
t.Errorf("expected value 'a-b', got %s", ctx.GetRefValue(ref))
}
}
func TestConcatExecutor_Run_MultiStages(t *testing.T) {
ref := &Ref{Type: RefTypeResponseHeader, Name: "foo"}
srcRef := &Ref{Type: RefTypeRequestHeader, Name: "test"}
cmd := &concatCommand{targetRef: ref, values: []interface{}{"a", srcRef, "b"}}
executor := cmd.CreateExecutor()
ctx := NewEditorContext()
ctx.SetRefValue(srcRef, "-")
err := executor.Run(ctx, StageRequestHeaders)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
err = executor.Run(ctx, StageResponseHeaders)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(ref) != "a-b" {
t.Errorf("expected value 'a-b', got %s", ctx.GetRefValue(ref))
}
}
func TestCopyExecutor_Run_SingleStage(t *testing.T) {
source := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
target := &Ref{Type: RefTypeRequestHeader, Name: "bar"}
ctx := NewEditorContext()
ctx.SetRefValue(source, "baz")
cmd := &copyCommand{sourceRef: source, targetRef: target}
executor := cmd.CreateExecutor()
stage := StageRequestHeaders
err := executor.Run(ctx, stage)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(target) != "baz" {
t.Errorf("expected value 'baz' for target, got %s", ctx.GetRefValue(target))
}
}
func TestCopyExecutor_Run_MultiStages(t *testing.T) {
source := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
target := &Ref{Type: RefTypeResponseHeader, Name: "bar"}
ctx := NewEditorContext()
ctx.SetRefValue(source, "baz")
cmd := &copyCommand{sourceRef: source, targetRef: target}
executor := cmd.CreateExecutor()
err := executor.Run(ctx, StageRequestHeaders)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
err = executor.Run(ctx, StageResponseHeaders)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(target) != "baz" {
t.Errorf("expected value 'baz' for target, got %s", ctx.GetRefValue(target))
}
}
func TestDeleteExecutor_Run(t *testing.T) {
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
ctx := NewEditorContext()
ctx.SetRefValue(ref, "bar")
cmd := &deleteCommand{targetRef: ref}
executor := cmd.CreateExecutor()
stage := StageRequestHeaders
err := executor.Run(ctx, stage)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ctx.GetRefValue(ref) != "" {
t.Errorf("expected value to be deleted, got %s", ctx.GetRefValue(ref))
}
}
func TestRenameExecutor_Run(t *testing.T) {
ref := &Ref{Type: RefTypeRequestHeader, Name: "foo"}
ctx := NewEditorContext()
ctx.SetRefValue(ref, "bar")
cmd := &renameCommand{targetRef: ref, newName: "baz"}
executor := cmd.CreateExecutor()
stage := StageRequestHeaders
err := executor.Run(ctx, stage)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
newRef := &Ref{Type: ref.Type, Name: "baz"}
if ctx.GetRefValue(newRef) != "bar" {
t.Errorf("expected value 'bar' for new name, got %s", ctx.GetRefValue(newRef))
}
if ctx.GetRefValue(ref) != "" {
t.Errorf("expected old name to be deleted, got %s", ctx.GetRefValue(ref))
}
}

View File

@@ -0,0 +1,325 @@
package pkg
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/tidwall/gjson"
)
const (
conditionTypeEquals = "equals"
conditionTypePrefix = "prefix"
conditionTypeSuffix = "suffix"
conditionTypeContains = "contains"
conditionTypeRegex = "regex"
)
var (
conditionFactories = map[string]func(gjson.Result) (Condition, error){
conditionTypeEquals: newEqualsCondition,
conditionTypePrefix: newPrefixCondition,
conditionTypeSuffix: newSuffixCondition,
conditionTypeContains: newContainsCondition,
conditionTypeRegex: newRegexCondition,
}
)
type ConditionSet struct {
Conditions []Condition `json:"conditions,omitempty"`
RelatedStages map[Stage]bool `json:"-"`
}
func (s *ConditionSet) FromJson(json gjson.Result) error {
relatedStages := map[Stage]bool{}
s.Conditions = nil
if conditionsJson := json.Get("conditions"); conditionsJson.Exists() && conditionsJson.IsArray() {
for _, item := range conditionsJson.Array() {
if condition, err := CreateCondition(item); err != nil {
return fmt.Errorf("failed to create condition from json: %v\n %v", err, item)
} else {
s.Conditions = append(s.Conditions, condition)
for _, ref := range condition.GetRefs() {
relatedStages[ref.GetStage()] = true
}
}
}
}
s.RelatedStages = relatedStages
return nil
}
func (s *ConditionSet) Matches(editorContext EditorContext) bool {
if len(s.Conditions) == 0 {
return true
}
for _, condition := range s.Conditions {
if !condition.Evaluate(editorContext) {
return false
}
}
return true
}
type Condition interface {
GetType() string
GetRefs() []*Ref
Evaluate(ctx EditorContext) bool
}
func CreateCondition(json gjson.Result) (Condition, error) {
t := json.Get("type").String()
if t == "" {
return nil, errors.New("condition type is required")
}
if constructor, ok := conditionFactories[t]; !ok || constructor == nil {
return nil, errors.New("unknown condition type: " + t)
} else if condition, err := constructor(json); err != nil {
return nil, fmt.Errorf("failed to create condition with type %s: %v", t, err)
} else {
for _, ref := range condition.GetRefs() {
if ref.GetStage() >= StageResponseHeaders {
return nil, fmt.Errorf("condition only supports request refs")
}
}
return condition, nil
}
}
// equalsCondition
func newEqualsCondition(json gjson.Result) (Condition, error) {
value1 := json.Get("value1")
if value1.Type != gjson.JSON {
return nil, errors.New("equalsCondition: value1 field type must be JSON object")
}
value1Ref, err := NewRef(value1)
if err != nil {
return nil, errors.New("equalsCondition: failed to create value1 ref: " + err.Error())
}
value2 := json.Get("value2").String()
return &equalsCondition{
value1Ref: value1Ref,
value2: value2,
}, nil
}
type equalsCondition struct {
value1Ref *Ref
value2 string
}
func (c *equalsCondition) GetType() string {
return conditionTypeEquals
}
func (c *equalsCondition) GetRefs() []*Ref {
return []*Ref{c.value1Ref}
}
func (c *equalsCondition) Evaluate(ctx EditorContext) bool {
log.Debugf("Evaluating equals condition: value1Ref=%v, value2=%s", c.value1Ref, c.value2)
ref1Values := ctx.GetRefValues(c.value1Ref)
if len(ref1Values) == 0 {
log.Debugf("No values found for ref1: %v", c.value1Ref)
return false
}
for _, value1 := range ref1Values {
if value1 == c.value2 {
log.Debugf("Condition matched: %s == %s", value1, c.value2)
return true
}
}
log.Debugf("No matches found for condition: value1Ref=%v, value2=%s", c.value1Ref, c.value2)
return false
}
// prefixCondition
func newPrefixCondition(json gjson.Result) (Condition, error) {
value := json.Get("value")
if value.Type != gjson.JSON {
return nil, errors.New("prefixCondition: value field type must be JSON object")
}
valueRef, err := NewRef(value)
if err != nil {
return nil, errors.New("prefixCondition: failed to create value ref: " + err.Error())
}
prefix := json.Get("prefix").String()
return &prefixCondition{
valueRef: valueRef,
prefix: prefix,
}, nil
}
type prefixCondition struct {
valueRef *Ref
prefix string
}
func (c *prefixCondition) GetType() string {
return conditionTypePrefix
}
func (c *prefixCondition) GetRefs() []*Ref {
return []*Ref{c.valueRef}
}
func (c *prefixCondition) Evaluate(ctx EditorContext) bool {
log.Debugf("Evaluating prefix condition: valueRef=%v, prefix=%s", c.valueRef, c.prefix)
refValues := ctx.GetRefValues(c.valueRef)
if len(refValues) == 0 {
log.Debugf("No values found for ref: %v", c.valueRef)
return false
}
for _, value := range refValues {
if strings.HasPrefix(value, c.prefix) {
log.Debugf("Condition matched: %s starts with %s", value, c.prefix)
return true
}
}
log.Debugf("No matches found for condition: valueRef=%v, prefix=%s", c.valueRef, c.prefix)
return false
}
// suffixCondition
func newSuffixCondition(json gjson.Result) (Condition, error) {
value := json.Get("value")
if value.Type != gjson.JSON {
return nil, errors.New("suffixCondition: value field type must be JSON object")
}
valueRef, err := NewRef(value)
if err != nil {
return nil, errors.New("suffixCondition: failed to create value ref: " + err.Error())
}
suffix := json.Get("suffix").String()
return &suffixCondition{
valueRef: valueRef,
suffix: suffix,
}, nil
}
type suffixCondition struct {
valueRef *Ref
suffix string
}
func (c *suffixCondition) GetType() string {
return conditionTypeSuffix
}
func (c *suffixCondition) GetRefs() []*Ref {
return []*Ref{c.valueRef}
}
func (c *suffixCondition) Evaluate(ctx EditorContext) bool {
log.Debugf("Evaluating suffix condition: valueRef=%v, prefix=%s", c.valueRef, c.suffix)
refValues := ctx.GetRefValues(c.valueRef)
if len(refValues) == 0 {
log.Debugf("No values found for ref: %v", c.valueRef)
return false
}
for _, value := range refValues {
if strings.HasSuffix(value, c.suffix) {
log.Debugf("Condition matched: %s ends with %s", value, c.suffix)
return true
}
}
log.Debugf("No matches found for condition: valueRef=%v, prefix=%s", c.valueRef, c.suffix)
return false
}
// containsCondition
func newContainsCondition(json gjson.Result) (Condition, error) {
value := json.Get("value")
if value.Type != gjson.JSON {
return nil, errors.New("containsCondition: value field type must be JSON object")
}
valueRef, err := NewRef(value)
if err != nil {
return nil, errors.New("containsCondition: failed to create value ref: " + err.Error())
}
part := json.Get("part").String()
return &containsCondition{
valueRef: valueRef,
part: part,
}, nil
}
type containsCondition struct {
valueRef *Ref
part string
}
func (c *containsCondition) GetType() string {
return conditionTypeContains
}
func (c *containsCondition) GetRefs() []*Ref {
return []*Ref{c.valueRef}
}
func (c *containsCondition) Evaluate(ctx EditorContext) bool {
refValues := ctx.GetRefValues(c.valueRef)
if len(refValues) == 0 {
return false
}
for _, value := range refValues {
if strings.Contains(value, c.part) {
return true
}
}
return false
}
// regexCondition
func newRegexCondition(json gjson.Result) (Condition, error) {
value := json.Get("value")
if value.Type != gjson.JSON {
return nil, errors.New("regexCondition: value field type must be JSON object")
}
valueRef, err := NewRef(value)
if err != nil {
return nil, errors.New("regexCondition: failed to create value ref: " + err.Error())
}
patternStr := json.Get("pattern").String()
pattern, err := regexp.Compile(patternStr)
if err != nil {
return nil, errors.New("regexCondition: failed to compile pattern: " + err.Error())
}
return &regexCondition{
valueRef: valueRef,
pattern: pattern,
}, nil
}
type regexCondition struct {
valueRef *Ref
pattern *regexp.Regexp
}
func (c *regexCondition) GetType() string {
return conditionTypeRegex
}
func (c *regexCondition) Evaluate(ctx EditorContext) bool {
log.Debugf("Evaluating regex condition: valueRef=%v, pattern=%s", c.valueRef, c.pattern.String())
refValues := ctx.GetRefValues(c.valueRef)
if len(refValues) == 0 {
log.Debugf("No values found for ref: %v", c.valueRef)
return false
}
for _, value := range refValues {
if c.pattern.MatchString(value) {
log.Debugf("Condition matched: %s matches %s", value, c.pattern.String())
return true
}
}
log.Debugf("No matches found for condition: valueRef=%v, pattern=%s", c.valueRef, c.pattern.String())
return false
}
func (c *regexCondition) GetRefs() []*Ref {
return []*Ref{c.valueRef}
}

View File

@@ -0,0 +1,217 @@
package pkg
import (
"testing"
"github.com/tidwall/gjson"
)
// --- equalsCondition tests ---
func TestEqualsCondition_Match(t *testing.T) {
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
if !cond.Evaluate(ctx) {
t.Error("equalsCondition should match")
}
}
func TestEqualsCondition_NoMatch(t *testing.T) {
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
cond, _ := CreateCondition(json)
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-test": {"def"}})
if cond.Evaluate(ctx) {
t.Error("equalsCondition should not match")
}
}
// --- prefixCondition tests ---
func TestPrefixCondition_Match(t *testing.T) {
json := gjson.Parse(`{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
if !cond.Evaluate(ctx) {
t.Error("prefixCondition should match")
}
}
func TestPrefixCondition_NoMatch(t *testing.T) {
json := gjson.Parse(`{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}`)
cond, _ := CreateCondition(json)
ctx := NewEditorContext()
ctx.SetRequestQueries(map[string][]string{"foo": {"bazbar"}})
if cond.Evaluate(ctx) {
t.Error("prefixCondition should not match")
}
}
// --- suffixCondition tests ---
func TestSuffixCondition_Match(t *testing.T) {
json := gjson.Parse(`{"type":"suffix","value":{"type":"request_header","name":"x-end"},"suffix":"xyz"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-end": {"123xyz"}})
if !cond.Evaluate(ctx) {
t.Error("suffixCondition should match")
}
}
func TestSuffixCondition_NoMatch(t *testing.T) {
json := gjson.Parse(`{"type":"suffix","value":{"type":"request_header","name":"x-end"},"suffix":"xyz"}`)
cond, _ := CreateCondition(json)
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-end": {"xyz123"}})
if cond.Evaluate(ctx) {
t.Error("suffixCondition should not match")
}
}
// --- containsCondition tests ---
func TestContainsCondition_Match(t *testing.T) {
json := gjson.Parse(`{"type":"contains","value":{"type":"request_query","name":"foo"},"part":"baz"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
if !cond.Evaluate(ctx) {
t.Error("containsCondition should match")
}
}
func TestContainsCondition_NoMatch(t *testing.T) {
json := gjson.Parse(`{"type":"contains","value":{"type":"request_query","name":"foo"},"part":"baz"}`)
cond, _ := CreateCondition(json)
ctx := NewEditorContext()
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
if cond.Evaluate(ctx) {
t.Error("containsCondition should not match")
}
}
// --- regexCondition tests ---
func TestRegexCondition_Match(t *testing.T) {
json := gjson.Parse(`{"type":"regex","value":{"type":"request_header","name":"x-reg"},"pattern":"^abc.*"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-reg": {"abcdef"}})
if !cond.Evaluate(ctx) {
t.Error("regexCondition should match")
}
}
func TestRegexCondition_NoMatch(t *testing.T) {
json := gjson.Parse(`{"type":"regex","value":{"type":"request_header","name":"x-reg"},"pattern":"^abc.*"}`)
cond, _ := CreateCondition(json)
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-reg": {"defabc"}})
if cond.Evaluate(ctx) {
t.Error("regexCondition should not match")
}
}
// --- CreateCondition error cases ---
func TestCreateCondition_UnknownType(t *testing.T) {
json := gjson.Parse(`{"type":"unknown","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
_, err := CreateCondition(json)
if err == nil {
t.Error("CreateCondition should fail for unknown type")
}
}
func TestCreateCondition_MissingType(t *testing.T) {
json := gjson.Parse(`{"value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
_, err := CreateCondition(json)
if err == nil {
t.Error("CreateCondition should fail for missing type")
}
}
func TestCreateCondition_InvalidRefType(t *testing.T) {
json := gjson.Parse(`{"type":"equals","value1":{"type":"invalid_type","name":"x-test"},"value2":"abc"}`)
_, err := CreateCondition(json)
if err == nil {
t.Error("CreateCondition should fail for invalid ref type")
}
}
func TestCreateCondition_MissingRefName(t *testing.T) {
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header"},"value2":"abc"}`)
_, err := CreateCondition(json)
if err == nil {
t.Error("CreateCondition should fail for missing ref name")
}
}
// --- ConditionSet tests ---
func TestConditionSet_Matches_AllMatch(t *testing.T) {
json := gjson.Parse(`{"conditions":[{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"},{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}]}`)
var set ConditionSet
if err := set.FromJson(json); err != nil {
t.Fatalf("FromJson failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
ctx.SetRequestQueries(map[string][]string{"foo": {"barbaz"}})
if !set.Matches(ctx) {
t.Error("ConditionSet should match when all conditions match")
}
}
func TestConditionSet_Matches_OneNoMatch(t *testing.T) {
json := gjson.Parse(`{"conditions":[{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"},{"type":"prefix","value":{"type":"request_query","name":"foo"},"prefix":"bar"}]}`)
var set ConditionSet
if err := set.FromJson(json); err != nil {
t.Fatalf("FromJson failed: %v", err)
}
ctx := NewEditorContext()
ctx.SetRequestHeaders(map[string][]string{"x-test": {"abc"}})
ctx.SetRequestQueries(map[string][]string{"foo": {"baz"}})
if set.Matches(ctx) {
t.Error("ConditionSet should not match if one condition does not match")
}
}
func TestConditionSet_Matches_Empty(t *testing.T) {
json := gjson.Parse(`{"conditions":[]}`)
var set ConditionSet
if err := set.FromJson(json); err != nil {
t.Fatalf("FromJson failed: %v", err)
}
ctx := NewEditorContext()
if !set.Matches(ctx) {
t.Error("ConditionSet with no conditions should always match")
}
}
// --- GetType/GetRefs coverage ---
func TestCondition_GetTypeAndRefs(t *testing.T) {
json := gjson.Parse(`{"type":"equals","value1":{"type":"request_header","name":"x-test"},"value2":"abc"}`)
cond, err := CreateCondition(json)
if err != nil {
t.Fatalf("CreateCondition failed: %v", err)
}
if cond.GetType() != "equals" {
t.Error("GetType should return 'equals'")
}
refs := cond.GetRefs()
if len(refs) != 1 || refs[0].Type != "request_header" || refs[0].Name != "x-test" {
t.Error("GetRefs should return correct ref")
}
}

View File

@@ -0,0 +1,310 @@
package pkg
import (
"maps"
"net/url"
"strings"
"github.com/higress-group/wasm-go/pkg/log"
)
type Stage int
const (
StageInvalid Stage = iota
StageRequestHeaders
StageRequestBody
StageResponseHeaders
StageResponseBody
pathHeader = ":path"
)
var (
OrderedStages = []Stage{
StageRequestHeaders,
StageRequestBody,
StageResponseHeaders,
StageResponseBody,
}
Stage2String = map[Stage]string{
StageRequestHeaders: "request_headers",
StageRequestBody: "request_body",
StageResponseHeaders: "response_headers",
StageResponseBody: "response_body",
}
)
type EditorContext interface {
GetEffectiveCommandSet() *CommandSet
SetEffectiveCommandSet(cmdSet *CommandSet)
GetCommandExecutors() []Executor
SetCommandExecutors(executors []Executor)
GetCurrentStage() Stage
SetCurrentStage(stage Stage)
GetRequestPath() string
SetRequestPath(path string)
GetRequestHeader(key string) []string
GetRequestHeaders() map[string][]string
SetRequestHeaders(map[string][]string)
GetRequestQuery(key string) []string
GetRequestQueries() map[string][]string
SetRequestQueries(map[string][]string)
GetResponseHeader(key string) []string
GetResponseHeaders() map[string][]string
SetResponseHeaders(map[string][]string)
GetRefValue(ref *Ref) string
GetRefValues(ref *Ref) []string
SetRefValue(ref *Ref, value string)
SetRefValues(ref *Ref, values []string)
DeleteRefValues(ref *Ref)
IsRequestHeadersDirty() bool
IsResponseHeadersDirty() bool
ResetDirtyFlags()
}
func NewEditorContext() EditorContext {
return &editorContext{}
}
type editorContext struct {
effectiveCommandSet *CommandSet
commandExecutors []Executor
currentStage Stage
requestPath string
requestHeaders map[string][]string
requestQueries map[string][]string
responseHeaders map[string][]string
requestHeadersDirty bool
responseHeadersDirty bool
}
func (ctx *editorContext) GetEffectiveCommandSet() *CommandSet {
return ctx.effectiveCommandSet
}
func (ctx *editorContext) SetEffectiveCommandSet(cmdSet *CommandSet) {
ctx.effectiveCommandSet = cmdSet
}
func (ctx *editorContext) GetCommandExecutors() []Executor {
return ctx.commandExecutors
}
func (ctx *editorContext) SetCommandExecutors(executors []Executor) {
ctx.commandExecutors = executors
}
func (ctx *editorContext) GetCurrentStage() Stage {
return ctx.currentStage
}
func (ctx *editorContext) SetCurrentStage(stage Stage) {
ctx.currentStage = stage
}
func (ctx *editorContext) GetRequestPath() string {
return ctx.requestPath
}
func (ctx *editorContext) SetRequestPath(path string) {
ctx.requestPath = path
ctx.savePathToHeader()
}
func (ctx *editorContext) GetRequestHeader(key string) []string {
if ctx.requestHeaders == nil {
return nil
}
return ctx.requestHeaders[strings.ToLower(key)]
}
func (ctx *editorContext) GetRequestHeaders() map[string][]string {
return maps.Clone(ctx.requestHeaders)
}
func (ctx *editorContext) SetRequestHeaders(headers map[string][]string) {
ctx.requestHeaders = headers
ctx.loadPathFromHeader()
ctx.requestHeadersDirty = true
}
func (ctx *editorContext) GetRequestQuery(key string) []string {
if ctx.requestQueries == nil {
return nil
}
return ctx.requestQueries[key]
}
func (ctx *editorContext) GetRequestQueries() map[string][]string {
return maps.Clone(ctx.requestQueries)
}
func (ctx *editorContext) SetRequestQueries(queries map[string][]string) {
ctx.requestQueries = queries
ctx.savePathToHeader()
}
func (ctx *editorContext) GetResponseHeader(key string) []string {
if ctx.responseHeaders == nil {
return nil
}
return ctx.responseHeaders[strings.ToLower(key)]
}
func (ctx *editorContext) GetResponseHeaders() map[string][]string {
return maps.Clone(ctx.responseHeaders)
}
func (ctx *editorContext) SetResponseHeaders(headers map[string][]string) {
ctx.responseHeaders = headers
ctx.responseHeadersDirty = true
}
func (ctx *editorContext) GetRefValue(ref *Ref) string {
values := ctx.GetRefValues(ref)
if len(values) == 0 {
return ""
}
return values[0]
}
func (ctx *editorContext) GetRefValues(ref *Ref) []string {
if ref == nil {
return nil
}
switch ref.Type {
case RefTypeRequestHeader:
return ctx.GetRequestHeader(strings.ToLower(ref.Name))
case RefTypeRequestQuery:
return ctx.GetRequestQuery(ref.Name)
case RefTypeResponseHeader:
return ctx.GetResponseHeader(strings.ToLower(ref.Name))
default:
return nil
}
}
func (ctx *editorContext) SetRefValue(ref *Ref, value string) {
if ref == nil {
return
}
ctx.SetRefValues(ref, []string{value})
}
func (ctx *editorContext) SetRefValues(ref *Ref, values []string) {
if ref == nil {
return
}
switch ref.Type {
case RefTypeRequestHeader:
if ctx.requestHeaders == nil {
ctx.requestHeaders = make(map[string][]string)
}
loweredRefName := strings.ToLower(ref.Name)
ctx.requestHeaders[loweredRefName] = values
ctx.requestHeadersDirty = true
if loweredRefName == pathHeader {
ctx.loadPathFromHeader()
}
break
case RefTypeRequestQuery:
if ctx.requestQueries == nil {
ctx.requestQueries = make(map[string][]string)
}
ctx.requestQueries[ref.Name] = values
ctx.savePathToHeader()
break
case RefTypeResponseHeader:
if ctx.responseHeaders == nil {
ctx.responseHeaders = make(map[string][]string)
}
ctx.responseHeaders[strings.ToLower(ref.Name)] = values
ctx.responseHeadersDirty = true
break
}
}
func (ctx *editorContext) DeleteRefValues(ref *Ref) {
if ref == nil {
return
}
switch ref.Type {
case RefTypeRequestHeader:
delete(ctx.requestHeaders, strings.ToLower(ref.Name))
ctx.requestHeadersDirty = true
break
case RefTypeRequestQuery:
delete(ctx.requestQueries, ref.Name)
ctx.savePathToHeader()
break
case RefTypeResponseHeader:
delete(ctx.responseHeaders, strings.ToLower(ref.Name))
ctx.responseHeadersDirty = true
break
}
}
func (ctx *editorContext) IsRequestHeadersDirty() bool {
return ctx.requestHeadersDirty
}
func (ctx *editorContext) IsResponseHeadersDirty() bool {
return ctx.responseHeadersDirty
}
func (ctx *editorContext) ResetDirtyFlags() {
ctx.requestHeadersDirty = false
ctx.responseHeadersDirty = false
}
func (ctx *editorContext) savePathToHeader() {
u, err := url.Parse(ctx.requestPath)
if err != nil {
log.Errorf("failed to build the new path with query strings: %v", err)
return
}
query := url.Values{}
for k, vs := range ctx.requestQueries {
for _, v := range vs {
query.Add(k, v)
}
}
u.RawQuery = query.Encode()
ctx.SetRefValue(&Ref{Type: RefTypeRequestHeader, Name: pathHeader}, u.String())
}
func (ctx *editorContext) loadPathFromHeader() {
paths := ctx.GetRequestHeader(pathHeader)
if len(paths) == 0 || paths[0] == "" {
log.Warn("the request has an empty path")
ctx.requestPath = ""
ctx.requestQueries = make(map[string][]string)
return
}
path := paths[0]
queries := make(map[string][]string)
u, err := url.Parse(path)
if err != nil {
log.Warnf("unable to parse the request path: %s", path)
ctx.requestPath = ""
ctx.requestQueries = make(map[string][]string)
return
}
ctx.requestPath = u.Path
for k, vs := range u.Query() {
queries[k] = vs
}
ctx.requestQueries = queries
}

View File

@@ -0,0 +1,218 @@
package pkg
import (
"reflect"
"testing"
)
func newTestRef(t, name string) *Ref {
return &Ref{Type: t, Name: name}
}
func TestEditorContext_CommandSetAndExecutors(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
cmdSet := &CommandSet{}
ctx.SetEffectiveCommandSet(cmdSet)
if ctx.GetEffectiveCommandSet() != cmdSet {
t.Errorf("EffectiveCommandSet not set/get correctly")
}
executors := []Executor{nil, nil}
ctx.SetCommandExecutors(executors)
if !reflect.DeepEqual(ctx.GetCommandExecutors(), executors) {
t.Errorf("CommandExecutors not set/get correctly")
}
}
func TestEditorContext_Stage(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
ctx.SetCurrentStage(StageRequestHeaders)
if ctx.GetCurrentStage() != StageRequestHeaders {
t.Errorf("CurrentStage not set/get correctly")
}
}
func TestEditorContext_RequestPath(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
ctx.SetRequestPath("/foo/bar")
if ctx.GetRequestPath() != "/foo/bar" {
t.Errorf("RequestPath not set/get correctly")
}
}
func TestEditorContext_RequestHeaders(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
headers := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
ctx.SetRequestHeaders(headers)
if !reflect.DeepEqual(ctx.GetRequestHeaders(), headers) {
t.Errorf("RequestHeaders not set/get correctly")
}
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty not set correctly")
}
if got := ctx.GetRequestHeader("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
t.Errorf("GetRequestHeader failed")
}
}
func TestEditorContext_RequestQueries(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
queries := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
ctx.SetRequestQueries(queries)
if !reflect.DeepEqual(ctx.GetRequestQueries(), queries) {
t.Errorf("RequestQueries not set/get correctly")
}
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty not set correctly")
}
if got := ctx.GetRequestQuery("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
t.Errorf("GetRequestQuery failed")
}
}
func TestEditorContext_ResponseHeaders(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
headers := map[string][]string{"foo": {"bar"}, "baz": {"qux"}}
ctx.SetResponseHeaders(headers)
if !reflect.DeepEqual(ctx.GetResponseHeaders(), headers) {
t.Errorf("ResponseHeaders not set/get correctly")
}
if !ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty not set correctly")
}
if got := ctx.GetResponseHeader("foo"); !reflect.DeepEqual(got, []string{"bar"}) {
t.Errorf("GetResponseHeader failed")
}
}
func TestEditorContext_RefValueAndValues(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
rh := newTestRef(RefTypeRequestHeader, "foo")
rq := newTestRef(RefTypeRequestQuery, "bar")
rh2 := newTestRef(RefTypeResponseHeader, "baz")
ctx.SetRefValue(rh, "v1")
ctx.SetRefValues(rq, []string{"v2", "v3"})
ctx.SetRefValues(rh2, []string{"v4"})
if v := ctx.GetRefValue(rh); v != "v1" {
t.Errorf("GetRefValue(RequestHeader) failed: %v", v)
}
if v := ctx.GetRefValues(rq); !reflect.DeepEqual(v, []string{"v2", "v3"}) {
t.Errorf("GetRefValues(RequestQuery) failed: %v", v)
}
if v := ctx.GetRefValues(rh2); !reflect.DeepEqual(v, []string{"v4"}) {
t.Errorf("GetRefValues(ResponseHeader) failed: %v", v)
}
}
func TestEditorContext_DeleteRefValues(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
rh := newTestRef(RefTypeRequestHeader, "foo")
rq := newTestRef(RefTypeRequestQuery, "bar")
rh2 := newTestRef(RefTypeResponseHeader, "baz")
ctx.SetRefValue(rh, "v1")
ctx.SetRefValues(rq, []string{"v2", "v3"})
ctx.SetRefValues(rh2, []string{"v4"})
ctx.DeleteRefValues(rh)
ctx.DeleteRefValues(rq)
ctx.DeleteRefValues(rh2)
if v := ctx.GetRefValues(rh); len(v) != 0 {
t.Errorf("DeleteRefValues(RequestHeader) failed: %v", v)
}
if v := ctx.GetRefValues(rq); len(v) != 0 {
t.Errorf("DeleteRefValues(RequestQuery) failed: %v", v)
}
if v := ctx.GetRefValues(rh2); len(v) != 0 {
t.Errorf("DeleteRefValues(ResponseHeader) failed: %v", v)
}
}
func TestEditorContext_ResetDirtyFlags(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
ctx.SetRequestHeaders(map[string][]string{"foo": {"bar"}})
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
ctx.SetResponseHeaders(map[string][]string{"foo": {"bar"}})
ctx.ResetDirtyFlags()
if ctx.IsRequestHeadersDirty() || ctx.IsRequestHeadersDirty() || ctx.IsResponseHeadersDirty() {
t.Errorf("ResetDirtyFlags failed")
}
}
func TestEditorContext_IsRequestHeadersDirty_SetHeaders(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
if ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty should be false initially")
}
ctx.SetRequestHeaders(map[string][]string{"foo": {"bar"}})
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty should be true after SetRequestHeaders")
}
ctx.ResetDirtyFlags()
if ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty should be false after ResetDirtyFlags")
}
ref := newTestRef(RefTypeRequestHeader, "foo")
ctx.SetRefValue(ref, "baz")
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty should be true after SetRefValue")
}
ctx.ResetDirtyFlags()
ctx.DeleteRefValues(ref)
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestHeadersDirty should be true after DeleteRefValues")
}
}
func TestEditorContext_IsRequestHeadersDirty_SetQueries(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
if ctx.IsRequestHeadersDirty() {
t.Errorf("RequestQueriesDirty should be false initially")
}
ctx.SetRequestQueries(map[string][]string{"foo": {"bar"}})
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestQueriesDirty should be true after SetRequestQueries")
}
ctx.ResetDirtyFlags()
if ctx.IsRequestHeadersDirty() {
t.Errorf("RequestQueriesDirty should be false after ResetDirtyFlags")
}
ref := newTestRef(RefTypeRequestQuery, "foo")
ctx.SetRefValues(ref, []string{"baz"})
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestQueriesDirty should be true after SetRefValues")
}
ctx.ResetDirtyFlags()
ctx.DeleteRefValues(ref)
if !ctx.IsRequestHeadersDirty() {
t.Errorf("RequestQueriesDirty should be true after DeleteRefValues")
}
}
func TestEditorContext_IsResponseHeadersDirty(t *testing.T) {
ctx := NewEditorContext().(*editorContext)
if ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty should be false initially")
}
ctx.SetResponseHeaders(map[string][]string{"foo": {"bar"}})
if !ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty should be true after SetResponseHeaders")
}
ctx.ResetDirtyFlags()
if ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty should be false after ResetDirtyFlags")
}
ref := newTestRef(RefTypeResponseHeader, "foo")
ctx.SetRefValues(ref, []string{"baz"})
if !ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty should be true after SetRefValues")
}
ctx.ResetDirtyFlags()
ctx.DeleteRefValues(ref)
if !ctx.IsResponseHeadersDirty() {
t.Errorf("ResponseHeadersDirty should be true after DeleteRefValues")
}
}

View File

@@ -0,0 +1,26 @@
package pkg
import (
"github.com/higress-group/wasm-go/pkg/log"
)
func init() {
// Initialize mock logger for testing
log.SetPluginLog(&mockLogger{})
}
type mockLogger struct{}
func (m *mockLogger) Trace(msg string) {}
func (m *mockLogger) Tracef(format string, args ...interface{}) {}
func (m *mockLogger) Debug(msg string) {}
func (m *mockLogger) Debugf(format string, args ...interface{}) {}
func (m *mockLogger) Info(msg string) {}
func (m *mockLogger) Infof(format string, args ...interface{}) {}
func (m *mockLogger) Warn(msg string) {}
func (m *mockLogger) Warnf(format string, args ...interface{}) {}
func (m *mockLogger) Error(msg string) {}
func (m *mockLogger) Errorf(format string, args ...interface{}) {}
func (m *mockLogger) Critical(msg string) {}
func (m *mockLogger) Criticalf(format string, args ...interface{}) {}
func (m *mockLogger) ResetID(pluginID string) {}

View File

@@ -0,0 +1,64 @@
package pkg
import (
"errors"
"fmt"
"github.com/tidwall/gjson"
)
const (
RefTypeRequestHeader = "request_header"
RefTypeRequestQuery = "request_query"
RefTypeResponseHeader = "response_header"
)
var (
refType2Stage = map[string]Stage{
RefTypeRequestHeader: StageRequestHeaders,
RefTypeRequestQuery: StageRequestHeaders,
RefTypeResponseHeader: StageResponseHeaders,
}
)
type Ref struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
stage Stage
}
func NewRef(json gjson.Result) (*Ref, error) {
ref := &Ref{}
if t := json.Get("type").String(); t != "" {
ref.Type = t
} else {
return nil, errors.New("missing type field")
}
if _, ok := refType2Stage[ref.Type]; !ok {
return nil, fmt.Errorf("unknown ref type: %s", ref.Type)
}
if name := json.Get("name").String(); name != "" {
ref.Name = name
} else {
return nil, errors.New("missing name field")
}
return ref, nil
}
func (r *Ref) GetStage() Stage {
if r.stage == 0 {
if stage, ok := refType2Stage[r.Type]; ok {
r.stage = stage
}
}
return r.stage
}
func (r *Ref) String() string {
return fmt.Sprintf("%s/%s", r.Type, r.Name)
}

View File

@@ -9,7 +9,7 @@ replace amap-tools => ../amap-tools
require (
amap-tools v0.0.0-00010101000000-000000000000
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/stretchr/testify v1.9.0
quark-search v0.0.0-00010101000000-000000000000
)

View File

@@ -24,6 +24,10 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500 h1:4BKKZ3BreIaIGub88nlvzihTK1uJmZYYoQ7r7Xkgb5Q=
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1 h1:+usoX0B1cwECTA2qf73IaLGyCIMVopIMev5cBWGgEZk=
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=

View File

@@ -145,22 +145,43 @@ func (w *watcher) generateServiceEntry(host string) *v1alpha3.ServiceEntry {
for _, ep := range strings.Split(w.Domain, common.CommaSeparator) {
var endpoint *v1alpha3.WorkloadEntry
if w.Type == string(registry.Static) {
pair := strings.Split(ep, common.ColonSeparator)
if len(pair) != 2 {
log.Errorf("invalid endpoint:%s with static type", ep)
return nil
var ip string
var portStr string
// Support IPv6 format: [2001:db8::1]:8080
if strings.HasPrefix(ep, "[") {
// IPv6 format: [IPv6]:port
lastBracket := strings.LastIndex(ep, "]")
if lastBracket == -1 {
log.Errorf("invalid IPv6 endpoint format:%s with static type", ep)
return nil
}
ip = ep[1:lastBracket] // Extract IPv6 address without brackets
if lastBracket+1 >= len(ep) || ep[lastBracket+1] != ':' {
log.Errorf("invalid IPv6 endpoint format:%s with static type, missing colon after bracket", ep)
return nil
}
portStr = ep[lastBracket+2:] // Extract port after "]:"
} else {
// IPv4 format: 192.168.1.1:8080
pair := strings.Split(ep, common.ColonSeparator)
if len(pair) != 2 {
log.Errorf("invalid endpoint:%s with static type", ep)
return nil
}
ip = pair[0]
portStr = pair[1]
}
port, err := strconv.ParseUint(pair[1], 10, 32)
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
log.Errorf("invalid port:%s of endpoint:%s", pair[1], ep)
log.Errorf("invalid port:%s of endpoint:%s", portStr, ep)
return nil
}
if net.ParseIP(pair[0]) == nil {
log.Errorf("invalid ip:%s of endpoint:%s", pair[0], ep)
if net.ParseIP(ip) == nil {
log.Errorf("invalid ip:%s of endpoint:%s", ip, ep)
return nil
}
endpoint = &v1alpha3.WorkloadEntry{
Address: pair[0],
Address: ip,
Ports: map[string]uint32{protocol: uint32(port)},
}
} else if w.Type == string(registry.DNS) {
@@ -247,3 +268,4 @@ func (w *watcher) getSni(se *v1alpha3.ServiceEntry) string {
func (w *watcher) GetRegistryType() string {
return w.RegistryConfig.Type
}

View File

@@ -43,7 +43,7 @@ type McpServerRule struct {
MatchRoute []string `json:"_match_route_,omitempty"`
Server *ServerConfig `json:"server,omitempty"`
Tools []*McpTool `json:"tools,omitempty"`
AllowTools []string `json:"allowTools,omitempty"`
AllowTools []string `json:"allowTools"`
}
type ServerConfig struct {

View File

@@ -434,7 +434,7 @@ func (w *watcher) processToolConfig(dataId, data string, credentials map[string]
rule.Server.SecuritySchemes = toolsDescription.SecuritySchemes
}
var allowTools []string
allowTools := []string{}
for _, t := range toolsDescription.Tools {
convertTool := &provider.McpTool{Name: t.Name, Description: t.Description}
@@ -813,6 +813,9 @@ func generateServiceEntry(host string, services *model.Service) *v1alpha3.Servic
endpoints := make([]*v1alpha3.WorkloadEntry, 0)
for _, service := range services.Hosts {
if !service.Healthy || !service.Enable {
continue
}
protocol := common.HTTP
if service.Metadata != nil && service.Metadata["protocol"] != "" {
protocol = common.ParseProtocol(service.Metadata["protocol"])

View File

@@ -17,6 +17,7 @@ package v2
import (
"errors"
"fmt"
"math"
"net"
"strconv"
"strings"
@@ -556,10 +557,19 @@ func (w *watcher) generateServiceEntry(host string, services []model.Instance) *
if !isValidIP(service.Ip) {
isDnsService = true
}
// Calculate weight from Nacos instance
// Nacos weight is float64, need to convert to uint32 for Istio
// Use math.Round to preserve fractional weights (e.g., 0.5, 1.5)
// If weight is 0 or negative, use default weight 1
weight := uint32(1)
if service.Weight > 0 {
weight = uint32(math.Round(service.Weight))
}
endpoint := &v1alpha3.WorkloadEntry{
Address: service.Ip,
Ports: map[string]uint32{port.Protocol: port.Number},
Labels: service.Metadata,
Weight: weight,
}
endpoints = append(endpoints, endpoint)
}

View File

@@ -0,0 +1,200 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v2
import (
"testing"
"github.com/nacos-group/nacos-sdk-go/v2/model"
"istio.io/api/networking/v1alpha3"
)
func Test_generateServiceEntry_Weight(t *testing.T) {
w := &watcher{}
testCases := []struct {
name string
services []model.Instance
expectedWeights []uint32
description string
}{
{
name: "normal integer weights",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
{Ip: "192.168.1.3", Port: 8080, Weight: 2.0},
},
expectedWeights: []uint32{5, 3, 2},
description: "Integer weights should be converted correctly",
},
{
name: "fractional weights with rounding",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.4},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.5},
{Ip: "192.168.1.3", Port: 8080, Weight: 2.6},
},
expectedWeights: []uint32{5, 4, 3},
description: "Fractional weights should be rounded to nearest integer",
},
{
name: "zero weight defaults to 1",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 5.0},
},
expectedWeights: []uint32{1, 5},
description: "Zero weight should default to 1",
},
{
name: "negative weight defaults to 1",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: -1.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
},
expectedWeights: []uint32{1, 3},
description: "Negative weight should default to 1",
},
{
name: "very small fractional weight rounds to 0 then defaults to 1",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.4},
{Ip: "192.168.1.2", Port: 8080, Weight: 0.5},
{Ip: "192.168.1.3", Port: 8080, Weight: 0.6},
},
expectedWeights: []uint32{1, 1, 1},
description: "Weights less than 0.5 round to 0, then default to 1; 0.5 and above round to 1",
},
{
name: "large weights",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 100.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 50.5},
},
expectedWeights: []uint32{100, 51},
description: "Large weights should be handled correctly",
},
{
name: "mixed weights",
services: []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 1.5},
{Ip: "192.168.1.3", Port: 8080, Weight: -5.0},
{Ip: "192.168.1.4", Port: 8080, Weight: 10.7},
},
expectedWeights: []uint32{1, 2, 1, 11},
description: "Mixed zero, negative, and fractional weights",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
se := w.generateServiceEntry("test-host", tc.services)
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != len(tc.expectedWeights) {
t.Fatalf("expected %d endpoints, got %d", len(tc.expectedWeights), len(se.Endpoints))
}
for i, endpoint := range se.Endpoints {
if endpoint.Weight != tc.expectedWeights[i] {
t.Errorf("endpoint[%d]: expected weight %d, got %d (original weight: %f) - %s",
i, tc.expectedWeights[i], endpoint.Weight, tc.services[i].Weight, tc.description)
}
}
})
}
}
func Test_generateServiceEntry_WeightFieldSet(t *testing.T) {
w := &watcher{}
services := []model.Instance{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0, Metadata: map[string]string{"zone": "a"}},
}
se := w.generateServiceEntry("test-host", services)
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != 1 {
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
}
endpoint := se.Endpoints[0]
// Verify all fields are set correctly
if endpoint.Address != "192.168.1.1" {
t.Errorf("expected address 192.168.1.1, got %s", endpoint.Address)
}
if endpoint.Weight != 5 {
t.Errorf("expected weight 5, got %d", endpoint.Weight)
}
if endpoint.Labels == nil || endpoint.Labels["zone"] != "a" {
t.Errorf("expected labels with zone=a, got %v", endpoint.Labels)
}
if endpoint.Ports == nil {
t.Error("expected ports to be set")
}
}
func Test_generateServiceEntry_EmptyServices(t *testing.T) {
w := &watcher{}
se := w.generateServiceEntry("test-host", []model.Instance{})
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != 0 {
t.Errorf("expected 0 endpoints for empty services, got %d", len(se.Endpoints))
}
}
func Test_generateServiceEntry_DNSResolution(t *testing.T) {
w := &watcher{}
services := []model.Instance{
{Ip: "example.com", Port: 8080, Weight: 5.0},
}
se := w.generateServiceEntry("test-host", services)
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if se.Resolution != v1alpha3.ServiceEntry_DNS {
t.Errorf("expected DNS resolution for domain name, got %v", se.Resolution)
}
if len(se.Endpoints) != 1 {
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
}
if se.Endpoints[0].Weight != 5 {
t.Errorf("expected weight 5 for DNS endpoint, got %d", se.Endpoints[0].Weight)
}
}

View File

@@ -15,6 +15,7 @@
package nacos
import (
"math"
"strconv"
"strings"
"sync"
@@ -352,10 +353,19 @@ func (w *watcher) generateServiceEntry(host string, services []model.SubscribeSe
portList = append(portList, port)
}
}
// Calculate weight from Nacos instance
// Nacos weight is float64, need to convert to uint32 for Istio
// Use math.Round to preserve fractional weights (e.g., 0.5, 1.5)
// If weight is 0 or negative, use default weight 1
weight := uint32(1)
if service.Weight > 0 {
weight = uint32(math.Round(service.Weight))
}
endpoint := v1alpha3.WorkloadEntry{
Address: service.Ip,
Ports: map[string]uint32{port.Protocol: port.Number},
Labels: service.Metadata,
Weight: weight,
}
endpoints = append(endpoints, &endpoint)
}

View File

@@ -0,0 +1,173 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nacos
import (
"testing"
"github.com/nacos-group/nacos-sdk-go/model"
)
func Test_generateServiceEntry_Weight(t *testing.T) {
w := &watcher{}
testCases := []struct {
name string
services []model.SubscribeService
expectedWeights []uint32
description string
}{
{
name: "normal integer weights",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
{Ip: "192.168.1.3", Port: 8080, Weight: 2.0},
},
expectedWeights: []uint32{5, 3, 2},
description: "Integer weights should be converted correctly",
},
{
name: "fractional weights with rounding",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.4},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.5},
{Ip: "192.168.1.3", Port: 8080, Weight: 2.6},
},
expectedWeights: []uint32{5, 4, 3},
description: "Fractional weights should be rounded to nearest integer",
},
{
name: "zero weight defaults to 1",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 5.0},
},
expectedWeights: []uint32{1, 5},
description: "Zero weight should default to 1",
},
{
name: "negative weight defaults to 1",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: -1.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 3.0},
},
expectedWeights: []uint32{1, 3},
description: "Negative weight should default to 1",
},
{
name: "very small fractional weight rounds to 0 then defaults to 1",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.4},
{Ip: "192.168.1.2", Port: 8080, Weight: 0.5},
{Ip: "192.168.1.3", Port: 8080, Weight: 0.6},
},
expectedWeights: []uint32{1, 1, 1},
description: "Weights less than 0.5 round to 0, then default to 1; 0.5 and above round to 1",
},
{
name: "large weights",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 100.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 50.5},
},
expectedWeights: []uint32{100, 51},
description: "Large weights should be handled correctly",
},
{
name: "mixed weights",
services: []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 0.0},
{Ip: "192.168.1.2", Port: 8080, Weight: 1.5},
{Ip: "192.168.1.3", Port: 8080, Weight: -5.0},
{Ip: "192.168.1.4", Port: 8080, Weight: 10.7},
},
expectedWeights: []uint32{1, 2, 1, 11},
description: "Mixed zero, negative, and fractional weights",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
se := w.generateServiceEntry("test-host", tc.services)
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != len(tc.expectedWeights) {
t.Fatalf("expected %d endpoints, got %d", len(tc.expectedWeights), len(se.Endpoints))
}
for i, endpoint := range se.Endpoints {
if endpoint.Weight != tc.expectedWeights[i] {
t.Errorf("endpoint[%d]: expected weight %d, got %d (original weight: %f) - %s",
i, tc.expectedWeights[i], endpoint.Weight, tc.services[i].Weight, tc.description)
}
}
})
}
}
func Test_generateServiceEntry_WeightFieldSet(t *testing.T) {
w := &watcher{}
services := []model.SubscribeService{
{Ip: "192.168.1.1", Port: 8080, Weight: 5.0, Metadata: map[string]string{"zone": "a"}},
}
se := w.generateServiceEntry("test-host", services)
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != 1 {
t.Fatalf("expected 1 endpoint, got %d", len(se.Endpoints))
}
endpoint := se.Endpoints[0]
// Verify all fields are set correctly
if endpoint.Address != "192.168.1.1" {
t.Errorf("expected address 192.168.1.1, got %s", endpoint.Address)
}
if endpoint.Weight != 5 {
t.Errorf("expected weight 5, got %d", endpoint.Weight)
}
if endpoint.Labels == nil || endpoint.Labels["zone"] != "a" {
t.Errorf("expected labels with zone=a, got %v", endpoint.Labels)
}
if endpoint.Ports == nil {
t.Error("expected ports to be set")
}
}
func Test_generateServiceEntry_EmptyServices(t *testing.T) {
w := &watcher{}
se := w.generateServiceEntry("test-host", []model.SubscribeService{})
if se == nil {
t.Fatal("generateServiceEntry returned nil")
}
if len(se.Endpoints) != 0 {
t.Errorf("expected 0 endpoints for empty services, got %d", len(se.Endpoints))
}
}