From 6a1557f6ac5d26c2a5c65f805d8aef17dfb1fd39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E8=B4=A4=E6=B6=9B?= <601803023@qq.com> Date: Mon, 28 Jul 2025 08:14:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20ai-token-ratelimit=20support=20setting?= =?UTF-8?q?=20global=20rate=20limit=20thresholds=20for=20routes=E2=80=8B?= =?UTF-8?q?=20(#2667)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../extensions/ai-token-ratelimit/README.md | 53 ++- .../ai-token-ratelimit/README_EN.md | 125 ++++-- .../extensions/ai-token-ratelimit/config.go | 307 -------------- .../ai-token-ratelimit/config/config.go | 392 ++++++++++++++++++ .../ai-token-ratelimit/config/config_test.go | 218 ++++++++++ .../extensions/ai-token-ratelimit/go.mod | 4 + .../extensions/ai-token-ratelimit/go.sum | 2 + .../extensions/ai-token-ratelimit/main.go | 203 ++++----- .../ai-token-ratelimit/util/utils.go | 87 ++++ .../extensions/ai-token-ratelimit/utils.go | 60 --- .../cluster-key-rate-limit/config/config.go | 37 +- .../config/config_test.go | 25 ++ .../extensions/cluster-key-rate-limit/go.sum | 2 - .../extensions/cluster-key-rate-limit/main.go | 30 +- 14 files changed, 981 insertions(+), 564 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/config.go create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/config/config.go create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/config/config_test.go create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/util/utils.go delete mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/utils.go diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/README.md b/plugins/wasm-go/extensions/ai-token-ratelimit/README.md index c3c0c2334..c8e724fad 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/README.md +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/README.md @@ -4,10 +4,12 @@ keywords: [ AI网关, AI token限流 ] description: AI Token限流插件配置参考 --- - ## 功能说明 -`ai-token-ratelimit`插件实现了基于特定键值实现token限流,键值来源可以是 URL 参数、HTTP 请求头、客户端 IP 地址、consumer 名称、cookie中 key 名称 +`ai-token-ratelimit`插件基于 Redis 实现了 AI Token 限流功能,支持以下两种限流模式: + +- **规则级全局限流**:依据相同的`rule_name`与`global_threshold`配置,为自定义规则组设置全局 token 限流阈值 +- **Key 级动态限流**:根据请求中的动态 Key(包括 URL 参数、请求头、客户端 IP、Consumer 名称或 Cookie 字段等)进行分组 token 限流 ## 运行属性 @@ -19,12 +21,22 @@ description: AI Token限流插件配置参考 | 配置项 | 类型 | 必填 | 默认值 | 说明 | | ----------------------- | ------ | ---- | ------ |---------------------------------------------------------------------------| | rule_name | string | 是 | - | 限流规则名称,根据限流规则名称+限流类型+限流key名称+限流key对应的实际值来拼装redis key | -| rule_items | array of object | 是 | - | 限流规则项,按照rule_items下的排列顺序,匹配第一个rule_item后命中限流规则,后续规则将被忽略 | +| global_threshold | Object | 否,`global_threshold` 或 `rule_items` 选填一项 | - | 对整个自定义规则组进行限流 | +| rule_items | array of object | 否,`global_threshold` 或 `rule_items` 选填一项 | - | 限流规则项,按照rule_items下的排列顺序,匹配第一个rule_item后命中限流规则,后续规则将被忽略 | | rejected_code | int | 否 | 429 | 请求被限流时,返回的HTTP状态码 | | rejected_msg | string | 否 | Too many requests | 请求被限流时,返回的响应体 | -| redis | object | 是 | - | redis相关配置 | +| redis | object | 是 | - | redis相关配置 | -`rule_items`中每一项的配置字段说明 +`global_threshold` 中每一项的配置字段说明。 + +| 配置项 | 类型 | 必填 | 默认值 | 说明 | +| ---------------- | ---- | ------------------------------------------------------------ | ------ | --------------------- | +| token_per_second | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每秒请求token数 | +| token_per_minute | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每分钟请求token数 | +| token_per_hour | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每小时请求token数 | +| token_per_day | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每天请求token数 | + +`rule_items`中每一项的配置字段说明。 | 配置项 | 类型 | 必填 | 默认值 | 说明 | | --------------------- | --------------- | -------------------------- | ------ | ------------------------------------------------------------ | @@ -39,7 +51,7 @@ description: AI Token限流插件配置参考 | limit_by_per_ip | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 IP,并对每个 IP 分别计算限流,配置获取限流键值的来源 IP 参数名称,从请求头获取,以`from-header-对应的header名`,示例:`from-header-x-forwarded-for`,直接获取对端socket ip,配置为`from-remote-addr` | | limit_keys | array of object | 是 | - | 配置匹配键值后的限流次数 | -`limit_keys`中每一项的配置字段说明 +`limit_keys`中每一项的配置字段说明。 | 配置项 | 类型 | 必填 | 默认值 | 说明 | | ---------------- | ------ | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | @@ -49,7 +61,7 @@ description: AI Token限流插件配置参考 | token_per_hour | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每小时请求token数 | | token_per_day | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每天请求token数 | -`redis`中每一项的配置字段说明 +`redis`中每一项的配置字段说明。 | 配置项 | 类型 | 必填 | 默认值 | 说明 | | ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- | @@ -63,6 +75,17 @@ description: AI Token限流插件配置参考 ## 配置示例 +### 自定义规则组全局限流 + +```yaml +rule_name: routeA-global-limit-rule +global_threshold: + token_per_minute: 1000 # 自定义规则组每分钟1000个token +redis: + service_name: redis.static +show_limit_quota_header: true +``` + ### 识别请求参数 apikey,进行区别限流 ```yaml @@ -89,8 +112,6 @@ redis: service_name: redis.static ``` - - ### 识别请求头 x-ca-key,进行区别限流 ```yaml @@ -98,7 +119,7 @@ rule_name: default_rule rule_items: - limit_by_header: x-ca-key limit_keys: - - key: 102234 + - key: 102234 token_per_minute: 10 - key: 308239 token_per_hour: 10 @@ -112,13 +133,11 @@ rule_items: token_per_minute: 100 # 兜底用,匹配所有请求,每个apikey对应的请求1000qdh - key: "*" - token_per_hour: 1000 + token_per_hour: 1000 redis: service_name: redis.static ``` - - ### 根据请求头 x-forwarded-for 获取对端IP,进行区别限流 ```yaml @@ -160,13 +179,11 @@ rule_items: token_per_minute: 100 # 兜底用,匹配所有请求,每个consumer对应的请求1000qdh - key: "*" - token_per_hour: 1000 + token_per_hour: 1000 redis: service_name: redis.static ``` - - ### 识别cookie中的键值对,进行区别限流 ```yaml @@ -188,7 +205,7 @@ rule_items: token_per_minute: 100 # 兜底用,匹配所有请求,每个cookie中的value对应的请求1000qdh - key: "*" - token_per_hour: 1000 + token_per_hour: 1000 rejected_code: 200 rejected_msg: '{"code":-1,"msg":"Too many requests"}' redis: @@ -198,6 +215,7 @@ redis: ## 完整示例 AI Token 限流插件依赖 Redis 记录剩余可用的 token 数,因此首先需要部署 Redis 服务。 + ```yaml apiVersion: apps/v1 kind: Deployment @@ -286,6 +304,7 @@ spec: phase: UNSPECIFIED_PHASE priority: 600 ``` + 注意,AI Token 限流插件中的 Redis 配置项 `service_name` 来自 McpBridge 中配置的服务来源,另外我们还需要在 McpBridge 中配置通义千问服务的访问地址。 ```yaml diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/README_EN.md b/plugins/wasm-go/extensions/ai-token-ratelimit/README_EN.md index cf502198e..b6ca17308 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/README_EN.md +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/README_EN.md @@ -3,56 +3,95 @@ title: AI Token Rate Limiting keywords: [ AI Gateway, AI Token Rate Limiting ] description: AI Token Rate Limiting Plugin Configuration Reference --- -## Function Description -The `ai-token-ratelimit` plugin implements token rate limiting based on specific key values. The key values can come from URL parameters, HTTP request headers, client IP addresses, consumer names, or key names in cookies. -## Runtime Attributes -Plugin execution phase: `default phase` +## Function Description + +The `ai-token-ratelimit` plugin implements AI Token rate limiting based on Redis, supporting the following two rate limiting modes: + +- **Rule-level Global Rate Limiting**: Sets a global token rate limit threshold for custom rule groups based on the same `rule_name` and `global_threshold` configurations. +- **Key-level Dynamic Rate Limiting**: Performs grouped token rate limiting based on dynamic keys in requests (including URL parameters, request headers, client IP, Consumer name, or Cookie fields, etc.). + + +## Runtime Properties + +Plugin execution phase: `Default Phase` Plugin execution priority: `600` + ## Configuration Description -| Configuration Item | Type | Required | Default Value | Description | -| ----------------------- | ----------------- | -------- | ------------- | ----------------------------------------------------------------------------- | -| rule_name | string | Yes | - | Name of the rate limiting rule, used to assemble the redis key based on the rule name + rate limiting type + rate limiting key name + actual value corresponding to the rate limiting key | -| rule_items | array of object | Yes | - | Rate limiting rule items. After matching the first rule_item, subsequent rules will be ignored based on the order in `rule_items` | -| rejected_code | int | No | 429 | The HTTP status code returned when the request is rate limited | -| rejected_msg | string | No | Too many requests | The response body returned when the request is rate limited | -| redis | object | Yes | - | Redis related configuration | -Field descriptions for each item in `rule_items` -| Configuration Item | Type | Required | Default Value | Description | -| ------------------------ | ----------------- | --------------------------- | ------------- | ------------------------------------------------------------ | -| limit_by_header | string | No, optionally select one in `limit_by_*` | - | Configure the source HTTP header name for obtaining the rate limiting key value | -| limit_by_param | string | No, optionally select one in `limit_by_*` | - | Configure the source URL parameter name for obtaining the rate limiting key value | -| limit_by_consumer | string | No, optionally select one in `limit_by_*` | - | Rate limit by consumer name, no actual value needs to be added | -| limit_by_cookie | string | No, optionally select one in `limit_by_*` | - | Configure the source key name in cookies for obtaining the rate limiting key value | -| limit_by_per_header | string | No, optionally select one in `limit_by_*` | - | Match specific HTTP request headers according to rules and calculate rate limiting separately for each header. Configure the source HTTP header name for obtaining the rate limiting key value. Supports regular expressions or `*` when configuring `limit_keys` | -| limit_by_per_param | string | No, optionally select one in `limit_by_*` | - | Match specific URL parameters according to rules and calculate rate limiting separately for each parameter. Configure the source URL parameter name for obtaining the rate limiting key value. Supports regular expressions or `*` when configuring `limit_keys` | -| limit_by_per_consumer | string | No, optionally select one in `limit_by_*` | - | Match specific consumers according to rules and calculate rate limiting separately for each consumer. Rate limit by consumer name, no actual value needs to be added. Supports regular expressions or `*` when configuring `limit_keys` | -| limit_by_per_cookie | string | No, optionally select one in `limit_by_*` | - | Match specific cookies according to rules and calculate rate limiting separately for each cookie. Configure the source key name in cookies for obtaining the rate limiting key value. Supports regular expressions or `*` when configuring `limit_keys` | -| limit_by_per_ip | string | No, optionally select one in `limit_by_*` | - | Match specific IPs according to rules and calculate rate limiting separately for each IP. Configure the source IP parameter name for obtaining the rate limiting key value from request headers, `from-header-
`, such as `from-header-x-forwarded-for`. Directly get the remote socket IP by configuring `from-remote-addr` | -| limit_keys | array of object | Yes | - | Configure the number of rate limit requests after matching keys | +| Configuration Item | Type | Required | Default Value | Description | +|--------------------------|----------------|----------|---------------|-------------------------------------------------------------------------------------------------| +| rule_name | string | Yes | - | Name of the rate limiting rule. The Redis key is assembled based on the rate limiting rule name + rate limiting type + rate limiting key name + actual value corresponding to the rate limiting key. | +| global_threshold | Object | No, either `global_threshold` or `rule_items` is required | - | Rate limits the entire custom rule group | +| rule_items | array of object| No, either `global_threshold` or `rule_items` is required | - | Rate limiting rule items. The first matching `rule_item` in the order of `rule_items` triggers the rate limiting rule, and subsequent rules are ignored. | +| rejected_code | int | No | 429 | HTTP status code returned when a request is rate-limited | +| rejected_msg | string | No | Too many requests | Response body returned when a request is rate-limited | +| redis | object | Yes | - | Redis-related configurations | -Field descriptions for each item in `limit_keys` -| Configuration Item | Type | Required | Default Value | Description | -| ----------------------- | ----------------- | ------------------------------------------- | ------------- | ----------------------------------------------- | -| key | string | Yes | - | Matched key value. Types `limit_by_per_header`, `limit_by_per_param`, `limit_by_per_consumer`, `limit_by_per_cookie` support configuring regular expressions (beginning with regexp: followed by the regex) or `*` (representing all). Example regex: `regexp:^d.*` (all strings starting with d); `limit_by_per_ip` supports configuring IP addresses or IP segments | -| token_per_second | int | No, optionally select one in `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` | - | Allowed number of token requests per second | -| token_per_minute | int | No, optionally select one in `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` | - | Allowed number of token requests per minute | -| token_per_hour | int | No, optionally select one in `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` | - | Allowed number of token requests per hour | -| token_per_day | int | No, optionally select one in `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` | - | Allowed number of token requests per day | -Field descriptions for each item in `redis` -| Configuration Item | Type | Required | Default Value | Description | -| ----------------------- | ----------------- | -------- | --------------------------------------------------------------- | ----------------------------------------------- | -| service_name | string | Required | - | Full FQDN name of the redis service, including service type, e.g., my-redis.dns, redis.my-ns.svc.cluster.local | -| service_port | int | No | Default value for static addresses (static service) is 80; otherwise, it is 6379 | Input the service port of the redis service | -| username | string | No | - | Redis username | -| password | string | No | - | Redis password | -| timeout | int | No | 1000 | Redis connection timeout in milliseconds | -| database | int | No | 0 | The database ID used, for example, configured as 1, corresponds to `SELECT 1`. | +### Description of Configuration Fields in `global_threshold` + +| Configuration Item | Type | Required | Default Value | Description | +|-----------------------|------|----------|---------------|-----------------------------------------------| +| token_per_second | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per second | +| token_per_minute | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per minute | +| token_per_hour | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per hour | +| token_per_day | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per day | + + +### Description of Configuration Fields in `rule_items` + +| Configuration Item | Type | Required | Default Value | Description | +|-----------------------------|-----------------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| limit_by_header | string | No, one of `limit_by_*` is required | - | Configures the source of the rate limiting key value as the HTTP request header name | +| limit_by_param | string | No, one of `limit_by_*` is required | - | Configures the source of the rate limiting key value as the URL parameter name | +| limit_by_consumer | string | No, one of `limit_by_*` is required | - | Performs rate limiting based on the consumer name; no actual value needs to be added | +| limit_by_cookie | string | No, one of `limit_by_*` is required | - | Configures the source of the rate limiting key value as the key name in the Cookie | +| limit_by_per_header | string | No, one of `limit_by_*` is required | - | Matches specific HTTP request headers by rule and calculates rate limits for each header separately. Configures the source of the rate limiting key value as the HTTP request header name. Regular expressions or `*` are supported when configuring `limit_keys`. | +| limit_by_per_param | string | No, one of `limit_by_*` is required | - | Matches specific URL parameters by rule and calculates rate limits for each parameter separately. Configures the source of the rate limiting key value as the URL parameter name. Regular expressions or `*` are supported when configuring `limit_keys`. | +| limit_by_per_consumer | string | No, one of `limit_by_*` is required | - | Matches specific consumers by rule and calculates rate limits for each consumer separately. Performs rate limiting based on the consumer name; no actual value needs to be added. Regular expressions or `*` are supported when configuring `limit_keys`. | +| limit_by_per_cookie | string | No, one of `limit_by_*` is required | - | Matches specific Cookies by rule and calculates rate limits for each Cookie separately. Configures the source of the rate limiting key value as the key name in the Cookie. Regular expressions or `*` are supported when configuring `limit_keys`. | +| limit_by_per_ip | string | No, one of `limit_by_*` is required | - | Matches specific IPs by rule and calculates rate limits for each IP separately. Configures the source of the rate limiting key value as the IP parameter name, obtained from the request header in the format `from-header-corresponding_header_name` (e.g., `from-header-x-forwarded-for`), or directly obtains the peer socket IP by configuring `from-remote-addr`. | +| limit_keys | array of object | Yes | - | Configures the rate limiting count after matching the key value | + + +### Description of Configuration Fields in `limit_keys` + +| Configuration Item | Type | Required | Default Value | Description | +|-----------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| key | string | Yes | - | The matched key value. For types `limit_by_per_header`, `limit_by_per_param`, `limit_by_per_consumer`, and `limit_by_per_cookie`, regular expressions (starting with `regexp:` followed by the regular expression, e.g., `regexp:^d.*` for all strings starting with "d") or `*` (representing all) are supported. For `limit_by_per_ip`, IP addresses or IP segments are supported. | +| token_per_second | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per second | +| token_per_minute | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per minute | +| token_per_hour | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per hour | +| token_per_day | int | No, one of `token_per_second`, `token_per_minute`, `token_per_hour`, `token_per_day` is required | - | Allowed number of request tokens per day | + + +### Description of Configuration Fields in `redis` + +| Configuration Item | Type | Required | Default Value | Description | +|--------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------| +| service_name | string | Yes | - | Redis service name, a complete FQDN with service type, e.g., my-redis.dns, redis.my-ns.svc.cluster.local | +| service_port | int | No | 80 for static services, 6379 for others | Enter the service port of the Redis service | +| username | string | No | - | Redis username | +| password | string | No | - | Redis password | +| timeout | int | No | 1000 | Redis connection timeout in milliseconds | +| database | int | No | 0 | The database ID to use, e.g., configuring 1 corresponds to `SELECT 1` | + + +## Configuration Example + +### Custom Rule Group Global Rate Limiting + +```yaml +rule_name: routeA-global-limit-rule +global_threshold: + token_per_minute: 1000 # 1000 tokens per minute for the custom rule group +redis: + service_name: redis.static +show_limit_quota_header: true +``` -## Configuration Examples ### Identify request parameter apikey for differentiated rate limiting ```yaml rule_name: default_rule @@ -83,7 +122,7 @@ rule_name: default_rule rule_items: - limit_by_header: x-ca-key limit_keys: - - key: 102234 + - key: 102234 token_per_minute: 10 - key: 308239 token_per_hour: 10 diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/config.go b/plugins/wasm-go/extensions/ai-token-ratelimit/config.go deleted file mode 100644 index 8f87d952c..000000000 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/config.go +++ /dev/null @@ -1,307 +0,0 @@ -package main - -import ( - "errors" - "fmt" - re "regexp" - "strings" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" - "github.com/zmap/go-iptree/iptree" -) - -// 限流规则项类型 -type limitRuleItemType string - -// 限流配置项key类型 -type limitConfigItemType string - -const ( - limitByHeaderType limitRuleItemType = "limit_by_header" - limitByParamType limitRuleItemType = "limit_by_param" - limitByConsumerType limitRuleItemType = "limit_by_consumer" - limitByCookieType limitRuleItemType = "limit_by_cookie" - limitByPerHeaderType limitRuleItemType = "limit_by_per_header" - limitByPerParamType limitRuleItemType = "limit_by_per_param" - limitByPerConsumerType limitRuleItemType = "limit_by_per_consumer" - limitByPerCookieType limitRuleItemType = "limit_by_per_cookie" - limitByPerIpType limitRuleItemType = "limit_by_per_ip" - - exactType limitConfigItemType = "exact" // 精确匹配 - regexpType limitConfigItemType = "regexp" // 正则表达式 - allType limitConfigItemType = "*" // 匹配所有情况 - ipNetType limitConfigItemType = "ipNet" // ip段 - - RemoteAddrSourceType = "remote-addr" - HeaderSourceType = "header" - - DefaultRejectedCode uint32 = 429 - DefaultRejectedMsg string = "Too many requests" - - Second int64 = 1 - SecondsPerMinute = 60 * Second - SecondsPerHour = 60 * SecondsPerMinute - SecondsPerDay = 24 * SecondsPerHour -) - -var timeWindows = map[string]int64{ - "token_per_second": Second, - "token_per_minute": SecondsPerMinute, - "token_per_hour": SecondsPerHour, - "token_per_day": SecondsPerDay, -} - -type ClusterKeyRateLimitConfig struct { - ruleName string // 限流规则名称 - ruleItems []LimitRuleItem // 限流规则项 - showLimitQuotaHeader bool // 响应头中是否显示X-RateLimit-Limit和X-RateLimit-Remaining - rejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码 - rejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体 - redisClient wrapper.RedisClient - counterMetrics map[string]proxywasm.MetricCounter // Metrics -} - -type LimitRuleItem struct { - limitType limitRuleItemType // 限流类型 - key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值 - limitByPerIp LimitByPerIp // 对端ip地址或ip段 - configItems []LimitConfigItem // 限流配置项 -} - -type LimitByPerIp struct { - sourceType string // ip来源类型 - headerName string // 根据该请求头获取客户端ip -} - -type LimitConfigItem struct { - configType limitConfigItemType // 限流配置项key类型 - key string // 限流key - ipNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType - regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType - count int64 // 指定时间窗口内的总请求数量阈值 - timeWindow int64 // 时间窗口大小 -} - -func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error { - redisConfig := json.Get("redis") - if !redisConfig.Exists() { - return errors.New("missing redis in config") - } - serviceName := redisConfig.Get("service_name").String() - if serviceName == "" { - return errors.New("redis service name must not be empty") - } - servicePort := int(redisConfig.Get("service_port").Int()) - if servicePort == 0 { - if strings.HasSuffix(serviceName, ".static") { - // use default logic port which is 80 for static service - servicePort = 80 - } else { - servicePort = 6379 - } - } - username := redisConfig.Get("username").String() - password := redisConfig.Get("password").String() - timeout := int(redisConfig.Get("timeout").Int()) - if timeout == 0 { - timeout = 1000 - } - config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: serviceName, - Port: int64(servicePort), - }) - database := int(redisConfig.Get("database").Int()) - err := config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database)) - if config.redisClient.Ready() { - log.Info("redis init successfully") - } else { - log.Error("redis init failed, will try later") - } - return err -} - -func parseClusterKeyRateLimitConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error { - ruleName := json.Get("rule_name") - if !ruleName.Exists() { - return errors.New("missing rule_name in config") - } - config.ruleName = ruleName.String() - - // 初始化ruleItems - err := initRuleItems(json, config) - if err != nil { - return err - } - - rejectedCode := json.Get("rejected_code") - if rejectedCode.Exists() { - config.rejectedCode = uint32(rejectedCode.Uint()) - } else { - config.rejectedCode = DefaultRejectedCode - } - rejectedMsg := json.Get("rejected_msg") - if rejectedMsg.Exists() { - config.rejectedMsg = rejectedMsg.String() - } else { - config.rejectedMsg = DefaultRejectedMsg - } - return nil -} - -func initRuleItems(json gjson.Result, config *ClusterKeyRateLimitConfig) error { - ruleItemsResult := json.Get("rule_items") - if !ruleItemsResult.Exists() { - return errors.New("missing rule_items in config") - } - if len(ruleItemsResult.Array()) == 0 { - return errors.New("config rule_items cannot be empty") - } - var ruleItems []LimitRuleItem - for _, item := range ruleItemsResult.Array() { - var ruleItem LimitRuleItem - - // 根据配置区分限流类型 - var limitType limitRuleItemType - setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr limitRuleItemType) { - if field.Exists() && field.String() != "" { - ruleItem.key = field.String() - limitType = limitTypeStr - } - } - setLimitByKeyIfExists(item.Get("limit_by_header"), limitByHeaderType) - setLimitByKeyIfExists(item.Get("limit_by_param"), limitByParamType) - setLimitByKeyIfExists(item.Get("limit_by_cookie"), limitByCookieType) - setLimitByKeyIfExists(item.Get("limit_by_per_header"), limitByPerHeaderType) - setLimitByKeyIfExists(item.Get("limit_by_per_param"), limitByPerParamType) - setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), limitByPerCookieType) - - limitByConsumer := item.Get("limit_by_consumer") - if limitByConsumer.Exists() { - ruleItem.key = ConsumerHeader - limitType = limitByConsumerType - } - limitByPerConsumer := item.Get("limit_by_per_consumer") - if limitByPerConsumer.Exists() { - ruleItem.key = ConsumerHeader - limitType = limitByPerConsumerType - } - - limitByPerIpResult := item.Get("limit_by_per_ip") - if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" { - limitByPerIp := limitByPerIpResult.String() - ruleItem.key = limitByPerIp - if strings.HasPrefix(limitByPerIp, "from-header-") { - headerName := limitByPerIp[len("from-header-"):] - if headerName == "" { - return errors.New("limit_by_per_ip parse error: empty after 'from-header-'") - } - ruleItem.limitByPerIp = LimitByPerIp{ - sourceType: HeaderSourceType, - headerName: headerName, - } - } else if limitByPerIp == "from-remote-addr" { - ruleItem.limitByPerIp = LimitByPerIp{ - sourceType: RemoteAddrSourceType, - headerName: "", - } - } else { - return errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'") - } - limitType = limitByPerIpType - } - - if limitType == "" { - return errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set") - } - ruleItem.limitType = limitType - - // 初始化configItems - err := initConfigItems(item, &ruleItem) - if err != nil { - return err - } - - ruleItems = append(ruleItems, ruleItem) - } - config.ruleItems = ruleItems - return nil -} - -func initConfigItems(json gjson.Result, rule *LimitRuleItem) error { - limitKeys := json.Get("limit_keys") - if !limitKeys.Exists() { - return errors.New("missing limit_keys in config") - } - if len(limitKeys.Array()) == 0 { - return errors.New("config limit_keys cannot be empty") - } - var configItems []LimitConfigItem - for _, item := range limitKeys.Array() { - key := item.Get("key") - if !key.Exists() || key.String() == "" { - return errors.New("limit_keys key is required") - } - - var ( - itemKey = key.String() - itemType limitConfigItemType - ipNet *iptree.IPTree - regexp *re.Regexp - ) - if rule.limitType == limitByPerIpType { - var err error - ipNet, err = parseIPNet(itemKey) - if err != nil { - return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err) - } - itemType = ipNetType - } else if rule.limitType == limitByPerHeaderType || - rule.limitType == limitByPerParamType || - rule.limitType == limitByPerConsumerType || - rule.limitType == limitByPerCookieType { - if itemKey == "*" { - itemType = allType - } else if strings.HasPrefix(itemKey, "regexp:") { - regexpStr := itemKey[len("regexp:"):] - var err error - regexp, err = re.Compile(regexpStr) - if err != nil { - return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err) - } - itemType = regexpType - } else { - return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.limitType) - } - } else { - itemType = exactType - } - - if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil { - return err - } else if configItem != nil { - configItems = append(configItems, *configItem) - } - } - rule.configItems = configItems - return nil -} - -func createConfigItemFromRate(item gjson.Result, itemType limitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) { - for timeWindowKey, duration := range timeWindows { - q := item.Get(timeWindowKey) - if q.Exists() && q.Int() > 0 { - return &LimitConfigItem{ - configType: itemType, - key: key, - ipNet: ipNet, - regexp: regexp, - count: q.Int(), - timeWindow: duration, - }, nil - } - } - return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for key: " + key) -} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/config/config.go b/plugins/wasm-go/extensions/ai-token-ratelimit/config/config.go new file mode 100644 index 000000000..c550f2c98 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/config/config.go @@ -0,0 +1,392 @@ +package config + +import ( + "errors" + "fmt" + re "regexp" + "strings" + + "ai-token-ratelimit/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" + "github.com/zmap/go-iptree/iptree" +) + +// LimitRuleItemType 限流规则项类型 +type LimitRuleItemType string + +// LimitConfigItemType 限流配置项key类型 +type LimitConfigItemType string + +const ( + LimitByHeaderType LimitRuleItemType = "limit_by_header" + LimitByParamType LimitRuleItemType = "limit_by_param" + LimitByConsumerType LimitRuleItemType = "limit_by_consumer" + LimitByCookieType LimitRuleItemType = "limit_by_cookie" + LimitByPerHeaderType LimitRuleItemType = "limit_by_per_header" + LimitByPerParamType LimitRuleItemType = "limit_by_per_param" + LimitByPerConsumerType LimitRuleItemType = "limit_by_per_consumer" + LimitByPerCookieType LimitRuleItemType = "limit_by_per_cookie" + LimitByPerIpType LimitRuleItemType = "limit_by_per_ip" + + ExactType LimitConfigItemType = "exact" // 精确匹配 + RegexpType LimitConfigItemType = "regexp" // 正则表达式 + AllType LimitConfigItemType = "*" // 匹配所有情况 + IpNetType LimitConfigItemType = "ipNet" // ip段 + + RemoteAddrSourceType = "remote-addr" + HeaderSourceType = "header" + + DefaultRejectedCode uint32 = 429 + DefaultRejectedMsg string = "Too many requests" + + Second int64 = 1 + SecondsPerMinute = 60 * Second + SecondsPerHour = 60 * SecondsPerMinute + SecondsPerDay = 24 * SecondsPerHour +) + +var timeWindows = map[string]int64{ + "token_per_second": Second, + "token_per_minute": SecondsPerMinute, + "token_per_hour": SecondsPerHour, + "token_per_day": SecondsPerDay, +} + +type AiTokenRateLimitConfig struct { + RuleName string // 限流规则名称 + GlobalThreshold *GlobalThreshold // 全局限流配置 + RuleItems []LimitRuleItem // 限流规则项 + RejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码 + RejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体 + RedisClient wrapper.RedisClient + CounterMetrics map[string]proxywasm.MetricCounter // Metrics +} + +type GlobalThreshold struct { + Count int64 // 时间窗口内的token数 + TimeWindow int64 // 时间窗口大小(秒) +} + +type LimitRuleItem struct { + LimitType LimitRuleItemType // 限流类型 + Key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值 + LimitByPerIp LimitByPerIp // 对端ip地址或ip段 + ConfigItems []LimitConfigItem // 限流配置项 +} + +type LimitByPerIp struct { + SourceType string // ip来源类型 + HeaderName string // 根据该请求头获取客户端ip +} + +type LimitConfigItem struct { + ConfigType LimitConfigItemType // 限流配置项key类型 + Key string // 限流key + IpNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType + Regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType + Count int64 // 指定时间窗口内的token数 + TimeWindow int64 // 时间窗口大小 +} + +func (cfg *AiTokenRateLimitConfig) IncrementCounter(metricName string, inc uint64) { + if inc == 0 { + return + } + counter, ok := cfg.CounterMetrics[metricName] + if !ok { + counter = proxywasm.DefineCounterMetric(metricName) + cfg.CounterMetrics[metricName] = counter + } + counter.Increment(inc) +} + +func InitRedisClusterClient(json gjson.Result, config *AiTokenRateLimitConfig) error { + redisConfig := json.Get("redis") + if !redisConfig.Exists() { + return errors.New("missing redis in config") + } + + serviceName := redisConfig.Get("service_name").String() + if serviceName == "" { + return errors.New("redis service name must not be empty") + } + + servicePort := int(redisConfig.Get("service_port").Int()) + if servicePort == 0 { + if strings.HasSuffix(serviceName, ".static") { + // use default logic port which is 80 for static service + servicePort = 80 + } else { + servicePort = 6379 + } + } + + username := redisConfig.Get("username").String() + password := redisConfig.Get("password").String() + timeout := int(redisConfig.Get("timeout").Int()) + if timeout == 0 { + timeout = 1000 + } + + config.RedisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: serviceName, + Port: int64(servicePort), + }) + database := int(redisConfig.Get("database").Int()) + err := config.RedisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database)) + if config.RedisClient.Ready() { + log.Info("redis init successfully") + } else { + log.Error("redis init failed, will try later") + } + return err +} + +func ParseAiTokenRateLimitConfig(json gjson.Result, config *AiTokenRateLimitConfig) error { + ruleName := json.Get("rule_name") + if !ruleName.Exists() { + return errors.New("missing rule_name in config") + } + config.RuleName = ruleName.String() + + // 初始化限流规则 + err := initLimitRule(json, config) + if err != nil { + return err + } + + rejectedCode := json.Get("rejected_code") + if rejectedCode.Exists() { + config.RejectedCode = uint32(rejectedCode.Uint()) + } else { + config.RejectedCode = DefaultRejectedCode + } + rejectedMsg := json.Get("rejected_msg") + if rejectedMsg.Exists() { + config.RejectedMsg = rejectedMsg.String() + } else { + config.RejectedMsg = DefaultRejectedMsg + } + return nil +} + +func initLimitRule(json gjson.Result, config *AiTokenRateLimitConfig) error { + globalThresholdResult := json.Get("global_threshold") + ruleItemsResult := json.Get("rule_items") + + hasGlobal := globalThresholdResult.Exists() + hasRule := ruleItemsResult.Exists() + if !hasGlobal && !hasRule { + return errors.New("at least one of 'global_threshold' or 'rule_items' must be set") + } else if hasGlobal && hasRule { + return errors.New("'global_threshold' and 'rule_items' cannot be set at the same time") + } + + // 处理全局限流配置 + if hasGlobal { + threshold, err := parseGlobalThreshold(globalThresholdResult) + if err != nil { + return fmt.Errorf("failed to parse global_threshold: %w", err) + } + config.GlobalThreshold = threshold + return nil + } + + // 处理条件限流规则 + items := ruleItemsResult.Array() + if len(items) == 0 { + return errors.New("config rule_items cannot be empty") + } + + var ruleItems []LimitRuleItem + // 用于记录已出现的LimitType和Key的组合 + seenLimitRules := make(map[string]bool) + + for _, item := range items { + ruleItem, err := parseLimitRuleItem(item) + if err != nil { + return fmt.Errorf("failed to parse rule_item in rule_items: %w", err) + } + + // 构造LimitType和Key的唯一标识 + ruleKey := string(ruleItem.LimitType) + ":" + ruleItem.Key + + // 检查是否有重复的LimitType和Key组合 + if seenLimitRules[ruleKey] { + log.Warnf("duplicate rule found: %s='%s' in rule_items", ruleItem.LimitType, ruleItem.Key) + } else { + seenLimitRules[ruleKey] = true + } + + ruleItems = append(ruleItems, *ruleItem) + } + config.RuleItems = ruleItems + return nil +} + +func parseGlobalThreshold(item gjson.Result) (*GlobalThreshold, error) { + for timeWindowKey, duration := range timeWindows { + q := item.Get(timeWindowKey) + if q.Exists() { + count := q.Int() + if count <= 0 { + return nil, fmt.Errorf("'%s' must be a positive integer, got %d", timeWindowKey, count) + } + return &GlobalThreshold{ + Count: count, + TimeWindow: duration, + }, nil + } + } + return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for global_threshold") +} + +func parseLimitRuleItem(item gjson.Result) (*LimitRuleItem, error) { + var ruleItem LimitRuleItem + + // 根据配置区分限流类型 + var limitType LimitRuleItemType + setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr LimitRuleItemType) { + if field.Exists() && field.String() != "" { + ruleItem.Key = field.String() + limitType = limitTypeStr + } + } + setLimitByKeyIfExists(item.Get("limit_by_header"), LimitByHeaderType) + setLimitByKeyIfExists(item.Get("limit_by_param"), LimitByParamType) + setLimitByKeyIfExists(item.Get("limit_by_cookie"), LimitByCookieType) + setLimitByKeyIfExists(item.Get("limit_by_per_header"), LimitByPerHeaderType) + setLimitByKeyIfExists(item.Get("limit_by_per_param"), LimitByPerParamType) + setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), LimitByPerCookieType) + + limitByConsumer := item.Get("limit_by_consumer") + if limitByConsumer.Exists() { + ruleItem.Key = util.ConsumerHeader + limitType = LimitByConsumerType + } + limitByPerConsumer := item.Get("limit_by_per_consumer") + if limitByPerConsumer.Exists() { + ruleItem.Key = util.ConsumerHeader + limitType = LimitByPerConsumerType + } + + limitByPerIpResult := item.Get("limit_by_per_ip") + if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" { + limitByPerIp := limitByPerIpResult.String() + ruleItem.Key = limitByPerIp + if strings.HasPrefix(limitByPerIp, "from-header-") { + headerName := limitByPerIp[len("from-header-"):] + if headerName == "" { + return nil, errors.New("limit_by_per_ip parse error: empty after 'from-header-'") + } + ruleItem.LimitByPerIp = LimitByPerIp{ + SourceType: HeaderSourceType, + HeaderName: headerName, + } + } else if limitByPerIp == "from-remote-addr" { + ruleItem.LimitByPerIp = LimitByPerIp{ + SourceType: RemoteAddrSourceType, + HeaderName: "", + } + } else { + return nil, errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'") + } + limitType = LimitByPerIpType + } + + if limitType == "" { + return nil, errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set") + } + ruleItem.LimitType = limitType + + // 初始化configItems + err := initConfigItems(item, &ruleItem) + if err != nil { + return nil, err + } + + return &ruleItem, nil +} + +func initConfigItems(json gjson.Result, rule *LimitRuleItem) error { + limitKeys := json.Get("limit_keys") + if !limitKeys.Exists() { + return errors.New("missing limit_keys in config") + } + if len(limitKeys.Array()) == 0 { + return errors.New("config limit_keys cannot be empty") + } + var configItems []LimitConfigItem + for _, item := range limitKeys.Array() { + key := item.Get("key") + if !key.Exists() || key.String() == "" { + return errors.New("limit_keys key is required") + } + + var ( + itemKey = key.String() + itemType LimitConfigItemType + ipNet *iptree.IPTree + regexp *re.Regexp + ) + if rule.LimitType == LimitByPerIpType { + var err error + ipNet, err = util.ParseIPNet(itemKey) + if err != nil { + return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err) + } + itemType = IpNetType + } else if rule.LimitType == LimitByPerHeaderType || + rule.LimitType == LimitByPerParamType || + rule.LimitType == LimitByPerConsumerType || + rule.LimitType == LimitByPerCookieType { + if itemKey == "*" { + itemType = AllType + } else if strings.HasPrefix(itemKey, "regexp:") { + regexpStr := itemKey[len("regexp:"):] + var err error + regexp, err = re.Compile(regexpStr) + if err != nil { + return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err) + } + itemType = RegexpType + } else { + return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.LimitType) + } + } else { + itemType = ExactType + } + + if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil { + return err + } else if configItem != nil { + configItems = append(configItems, *configItem) + } + } + rule.ConfigItems = configItems + return nil +} + +func createConfigItemFromRate(item gjson.Result, itemType LimitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) { + for timeWindowKey, duration := range timeWindows { + q := item.Get(timeWindowKey) + if q.Exists() { + count := q.Int() + if count <= 0 { + return nil, fmt.Errorf("'%s' must be a positive integer for key '%s', got %d", timeWindowKey, key, count) + } + return &LimitConfigItem{ + ConfigType: itemType, + Key: key, + IpNet: ipNet, + Regexp: regexp, + Count: count, + TimeWindow: duration, + }, nil + } + } + return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for key: " + key) +} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/config/config_test.go b/plugins/wasm-go/extensions/ai-token-ratelimit/config/config_test.go new file mode 100644 index 000000000..9a5565e36 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/config/config_test.go @@ -0,0 +1,218 @@ +package config + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" +) + +func TestParseAiTokenRateLimitConfig(t *testing.T) { + tests := []struct { + name string + json string + expected AiTokenRateLimitConfig + expectedErr error + }{ + { + name: "MissingRuleName", + json: `{}`, + expectedErr: errors.New("missing rule_name in config"), + }, + { + name: "GlobalThreshold_InvalidThreshold", + json: `{ + "rule_name": "invalid-threshold", + "global_threshold": { + "token_per_minute": -100 + } + }`, + expectedErr: errors.New("failed to parse global_threshold: 'token_per_minute' must be a positive integer, got -100"), + }, + { + name: "GlobalThreshold_QueryPerSecond", + json: `{ + "rule_name": "global-route-limit", + "global_threshold": { + "token_per_second": 100 + } + }`, + expected: AiTokenRateLimitConfig{ + RuleName: "global-route-limit", + GlobalThreshold: &GlobalThreshold{ + Count: 100, + TimeWindow: Second, + }, + RejectedCode: DefaultRejectedCode, + RejectedMsg: DefaultRejectedMsg, + }, + }, + { + name: "GlobalThreshold_QueryPerMinute", + json: `{ + "rule_name": "global-route-limit", + "global_threshold": { + "token_per_minute": 1000 + } + }`, + expected: AiTokenRateLimitConfig{ + RuleName: "global-route-limit", + GlobalThreshold: &GlobalThreshold{ + Count: 1000, + TimeWindow: SecondsPerMinute, + }, + RejectedCode: DefaultRejectedCode, + RejectedMsg: DefaultRejectedMsg, + }, + }, + { + name: "RuleItems_InvalidThreshold", + json: `{ + "rule_name": "invalid-threshold", + "rule_items": [ + { + "limit_by_header": "x-test", + "limit_keys": [ + {"key": "key1", "token_per_minute": -100} + ] + } + ] + }`, + expectedErr: errors.New("failed to parse rule_item in rule_items: 'token_per_minute' must be a positive integer for key 'key1', got -100"), + }, + { + name: "RuleItems_SingleRule", + json: `{ + "rule_name": "rule-based-limit", + "rule_items": [ + { + "limit_by_header": "x-test", + "limit_keys": [ + {"key": "key1", "token_per_second": 10} + ] + } + ] + }`, + expected: AiTokenRateLimitConfig{ + RuleName: "rule-based-limit", + RuleItems: []LimitRuleItem{ + { + LimitType: LimitByHeaderType, + Key: "x-test", + ConfigItems: []LimitConfigItem{ + { + ConfigType: ExactType, + Key: "key1", + Count: 10, + TimeWindow: Second, + }, + }, + }, + }, + RejectedCode: DefaultRejectedCode, + RejectedMsg: DefaultRejectedMsg, + }, + }, + { + name: "RuleItems_MultipleRules", + json: `{ + "rule_name": "multi-rule-limit", + "rule_items": [ + { + "limit_by_param": "user_id", + "limit_keys": [ + {"key": "123", "token_per_hour": 50} + ] + }, + { + "limit_by_per_cookie": "session_id", + "limit_keys": [ + {"key": "*", "token_per_day": 100} + ] + } + ] + }`, + expected: AiTokenRateLimitConfig{ + RuleName: "multi-rule-limit", + RuleItems: []LimitRuleItem{ + { + LimitType: LimitByParamType, + Key: "user_id", + ConfigItems: []LimitConfigItem{ + { + ConfigType: ExactType, + Key: "123", + Count: 50, + TimeWindow: SecondsPerHour, + }, + }, + }, + { + LimitType: LimitByPerCookieType, + Key: "session_id", + ConfigItems: []LimitConfigItem{ + { + ConfigType: AllType, + Key: "*", + Count: 100, + TimeWindow: SecondsPerDay, + }, + }, + }, + }, + RejectedCode: DefaultRejectedCode, + RejectedMsg: DefaultRejectedMsg, + }, + }, + { + name: "Conflict_GlobalThresholdAndRuleItems", + json: `{ + "rule_name": "test-conflict", + "global_threshold": {"token_per_second": 100}, + "rule_items": [{"limit_by_header": "x-test"}] + }`, + expectedErr: errors.New("'global_threshold' and 'rule_items' cannot be set at the same time"), + }, + { + name: "Missing_GlobalThresholdAndRuleItems", + json: `{ + "rule_name": "test-missing" + }`, + expectedErr: errors.New("at least one of 'global_threshold' or 'rule_items' must be set"), + }, + { + name: "Custom_RejectedCodeAndMessage", + json: `{ + "rule_name": "custom-reject", + "rejected_code": 403, + "rejected_msg": "Forbidden", + "global_threshold": {"token_per_second": 100} + }`, + expected: AiTokenRateLimitConfig{ + RuleName: "custom-reject", + GlobalThreshold: &GlobalThreshold{ + Count: 100, + TimeWindow: Second, + }, + RejectedCode: 403, + RejectedMsg: "Forbidden", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var config AiTokenRateLimitConfig + result := gjson.Parse(tt.json) + err := ParseAiTokenRateLimitConfig(result, &config) + + if tt.expectedErr != nil { + assert.EqualError(t, err, tt.expectedErr.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, config) + } + }) + } +} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod index d793a85d0..3a82ee7f6 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod @@ -7,6 +7,7 @@ toolchain go1.24.4 require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/higress-group/wasm-go v1.0.1 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/resp v0.1.1 github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 @@ -14,7 +15,10 @@ require ( require ( github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum index e10108c5e..a7c15fd23 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum @@ -23,5 +23,7 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M= github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go index 1f9a610b0..5d5669c23 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -21,6 +21,8 @@ import ( "strconv" "strings" + "ai-token-ratelimit/config" + "ai-token-ratelimit/util" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" @@ -42,8 +44,12 @@ func init() { } const ( - ClusterRateLimitFormat string = "higress-token-ratelimit:%s:%s:%d:%d:%s:%s" // ruleName, limitType, timewindow, windowsize, key, val - RequestPhaseFixedWindowScript string = ` + RedisKeyPrefix string = "higress-token-ratelimit" + // AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数 + AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d" + // AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值 + AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s" + RequestPhaseFixedWindowScript = ` local ttl = redis.call('ttl', KEYS[1]) if ttl < 0 then redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2]) @@ -51,7 +57,7 @@ const ( end return {ARGV[1], redis.call('get', KEYS[1]), ttl} ` - ResponsePhaseFixedWindowScript string = ` + ResponsePhaseFixedWindowScript = ` local ttl = redis.call('ttl', KEYS[1]) if ttl < 0 then redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2]) @@ -60,14 +66,11 @@ const ( return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl} ` - LimitRedisContextKey string = "LimitRedisContext" + LimitRedisContextKey = "LimitRedisContext" - ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 - CookieHeader string = "cookie" + CookieHeader = "cookie" - RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数 - RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数 - RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回) + RateLimitResetHeader = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回) TokenRateLimitCount = "token_ratelimit_count" // metric name ) @@ -84,42 +87,52 @@ type LimitRedisContext struct { window int64 } -func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error { - err := initRedisClusterClient(json, config) +func parseConfig(json gjson.Result, cfg *config.AiTokenRateLimitConfig) error { + err := config.InitRedisClusterClient(json, cfg) if err != nil { return err } - err = parseClusterKeyRateLimitConfig(json, config) + err = config.ParseAiTokenRateLimitConfig(json, cfg) if err != nil { return err } // Metric settings - config.counterMetrics = make(map[string]proxywasm.MetricCounter) + cfg.CounterMetrics = make(map[string]proxywasm.MetricCounter) return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitConfig) types.Action { ctx.DisableReroute() - // 判断是否命中限流规则 - val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems) - if ruleItem == nil || configItem == nil { - return types.ActionContinue + limitKey, count, timeWindow := "", int64(0), int64(0) + + if cfg.GlobalThreshold != nil { + // 全局限流模式 + limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count) + count = cfg.GlobalThreshold.Count + timeWindow = cfg.GlobalThreshold.TimeWindow + } else { + // 规则限流模式 + val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, cfg.RuleItems) + if ruleItem == nil || configItem == nil { + // 没有匹配到限流规则直接返回 + return types.ActionContinue + } + + limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val) + count = configItem.Count + timeWindow = configItem.TimeWindow } - // 构建redis限流key和参数 - limitKey := fmt.Sprintf(ClusterRateLimitFormat, config.ruleName, ruleItem.limitType, configItem.timeWindow, configItem.count, ruleItem.key, val) - keys := []interface{}{limitKey} - args := []interface{}{configItem.count, configItem.timeWindow} - - limitRedisContext := LimitRedisContext{ + ctx.SetContext(LimitRedisContextKey, LimitRedisContext{ key: limitKey, - count: configItem.count, - window: configItem.timeWindow, - } - ctx.SetContext(LimitRedisContextKey, limitRedisContext) + count: count, + window: timeWindow, + }) // 执行限流逻辑 - err := config.redisClient.Eval(RequestPhaseFixedWindowScript, 1, keys, args, func(response resp.Value) { + keys := []interface{}{limitKey} + args := []interface{}{count, timeWindow} + err := cfg.RedisClient.Eval(RequestPhaseFixedWindowScript, 1, keys, args, func(response resp.Value) { resultArray := response.Array() if len(resultArray) != 3 { log.Errorf("redis response parse error, response: %v", response) @@ -135,7 +148,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon // 触发限流 ctx.SetUserAttribute("token_ratelimit_status", "limited") ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) - rejected(config, context) + rejected(cfg, context) } else { proxywasm.ResumeHttpRequest() } @@ -147,7 +160,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon return types.HeaderStopAllIterationAndWatermark } -func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte { +func onHttpStreamingBody(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitConfig, data []byte, endOfStream bool) []byte { if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 { ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken) ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken) @@ -164,7 +177,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf } keys := []interface{}{limitRedisContext.key} args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} - err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) + err := cfg.RedisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) if err != nil { log.Errorf("redis call failed: %v", err) } @@ -172,27 +185,29 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf return data } -func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) { - for _, rule := range ruleItems { - val, ruleItem, configItem := hitRateRuleItem(ctx, rule) - if ruleItem != nil && configItem != nil { - return val, ruleItem, configItem +func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []config.LimitRuleItem) (string, *config.LimitRuleItem, *config.LimitConfigItem) { + if len(ruleItems) > 0 { + for _, rule := range ruleItems { + val, ruleItem, configItem := hitRateRuleItem(ctx, rule) + if ruleItem != nil && configItem != nil { + return val, ruleItem, configItem + } } } return "", nil, nil } -func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) { - switch rule.limitType { +func hitRateRuleItem(ctx wrapper.HttpContext, rule config.LimitRuleItem) (string, *config.LimitRuleItem, *config.LimitConfigItem) { + switch rule.LimitType { // 根据HTTP请求头限流 - case limitByHeaderType, limitByPerHeaderType: - val, err := proxywasm.GetHttpRequestHeader(rule.key) + case config.LimitByHeaderType, config.LimitByPerHeaderType: + val, err := proxywasm.GetHttpRequestHeader(rule.Key) if err != nil { - return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err) + return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.Key, err) } - return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + return val, &rule, findMatchingItem(rule.LimitType, rule.ConfigItems, val) // 根据HTTP请求参数限流 - case limitByParamType, limitByPerParamType: + case config.LimitByParamType, config.LimitByPerParamType: parse, err := url.Parse(ctx.Path()) if err != nil { return logDebugAndReturnEmpty("failed to parse request path: %v", err) @@ -201,38 +216,38 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *Limi if err != nil { return logDebugAndReturnEmpty("failed to parse query params: %v", err) } - val, ok := query[rule.key] + val, ok := query[rule.Key] if !ok { - return logDebugAndReturnEmpty("request param %s is empty", rule.key) + return logDebugAndReturnEmpty("request param %s is empty", rule.Key) } - return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0]) + return val[0], &rule, findMatchingItem(rule.LimitType, rule.ConfigItems, val[0]) // 根据consumer限流 - case limitByConsumerType, limitByPerConsumerType: - val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader) + case config.LimitByConsumerType, config.LimitByPerConsumerType: + val, err := proxywasm.GetHttpRequestHeader(util.ConsumerHeader) if err != nil { - return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err) + return logDebugAndReturnEmpty("failed to get request header %s: %v", util.ConsumerHeader, err) } - return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + return val, &rule, findMatchingItem(rule.LimitType, rule.ConfigItems, val) // 根据cookie中key值限流 - case limitByCookieType, limitByPerCookieType: + case config.LimitByCookieType, config.LimitByPerCookieType: cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader) if err != nil { return logDebugAndReturnEmpty("failed to get request cookie : %v", err) } - val := extractCookieValueByKey(cookie, rule.key) + val := util.ExtractCookieValueByKey(cookie, rule.Key) if val == "" { - return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie) + return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.Key, cookie) } - return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + return val, &rule, findMatchingItem(rule.LimitType, rule.ConfigItems, val) // 根据客户端IP限流 - case limitByPerIpType: + case config.LimitByPerIpType: realIp, err := getDownStreamIp(rule) if err != nil { log.Warnf("failed to get down stream ip: %v", err) return "", &rule, nil } - for _, item := range rule.configItems { - if _, found, _ := item.ipNet.Get(realIp); !found { + for _, item := range rule.ConfigItems { + if _, found, _ := item.IpNet.Get(realIp); !found { continue } return realIp.String(), &rule, &item @@ -241,37 +256,37 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *Limi return "", nil, nil } -func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) { +func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *config.LimitRuleItem, *config.LimitConfigItem) { log.Debugf(errMsg, args...) return "", nil, nil } -func findMatchingItem(limitType limitRuleItemType, items []LimitConfigItem, key string) *LimitConfigItem { +func findMatchingItem(limitType config.LimitRuleItemType, items []config.LimitConfigItem, key string) *config.LimitConfigItem { for _, item := range items { // per类型,检查allType和regexpType - if limitType == limitByPerHeaderType || - limitType == limitByPerParamType || - limitType == limitByPerConsumerType || - limitType == limitByPerCookieType { - if item.configType == allType || (item.configType == regexpType && item.regexp.MatchString(key)) { + if limitType == config.LimitByPerHeaderType || + limitType == config.LimitByPerParamType || + limitType == config.LimitByPerConsumerType || + limitType == config.LimitByPerCookieType { + if item.ConfigType == config.AllType || (item.ConfigType == config.RegexpType && item.Regexp.MatchString(key)) { return &item } } // 其他类型,直接比较key - if item.key == key { + if item.Key == key { return &item } } return nil } -func getDownStreamIp(rule LimitRuleItem) (net.IP, error) { +func getDownStreamIp(rule config.LimitRuleItem) (net.IP, error) { var ( realIpStr string err error ) - if rule.limitByPerIp.sourceType == HeaderSourceType { - realIpStr, err = proxywasm.GetHttpRequestHeader(rule.limitByPerIp.headerName) + if rule.LimitByPerIp.SourceType == config.HeaderSourceType { + realIpStr, err = proxywasm.GetHttpRequestHeader(rule.LimitByPerIp.HeaderName) if err == nil { realIpStr = strings.Split(strings.Trim(realIpStr, " "), ",")[0] } @@ -283,7 +298,7 @@ func getDownStreamIp(rule LimitRuleItem) (net.IP, error) { if err != nil { return nil, err } - ip := parseIP(realIpStr) + ip := util.ParseIP(realIpStr) realIP := net.ParseIP(ip) if realIP == nil { return nil, fmt.Errorf("invalid ip[%s]", ip) @@ -291,54 +306,18 @@ func getDownStreamIp(rule LimitRuleItem) (net.IP, error) { return realIP, nil } -func (config *ClusterKeyRateLimitConfig) incrementCounter(metricName string, inc uint64) { - if inc == 0 { - return - } - counter, ok := config.counterMetrics[metricName] - if !ok { - counter = proxywasm.DefineCounterMetric(metricName) - config.counterMetrics[metricName] = counter - } - counter.Increment(inc) -} - func generateMetricName(route, cluster, model, consumer, metricName string) string { return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName) } -func getRouteName() (string, error) { - if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { - return "-", err - } else { - return string(raw), nil - } -} - -func getClusterName() (string, error) { - if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { - return "-", err - } else { - return string(raw), nil - } -} - -func getConsumer() (string, error) { - if consumer, err := proxywasm.GetHttpRequestHeader(ConsumerHeader); err != nil { - return "none", err - } else { - return consumer, nil - } -} - -func rejected(config ClusterKeyRateLimitConfig, context LimitContext) { +func rejected(cfg config.AiTokenRateLimitConfig, context LimitContext) { headers := make(map[string][]string) headers[RateLimitResetHeader] = []string{strconv.Itoa(context.reset)} _ = proxywasm.SendHttpResponseWithDetail( - config.rejectedCode, "ai-token-ratelimit.rejected", reconvertHeaders(headers), []byte(config.rejectedMsg), -1) + cfg.RejectedCode, "ai-token-ratelimit.rejected", util.ReconvertHeaders(headers), []byte(cfg.RejectedMsg), -1) - route, _ := getRouteName() - cluster, _ := getClusterName() - consumer, _ := getConsumer() - config.incrementCounter(generateMetricName(route, cluster, "none", consumer, TokenRateLimitCount), 1) + route, _ := util.GetRouteName() + cluster, _ := util.GetClusterName() + consumer, _ := util.GetConsumer() + cfg.IncrementCounter(generateMetricName(route, cluster, "none", consumer, TokenRateLimitCount), 1) } diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/util/utils.go b/plugins/wasm-go/extensions/ai-token-ratelimit/util/utils.go new file mode 100644 index 000000000..3556a32ed --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/util/utils.go @@ -0,0 +1,87 @@ +package util + +import ( + "fmt" + "sort" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/zmap/go-iptree/iptree" +) + +const ConsumerHeader = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 + +// ParseIPNet 解析Ip段配置 +func ParseIPNet(key string) (*iptree.IPTree, error) { + tree := iptree.New() + err := tree.AddByString(key, 0) + if err != nil { + return nil, fmt.Errorf("invalid IP[%s]", key) + } + return tree, nil +} + +// ParseIP 解析IP +func ParseIP(source string) string { + if strings.Contains(source, ".") { + // parse ipv4 + return strings.Split(source, ":")[0] + } + // parse ipv6 + if strings.Contains(source, "]") { + return strings.Split(source, "]")[0][1:] + } + return source +} + +// ReconvertHeaders headers: map[string][]string -> [][2]string +func ReconvertHeaders(hs map[string][]string) [][2]string { + var ret [][2]string + for k, vs := range hs { + for _, v := range vs { + ret = append(ret, [2]string{k, v}) + } + } + sort.SliceStable(ret, func(i, j int) bool { + return ret[i][0] < ret[j][0] + }) + return ret +} + +// ExtractCookieValueByKey 从cookie中提取key对应的value +func ExtractCookieValueByKey(cookie string, key string) (value string) { + pairs := strings.Split(cookie, ";") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + kv := strings.Split(pair, "=") + if kv[0] == key { + value = kv[1] + break + } + } + return value +} + +func GetRouteName() (string, error) { + if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { + return "-", err + } else { + return string(raw), nil + } +} + +func GetClusterName() (string, error) { + if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { + return "-", err + } else { + return string(raw), nil + } +} + +func GetConsumer() (string, error) { + if consumer, err := proxywasm.GetHttpRequestHeader(ConsumerHeader); err != nil { + return "none", err + } else { + return consumer, nil + } +} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go b/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go deleted file mode 100644 index e1908a26b..000000000 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go +++ /dev/null @@ -1,60 +0,0 @@ -package main - -import ( - "fmt" - "sort" - "strings" - - "github.com/zmap/go-iptree/iptree" -) - -// parseIPNet 解析Ip段配置 -func parseIPNet(key string) (*iptree.IPTree, error) { - tree := iptree.New() - err := tree.AddByString(key, 0) - if err != nil { - return nil, fmt.Errorf("invalid IP[%s]", key) - } - return tree, nil -} - -// parseIP 解析IP -func parseIP(source string) string { - if strings.Contains(source, ".") { - // parse ipv4 - return strings.Split(source, ":")[0] - } - // parse ipv6 - if strings.Contains(source, "]") { - return strings.Split(source, "]")[0][1:] - } - return source -} - -// reconvertHeaders headers: map[string][]string -> [][2]string -func reconvertHeaders(hs map[string][]string) [][2]string { - var ret [][2]string - for k, vs := range hs { - for _, v := range vs { - ret = append(ret, [2]string{k, v}) - } - } - sort.SliceStable(ret, func(i, j int) bool { - return ret[i][0] < ret[j][0] - }) - return ret -} - -// extractCookieValueByKey 从cookie中提取key对应的value -func extractCookieValueByKey(cookie string, key string) (value string) { - pairs := strings.Split(cookie, ";") - for _, pair := range pairs { - pair = strings.TrimSpace(pair) - kv := strings.Split(pair, "=") - if kv[0] == key { - value = kv[1] - break - } - } - return value -} diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config.go b/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config.go index 019ddb214..6e1e8c417 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config.go +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config.go @@ -3,12 +3,11 @@ package config import ( "errors" "fmt" + re "regexp" "strings" "cluster-key-rate-limit/util" - - re "regexp" - + "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/zmap/go-iptree/iptree" @@ -191,11 +190,25 @@ func initLimitRule(json gjson.Result, config *ClusterKeyRateLimitConfig) error { } var ruleItems []LimitRuleItem + // 用于记录已出现的LimitType和Key的组合 + seenLimitRules := make(map[string]bool) + for _, item := range items { ruleItem, err := parseLimitRuleItem(item) if err != nil { return fmt.Errorf("failed to parse rule_item in rule_items: %w", err) } + + // 构造LimitType和Key的唯一标识 + ruleKey := string(ruleItem.LimitType) + ":" + ruleItem.Key + + // 检查是否有重复的LimitType和Key组合 + if seenLimitRules[ruleKey] { + log.Warnf("duplicate rule found: %s='%s' in rule_items", ruleItem.LimitType, ruleItem.Key) + } else { + seenLimitRules[ruleKey] = true + } + ruleItems = append(ruleItems, *ruleItem) } config.RuleItems = ruleItems @@ -205,9 +218,13 @@ func initLimitRule(json gjson.Result, config *ClusterKeyRateLimitConfig) error { func parseGlobalThreshold(item gjson.Result) (*GlobalThreshold, error) { for timeWindowKey, duration := range timeWindows { q := item.Get(timeWindowKey) - if q.Exists() && q.Int() > 0 { + if q.Exists() { + count := q.Int() + if count <= 0 { + return nil, fmt.Errorf("'%s' must be a positive integer, got %d", timeWindowKey, count) + } return &GlobalThreshold{ - Count: q.Int(), + Count: count, TimeWindow: duration, }, nil } @@ -276,7 +293,7 @@ func parseLimitRuleItem(item gjson.Result) (*LimitRuleItem, error) { // 初始化configItems err := initConfigItems(item, &ruleItem) if err != nil { - return nil, fmt.Errorf("failed to init config items: %w", err) + return nil, err } return &ruleItem, nil @@ -344,13 +361,17 @@ func initConfigItems(json gjson.Result, rule *LimitRuleItem) error { func createConfigItemFromRate(item gjson.Result, itemType LimitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) { for timeWindowKey, duration := range timeWindows { q := item.Get(timeWindowKey) - if q.Exists() && q.Int() > 0 { + if q.Exists() { + count := q.Int() + if count <= 0 { + return nil, fmt.Errorf("'%s' must be a positive integer for key '%s', got %d", timeWindowKey, key, count) + } return &LimitConfigItem{ ConfigType: itemType, Key: key, IpNet: ipNet, Regexp: regexp, - Count: q.Int(), + Count: count, TimeWindow: duration, }, nil } diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config_test.go b/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config_test.go index 26bad8fb8..b69d38559 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config_test.go +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/config/config_test.go @@ -20,6 +20,16 @@ func TestParseClusterKeyRateLimitConfig(t *testing.T) { json: `{}`, expectedErr: errors.New("missing rule_name in config"), }, + { + name: "GlobalThreshold_InvalidThreshold", + json: `{ + "rule_name": "invalid-threshold", + "global_threshold": { + "query_per_minute": -100 + } + }`, + expectedErr: errors.New("failed to parse global_threshold: 'query_per_minute' must be a positive integer, got -100"), + }, { name: "GlobalThreshold_QueryPerSecond", json: `{ @@ -56,6 +66,21 @@ func TestParseClusterKeyRateLimitConfig(t *testing.T) { RejectedMsg: DefaultRejectedMsg, }, }, + { + name: "RuleItems_InvalidThreshold", + json: `{ + "rule_name": "invalid-threshold", + "rule_items": [ + { + "limit_by_header": "x-test", + "limit_keys": [ + {"key": "key1", "query_per_minute": -100} + ] + } + ] + }`, + expectedErr: errors.New("failed to parse rule_item in rule_items: 'query_per_minute' must be a positive integer for key 'key1', got -100"), + }, { name: "RuleItems_SingleRule", json: `{ diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum b/plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum index b82bc088f..a7c15fd23 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/go.sum @@ -6,8 +6,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= -github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.1 h1:T1m++qTEANp8+jwE0sxltwtaTKmrHCkLOp1m9N+YeqY= github.com/higress-group/wasm-go v1.0.1/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go b/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go index ac0913abb..400590743 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go @@ -44,12 +44,12 @@ func init() { } const ( - // ClusterKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀 - ClusterKeyPrefix = "higress-cluster-key-rate-limit" - // ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 ClusterKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数 - ClusterGlobalRateLimitFormat = ClusterKeyPrefix + ":%s:global_threshold:%d:%d" - // ClusterRateLimitFormat 规则限流模式 redis key 为 ClusterKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值 - ClusterRateLimitFormat = ClusterKeyPrefix + ":%s:%s:%d:%d:%s:%s" + // RedisKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀 + RedisKeyPrefix = "higress-cluster-key-rate-limit" + // ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数 + ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d" + // ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值 + ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s" FixedWindowScript = ` local ttl = redis.call('ttl', KEYS[1]) if ttl < 0 then @@ -86,24 +86,24 @@ func parseConfig(json gjson.Result, cfg *config.ClusterKeyRateLimitConfig) error return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ClusterKeyRateLimitConfig) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimitConfig) types.Action { ctx.DisableReroute() limitKey, count, timeWindow := "", int64(0), int64(0) - if config.GlobalThreshold != nil { + if cfg.GlobalThreshold != nil { // 全局限流模式 - limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, config.RuleName, config.GlobalThreshold.TimeWindow, config.GlobalThreshold.Count) - count = config.GlobalThreshold.Count - timeWindow = config.GlobalThreshold.TimeWindow + limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count) + count = cfg.GlobalThreshold.Count + timeWindow = cfg.GlobalThreshold.TimeWindow } else { // 规则限流模式 - val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.RuleItems) + val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, cfg.RuleItems) if ruleItem == nil || configItem == nil { // 没有匹配到限流规则直接返回 return types.ActionContinue } - limitKey = fmt.Sprintf(ClusterRateLimitFormat, config.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val) + limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val) count = configItem.Count timeWindow = configItem.TimeWindow } @@ -111,7 +111,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ClusterKeyRateL // 执行限流逻辑 keys := []interface{}{limitKey} args := []interface{}{count, timeWindow} - err := config.RedisClient.Eval(FixedWindowScript, 1, keys, args, func(response resp.Value) { + err := cfg.RedisClient.Eval(FixedWindowScript, 1, keys, args, func(response resp.Value) { resultArray := response.Array() if len(resultArray) != 3 { log.Errorf("redis response parse error, response: %v", response) @@ -125,7 +125,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ClusterKeyRateL } if context.remaining < 0 { // 触发限流 - rejected(config, context) + rejected(cfg, context) } else { ctx.SetContext(LimitContextKey, context) proxywasm.ResumeHttpRequest()