mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
fix: Optimization of Rate Limiting Logic for Cluster, AI Token and WASM Plugin (#2997)
This commit is contained in:
@@ -45,26 +45,55 @@ func init() {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
RedisKeyPrefix string = "higress-token-ratelimit"
|
RedisKeyPrefix string = "higress-token-ratelimit"
|
||||||
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数
|
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口
|
||||||
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d"
|
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d"
|
||||||
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值
|
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值
|
||||||
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s"
|
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s"
|
||||||
RequestPhaseFixedWindowScript = `
|
RequestPhaseFixedWindowScript = `
|
||||||
local ttl = redis.call('ttl', KEYS[1])
|
local current = redis.call('get', KEYS[1])
|
||||||
if ttl < 0 then
|
local ttl = redis.call('ttl', KEYS[1])
|
||||||
redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2])
|
local threshold = tonumber(ARGV[1])
|
||||||
return {ARGV[1], ARGV[1], ARGV[2]}
|
local window = tonumber(ARGV[2])
|
||||||
end
|
|
||||||
return {ARGV[1], redis.call('get', KEYS[1]), ttl}
|
-- 键不存在时,返回初始状态(计数0,窗口时间为过期时间)
|
||||||
|
if not current then
|
||||||
|
return {threshold, 0, window}
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 修复异常过期时间(确保窗口有效)
|
||||||
|
if ttl < 0 then
|
||||||
|
ttl = window
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 返回窗口状态:阈值、当前计数、剩余时间
|
||||||
|
return {threshold, tonumber(current), ttl}
|
||||||
`
|
`
|
||||||
ResponsePhaseFixedWindowScript = `
|
ResponsePhaseFixedWindowScript = `
|
||||||
local ttl = redis.call('ttl', KEYS[1])
|
local key = KEYS[1]
|
||||||
if ttl < 0 then
|
local threshold = tonumber(ARGV[1])
|
||||||
redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2])
|
local window = tonumber(ARGV[2])
|
||||||
return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]}
|
local added = tonumber(ARGV[3]) -- 需要累加的token数量
|
||||||
end
|
|
||||||
return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl}
|
local current = tonumber(redis.call('get', key) or "0")
|
||||||
`
|
|
||||||
|
-- 只有当前计数未超过阈值时才执行累加
|
||||||
|
if current <= threshold then
|
||||||
|
current = redis.call('incrby', key, added)
|
||||||
|
-- 第一次设置值时初始化过期时间
|
||||||
|
if current == added then
|
||||||
|
redis.call('expire', key, window)
|
||||||
|
else
|
||||||
|
-- 非首次设置时检查过期时间,确保窗口有效性
|
||||||
|
local ttl = redis.call('ttl', key)
|
||||||
|
if ttl < 0 then
|
||||||
|
redis.call('expire', key, window)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 返回当前窗口状态:阈值、当前计数、剩余时间
|
||||||
|
return {threshold, current, redis.call('ttl', key)}
|
||||||
|
`
|
||||||
|
|
||||||
LimitRedisContextKey = "LimitRedisContext"
|
LimitRedisContextKey = "LimitRedisContext"
|
||||||
|
|
||||||
@@ -107,7 +136,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
|
|||||||
|
|
||||||
if cfg.GlobalThreshold != nil {
|
if cfg.GlobalThreshold != nil {
|
||||||
// 全局限流模式
|
// 全局限流模式
|
||||||
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count)
|
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow)
|
||||||
count = cfg.GlobalThreshold.Count
|
count = cfg.GlobalThreshold.Count
|
||||||
timeWindow = cfg.GlobalThreshold.TimeWindow
|
timeWindow = cfg.GlobalThreshold.TimeWindow
|
||||||
} else {
|
} else {
|
||||||
@@ -118,7 +147,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
|
|||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val)
|
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val)
|
||||||
count = configItem.Count
|
count = configItem.Count
|
||||||
timeWindow = configItem.TimeWindow
|
timeWindow = configItem.TimeWindow
|
||||||
}
|
}
|
||||||
@@ -139,12 +168,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
|
|||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取限流结果
|
||||||
|
threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer()
|
||||||
context := LimitContext{
|
context := LimitContext{
|
||||||
count: resultArray[0].Integer(),
|
count: threshold,
|
||||||
remaining: resultArray[1].Integer(),
|
remaining: threshold - current,
|
||||||
reset: resultArray[2].Integer(),
|
reset: ttl,
|
||||||
}
|
}
|
||||||
if context.remaining < 0 {
|
if current > threshold {
|
||||||
// 触发限流
|
// 触发限流
|
||||||
ctx.SetUserAttribute("token_ratelimit_status", "limited")
|
ctx.SetUserAttribute("token_ratelimit_status", "limited")
|
||||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
|
|||||||
@@ -291,8 +291,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(允许请求)
|
// 模拟 Redis 调用响应(允许请求)
|
||||||
// 返回 [count, remaining, ttl] 格式
|
// 返回 [threshold, current, ttl] 格式
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
@@ -316,7 +316,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(允许请求)
|
// 模拟 Redis 调用响应(允许请求)
|
||||||
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
|
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
@@ -339,7 +339,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(允许请求)
|
// 模拟 Redis 调用响应(允许请求)
|
||||||
resp := test.CreateRedisRespArray([]interface{}{50, 49, 60})
|
resp := test.CreateRedisRespArray([]interface{}{50, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
@@ -363,7 +363,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(允许请求)
|
// 模拟 Redis 调用响应(允许请求)
|
||||||
resp := test.CreateRedisRespArray([]interface{}{200, 199, 60})
|
resp := test.CreateRedisRespArray([]interface{}{200, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
@@ -387,7 +387,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(允许请求)
|
// 模拟 Redis 调用响应(允许请求)
|
||||||
resp := test.CreateRedisRespArray([]interface{}{75, 74, 60})
|
resp := test.CreateRedisRespArray([]interface{}{75, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
@@ -410,8 +410,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(触发限流)
|
// 模拟 Redis 调用响应(触发限流)
|
||||||
// 返回 [count, remaining, ttl] 格式,remaining < 0 表示触发限流
|
// 返回 [threshold, current, ttl] 格式,current > threshold 表示触发限流
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60})
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
// 检查是否发送了限流响应
|
// 检查是否发送了限流响应
|
||||||
@@ -459,7 +459,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 模拟 Redis 调用响应
|
// 模拟 Redis 调用响应
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
// 处理流式响应体
|
// 处理流式响应体
|
||||||
@@ -499,7 +499,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 模拟 Redis 调用响应
|
// 模拟 Redis 调用响应
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
// 处理流式响应体
|
// 处理流式响应体
|
||||||
@@ -537,7 +537,7 @@ func TestCompleteFlow(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 2. 模拟 Redis 调用响应
|
// 2. 模拟 Redis 调用响应
|
||||||
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
|
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
// 3. 处理流式响应体
|
// 3. 处理流式响应体
|
||||||
|
|||||||
@@ -46,17 +46,30 @@ func init() {
|
|||||||
const (
|
const (
|
||||||
// RedisKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀
|
// RedisKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀
|
||||||
RedisKeyPrefix = "higress-cluster-key-rate-limit"
|
RedisKeyPrefix = "higress-cluster-key-rate-limit"
|
||||||
// ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数
|
// ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口
|
||||||
ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d"
|
ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d"
|
||||||
// ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值
|
// ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值
|
||||||
ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s"
|
ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s"
|
||||||
FixedWindowScript = `
|
FixedWindowScript = `
|
||||||
local ttl = redis.call('ttl', KEYS[1])
|
local key = KEYS[1]
|
||||||
if ttl < 0 then
|
local threshold = tonumber(ARGV[1])
|
||||||
redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2])
|
local window = tonumber(ARGV[2])
|
||||||
return {ARGV[1], ARGV[1] - 1, ARGV[2]}
|
|
||||||
end
|
local current = tonumber(redis.call('get', key) or "0")
|
||||||
return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl}
|
|
||||||
|
-- 只有超过阈值时才停止累加,达到阈值时仍允许(此时是最后一次允许)
|
||||||
|
if current > threshold then
|
||||||
|
return {threshold, current, redis.call('ttl', key)}
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 计数未超过阈值,执行累加
|
||||||
|
current = redis.call('incr', key)
|
||||||
|
-- 第一次累加时设置过期时间
|
||||||
|
if current == 1 then
|
||||||
|
redis.call('expire', key, window)
|
||||||
|
end
|
||||||
|
|
||||||
|
return {threshold, current, redis.call('ttl', key)}
|
||||||
`
|
`
|
||||||
|
|
||||||
LimitContextKey = "LimitContext" // 限流上下文信息
|
LimitContextKey = "LimitContext" // 限流上下文信息
|
||||||
@@ -92,7 +105,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
|
|||||||
|
|
||||||
if cfg.GlobalThreshold != nil {
|
if cfg.GlobalThreshold != nil {
|
||||||
// 全局限流模式
|
// 全局限流模式
|
||||||
limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count)
|
limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow)
|
||||||
count = cfg.GlobalThreshold.Count
|
count = cfg.GlobalThreshold.Count
|
||||||
timeWindow = cfg.GlobalThreshold.TimeWindow
|
timeWindow = cfg.GlobalThreshold.TimeWindow
|
||||||
} else {
|
} else {
|
||||||
@@ -103,7 +116,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
|
|||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val)
|
limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val)
|
||||||
count = configItem.Count
|
count = configItem.Count
|
||||||
timeWindow = configItem.TimeWindow
|
timeWindow = configItem.TimeWindow
|
||||||
}
|
}
|
||||||
@@ -118,12 +131,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
|
|||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取限流结果
|
||||||
|
threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer()
|
||||||
context := LimitContext{
|
context := LimitContext{
|
||||||
count: resultArray[0].Integer(),
|
count: threshold,
|
||||||
remaining: resultArray[1].Integer(),
|
remaining: threshold - current,
|
||||||
reset: resultArray[2].Integer(),
|
reset: ttl,
|
||||||
}
|
}
|
||||||
if context.remaining < 0 {
|
if current > threshold {
|
||||||
// 触发限流
|
// 触发限流
|
||||||
rejected(cfg, context)
|
rejected(cfg, context)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -15,10 +15,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cluster-key-rate-limit/config"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"cluster-key-rate-limit/config"
|
||||||
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/test"
|
"github.com/higress-group/wasm-go/pkg/test"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -527,9 +528,16 @@ func TestOnHttpRequestHeaders(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 模拟 Redis 调用响应(触发限流)
|
// 模拟 Redis 调用响应(触发限流)
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60})
|
// 当前请求数(1001)超过阈值(1000),触发限流
|
||||||
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
|
// 检查是否发送了限流响应
|
||||||
|
localResponse := host.GetLocalResponse()
|
||||||
|
require.NotNil(t, localResponse)
|
||||||
|
require.Equal(t, uint32(429), localResponse.StatusCode)
|
||||||
|
require.Contains(t, string(localResponse.Data), "Too many requests")
|
||||||
|
|
||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -641,7 +649,7 @@ func TestCompleteFlow(t *testing.T) {
|
|||||||
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
|
||||||
|
|
||||||
// 2. 模拟 Redis 调用响应
|
// 2. 模拟 Redis 调用响应
|
||||||
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
|
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
|
||||||
host.CallOnRedisCall(0, resp)
|
host.CallOnRedisCall(0, resp)
|
||||||
|
|
||||||
// 3. 处理响应头
|
// 3. 处理响应头
|
||||||
|
|||||||
Reference in New Issue
Block a user