Files
higress/plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go

675 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"cluster-key-rate-limit/config"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:全局限流配置
var globalThresholdConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-global-limit-rule",
"global_threshold": map[string]interface{}{
"query_per_minute": 1000,
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
"timeout": 1000,
},
"show_limit_quota_header": true,
"rejected_code": 429,
"rejected_msg": "Too many requests",
})
return data
}()
// 测试配置:基于请求参数的限流配置
var paramLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-request-param-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_param": "apikey",
"limit_keys": []map[string]interface{}{
{
"key": "9a342114-ba8a-11ec-b1bf-00163e1250b5",
"query_per_minute": 10,
},
{
"key": "a6a6d7f2-ba8a-11ec-bec2-00163e1250b5",
"query_per_hour": 100,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
})
return data
}()
// 测试配置:基于请求头的限流配置
var headerLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-request-header-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_header": "x-ca-key",
"limit_keys": []map[string]interface{}{
{
"key": "102234",
"query_per_minute": 10,
},
{
"key": "308239",
"query_per_hour": 10,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
})
return data
}()
// 测试配置:基于 Consumer 的限流配置
var consumerLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-consumer-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_consumer": "",
"limit_keys": []map[string]interface{}{
{
"key": "consumer1",
"query_per_second": 10,
},
{
"key": "consumer2",
"query_per_hour": 100,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
})
return data
}()
// 测试配置:基于 Cookie 的限流配置
var cookieLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-cookie-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_cookie": "key1",
"limit_keys": []map[string]interface{}{
{
"key": "value1",
"query_per_minute": 10,
},
{
"key": "value2",
"query_per_hour": 100,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
"rejected_code": 200,
"rejected_msg": `{"code":-1,"msg":"Too many requests"}`,
})
return data
}()
// 测试配置:基于 IP 的限流配置
var ipLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-client-ip-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_per_ip": "from-header-x-forwarded-for",
"limit_keys": []map[string]interface{}{
{
"key": "1.1.1.1",
"query_per_day": 10,
},
{
"key": "1.1.1.0/24",
"query_per_day": 100,
},
{
"key": "0.0.0.0/0",
"query_per_day": 1000,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
})
return data
}()
// 测试配置:正则表达式限流配置
var regexpLimitConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-regexp-limit-rule",
"rule_items": []map[string]interface{}{
{
"limit_by_per_param": "apikey",
"limit_keys": []map[string]interface{}{
{
"key": "regexp:^a.*",
"query_per_second": 10,
},
{
"key": "regexp:^b.*",
"query_per_minute": 100,
},
{
"key": "*",
"query_per_hour": 1000,
},
},
},
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": true,
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试全局限流配置解析
t.Run("global threshold config", func(t *testing.T) {
host, status := test.NewTestHost(globalThresholdConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-global-limit-rule", parsedConfig.RuleName)
require.NotNil(t, parsedConfig.GlobalThreshold)
require.Equal(t, int64(1000), parsedConfig.GlobalThreshold.Count)
require.Equal(t, int64(60), parsedConfig.GlobalThreshold.TimeWindow)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
require.Equal(t, uint32(429), parsedConfig.RejectedCode)
require.Equal(t, "Too many requests", parsedConfig.RejectedMsg)
})
// 测试基于请求参数的限流配置解析
t.Run("param limit config", func(t *testing.T) {
host, status := test.NewTestHost(paramLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-request-param-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByParamType, parsedConfig.RuleItems[0].LimitType)
require.Equal(t, "apikey", parsedConfig.RuleItems[0].Key)
require.Len(t, parsedConfig.RuleItems[0].ConfigItems, 2)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
// 测试基于请求头的限流配置解析
t.Run("header limit config", func(t *testing.T) {
host, status := test.NewTestHost(headerLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-request-header-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByHeaderType, parsedConfig.RuleItems[0].LimitType)
require.Equal(t, "x-ca-key", parsedConfig.RuleItems[0].Key)
require.Len(t, parsedConfig.RuleItems[0].ConfigItems, 2)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
// 测试基于 Consumer 的限流配置解析
t.Run("consumer limit config", func(t *testing.T) {
host, status := test.NewTestHost(consumerLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-consumer-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByConsumerType, parsedConfig.RuleItems[0].LimitType)
require.Len(t, parsedConfig.RuleItems[0].ConfigItems, 2)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
// 测试基于 Cookie 的限流配置解析
t.Run("cookie limit config", func(t *testing.T) {
host, status := test.NewTestHost(cookieLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-cookie-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByCookieType, parsedConfig.RuleItems[0].LimitType)
require.Equal(t, "key1", parsedConfig.RuleItems[0].Key)
require.Len(t, parsedConfig.RuleItems[0].ConfigItems, 2)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
// 测试基于 IP 的限流配置解析
t.Run("ip limit config", func(t *testing.T) {
host, status := test.NewTestHost(ipLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-client-ip-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByPerIpType, parsedConfig.RuleItems[0].LimitType)
require.NotNil(t, parsedConfig.RuleItems[0].LimitByPerIp)
require.Equal(t, config.HeaderSourceType, parsedConfig.RuleItems[0].LimitByPerIp.SourceType)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
// 测试正则表达式限流配置解析
t.Run("regexp limit config", func(t *testing.T) {
host, status := test.NewTestHost(regexpLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
cfg, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
// 验证配置内容
parsedConfig := cfg.(*config.ClusterKeyRateLimitConfig)
require.Equal(t, "routeA-regexp-limit-rule", parsedConfig.RuleName)
require.Len(t, parsedConfig.RuleItems, 1)
require.Equal(t, config.LimitByPerParamType, parsedConfig.RuleItems[0].LimitType)
require.Equal(t, "apikey", parsedConfig.RuleItems[0].Key)
require.Len(t, parsedConfig.RuleItems[0].ConfigItems, 3)
require.Equal(t, config.RegexpType, parsedConfig.RuleItems[0].ConfigItems[0].ConfigType)
require.True(t, parsedConfig.ShowLimitQuotaHeader)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试全局限流请求头处理
t.Run("global threshold request headers", func(t *testing.T) {
host, status := test.NewTestHost(globalThresholdConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
// 模拟 Redis 调用响应(允许请求)
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试基于请求参数的限流请求头处理
t.Run("param limit request headers", func(t *testing.T) {
host, status := test.NewTestHost(paramLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含查询参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5"},
{":method", "GET"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{10, 9, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试基于请求头的限流请求头处理
t.Run("header limit request headers", func(t *testing.T) {
host, status := test.NewTestHost(headerLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含限流键
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
{"x-ca-key", "102234"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{10, 9, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试基于 Consumer 的限流请求头处理
t.Run("consumer limit request headers", func(t *testing.T) {
host, status := test.NewTestHost(consumerLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含 consumer 信息
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
{"x-mse-consumer", "consumer1"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{10, 9, 1})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试基于 Cookie 的限流请求头处理
t.Run("cookie limit request headers", func(t *testing.T) {
host, status := test.NewTestHost(cookieLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含 cookie
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
{"cookie", "key1=value1; other=value"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{10, 9, 60})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试基于 IP 的限流请求头处理
t.Run("ip limit request headers", func(t *testing.T) {
host, status := test.NewTestHost(ipLimitConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含 IP 信息
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
{"x-forwarded-for", "1.1.1.1"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(允许请求)
resp := test.CreateRedisRespArray([]interface{}{10, 9, 86400})
host.CallOnRedisCall(0, resp)
host.CompleteHttp()
})
// 测试限流触发的情况
t.Run("rate limit exceeded", func(t *testing.T) {
host, status := test.NewTestHost(globalThresholdConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 模拟 Redis 调用响应(触发限流)
// 当前请求数(1001)超过阈值(1000),触发限流
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
host.CallOnRedisCall(0, resp)
// 检查是否发送了限流响应
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(429), localResponse.StatusCode)
require.Contains(t, string(localResponse.Data), "Too many requests")
host.CompleteHttp()
})
})
}
func TestOnHttpResponseHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试显示限流配额的响应头处理
t.Run("show limit quota headers", func(t *testing.T) {
host, status := test.NewTestHost(globalThresholdConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先处理请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
})
// 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
host.CallOnRedisCall(0, resp)
// 处理响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回 ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证是否添加了限流配额响应头
responseHeaders := host.GetResponseHeaders()
require.True(t, test.HasHeader(responseHeaders, "x-ratelimit-limit"))
require.True(t, test.HasHeader(responseHeaders, "x-ratelimit-remaining"))
host.CompleteHttp()
})
// 测试不显示限流配额的响应头处理
t.Run("hide limit quota headers", func(t *testing.T) {
// 创建不显示限流配额的配置
hideQuotaConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"rule_name": "routeA-global-limit-rule",
"global_threshold": map[string]interface{}{
"query_per_minute": 1000,
},
"redis": map[string]interface{}{
"service_name": "redis.static",
"service_port": 6379,
},
"show_limit_quota_header": false,
})
return data
}()
host, status := test.NewTestHost(hideQuotaConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先处理请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
})
// 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
host.CallOnRedisCall(0, resp)
// 处理响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回 ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证是否没有添加限流配额响应头
responseHeaders := host.GetResponseHeaders()
require.False(t, test.HasHeader(responseHeaders, "x-ratelimit-limit"))
require.False(t, test.HasHeader(responseHeaders, "x-ratelimit-remaining"))
host.CompleteHttp()
})
})
}
func TestCompleteFlow(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("complete rate limit flow", func(t *testing.T) {
host, status := test.NewTestHost(globalThresholdConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 1. 处理请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/test"},
{":method", "GET"},
})
// 由于需要调用 Redis应该返回 HeaderStopAllIterationAndWatermark
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
// 2. 模拟 Redis 调用响应
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
host.CallOnRedisCall(0, resp)
// 3. 处理响应头
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回 ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证完整的限流流程
responseHeaders := host.GetResponseHeaders()
// 验证是否添加了必要的限流响应头
require.True(t, test.HasHeader(responseHeaders, "x-ratelimit-limit"))
require.True(t, test.HasHeader(responseHeaders, "x-ratelimit-remaining"))
host.CompleteHttp()
})
})
}