fix: Optimization of Rate Limiting Logic for Cluster, AI Token and WASM Plugin (#2997)

This commit is contained in:
韩贤涛
2025-10-15 17:24:42 +08:00
committed by GitHub
parent b026455701
commit 1f301be851
4 changed files with 109 additions and 53 deletions

View File

@@ -45,26 +45,55 @@ func init() {
const (
RedisKeyPrefix string = "higress-token-ratelimit"
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d"
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s"
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d"
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s"
RequestPhaseFixedWindowScript = `
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2])
return {ARGV[1], ARGV[1], ARGV[2]}
end
return {ARGV[1], redis.call('get', KEYS[1]), ttl}
local current = redis.call('get', KEYS[1])
local ttl = redis.call('ttl', KEYS[1])
local threshold = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
-- 键不存在时返回初始状态计数0窗口时间为过期时间
if not current then
return {threshold, 0, window}
end
-- 修复异常过期时间(确保窗口有效)
if ttl < 0 then
ttl = window
end
-- 返回窗口状态:阈值、当前计数、剩余时间
return {threshold, tonumber(current), ttl}
`
ResponsePhaseFixedWindowScript = `
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2])
return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]}
end
return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl}
`
local key = KEYS[1]
local threshold = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local added = tonumber(ARGV[3]) -- 需要累加的token数量
local current = tonumber(redis.call('get', key) or "0")
-- 只有当前计数未超过阈值时才执行累加
if current <= threshold then
current = redis.call('incrby', key, added)
-- 第一次设置值时初始化过期时间
if current == added then
redis.call('expire', key, window)
else
-- 非首次设置时检查过期时间,确保窗口有效性
local ttl = redis.call('ttl', key)
if ttl < 0 then
redis.call('expire', key, window)
end
end
end
-- 返回当前窗口状态:阈值、当前计数、剩余时间
return {threshold, current, redis.call('ttl', key)}
`
LimitRedisContextKey = "LimitRedisContext"
@@ -107,7 +136,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
if cfg.GlobalThreshold != nil {
// 全局限流模式
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count)
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow)
count = cfg.GlobalThreshold.Count
timeWindow = cfg.GlobalThreshold.TimeWindow
} else {
@@ -118,7 +147,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
return types.ActionContinue
}
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val)
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val)
count = configItem.Count
timeWindow = configItem.TimeWindow
}
@@ -139,12 +168,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
proxywasm.ResumeHttpRequest()
return
}
// 获取限流结果
threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer()
context := LimitContext{
count: resultArray[0].Integer(),
remaining: resultArray[1].Integer(),
reset: resultArray[2].Integer(),
count: threshold,
remaining: threshold - current,
reset: ttl,
}
if context.remaining < 0 {
if current > threshold {
// 触发限流
ctx.SetUserAttribute("token_ratelimit_status", "limited")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)

View File

@@ -291,8 +291,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
// 返回 [count, remaining, ttl] 格式
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
// 返回 [threshold, current, ttl] 格式
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
@@ -316,7 +316,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
@@ -339,7 +339,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{50, 49, 60})
resp := test.CreateRedisRespArray([]interface{}{50, 1, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
@@ -363,7 +363,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{200, 199, 60})
resp := test.CreateRedisRespArray([]interface{}{200, 1, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
@@ -387,7 +387,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{75, 74, 60})
resp := test.CreateRedisRespArray([]interface{}{75, 1, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
@@ -410,8 +410,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(触发限流)
// 返回 [count, remaining, ttl] 格式remaining < 0 表示触发限流
resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60})
// 返回 [threshold, current, ttl] 格式current > threshold 表示触发限流
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
host.CallOnRedisCall(0, resp)
// 检查是否发送了限流响应
@@ -459,7 +459,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
})
// 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
host.CallOnRedisCall(0, resp)
// 处理流式响应体
@@ -499,7 +499,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
})
// 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
host.CallOnRedisCall(0, resp)
// 处理流式响应体
@@ -537,7 +537,7 @@ func TestCompleteFlow(t *testing.T) {
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 2. 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
host.CallOnRedisCall(0, resp)
// 3. 处理流式响应体