add plugin: ai-token-ratelimit (#1015)

This commit is contained in:
rinfx
2024-06-19 13:46:59 +08:00
committed by GitHub
parent 7164653446
commit 1ea87f0e7a
7 changed files with 904 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
main.wasm
config.yaml

View File

@@ -0,0 +1,186 @@
# 功能说明
`ai-token-ratelimit`插件实现了基于特定键值实现token限流键值来源可以是 URL 参数、HTTP 请求头、客户端 IP 地址、consumer 名称、cookie中 key 名称
# 配置说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| ----------------------- | ------ | ---- | ------ |---------------------------------------------------------------------------|
| rule_name | string | 是 | - | 限流规则名称,根据限流规则名称+限流类型+限流key名称+限流key对应的实际值来拼装redis key |
| rule_items | array of object | 是 | - | 限流规则项按照rule_items下的排列顺序匹配第一个rule_item后命中限流规则后续规则将被忽略 |
| rejected_code | int | 否 | 429 | 请求被限流时返回的HTTP状态码 |
| rejected_msg | string | 否 | Too many requests | 请求被限流时,返回的响应体 |
| redis | object | 是 | - | redis相关配置 |
`rule_items`中每一项的配置字段说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| --------------------- | --------------- | -------------------------- | ------ | ------------------------------------------------------------ |
| limit_by_header | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 HTTP 请求头名称 |
| limit_by_param | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 URL 参数名称 |
| limit_by_consumer | string | 否,`limit_by_*`中选填一项 | - | 根据 consumer 名称进行限流,无需添加实际值 |
| limit_by_cookie | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 Cookie中 key 名称 |
| limit_by_per_header | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 HTTP 请求头,并对每个请求头分别计算限流,配置获取限流键值的来源 HTTP 请求头名称,配置`limit_keys`时支持正则表达式或`*` |
| limit_by_per_param | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 URL 参数,并对每个参数分别计算限流,配置获取限流键值的来源 URL 参数名称,配置`limit_keys`时支持正则表达式或`*` |
| limit_by_per_consumer | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 consumer并对每个 consumer 分别计算限流,根据 consumer 名称进行限流,无需添加实际值,配置`limit_keys`时支持正则表达式或`*` |
| limit_by_per_cookie | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 Cookie并对每个 Cookie 分别计算限流,配置获取限流键值的来源 Cookie中 key 名称,配置`limit_keys`时支持正则表达式或`*` |
| limit_by_per_ip | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 IP并对每个 IP 分别计算限流,配置获取限流键值的来源 IP 参数名称,从请求头获取,以`from-header-对应的header名`,示例:`from-header-x-forwarded-for`直接获取对端socket ip配置为`from-remote-addr` |
| limit_keys | array of object | 是 | - | 配置匹配键值后的限流次数 |
`limit_keys`中每一项的配置字段说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| ---------------- | ------ | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
| key | string | 是 | - | 匹配的键值,`limit_by_per_header`,`limit_by_per_param`,`limit_by_per_consumer`,`limit_by_per_cookie` 类型支持配置正则表达式以regexp:开头后面跟正则表达式)或者*(代表所有),正则表达式示例:`regexp:^d.*`以d开头的所有字符串`limit_by_per_ip`支持配置 IP 地址或 IP 段 |
| token_per_second | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每秒请求token数 |
| token_per_minute | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每分钟请求token数 |
| token_per_hour | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每小时请求token数 |
| token_per_day | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每天请求token数 |
`redis`中每一项的配置字段说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- |
| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
| service_port | int | 否 | 服务类型为固定地址static service默认值为80其他为6379 | 输入redis服务的服务端口 |
| username | string | 否 | - | redis用户名 |
| password | string | 否 | - | redis密码 |
| timeout | int | 否 | 1000 | redis连接超时时间单位毫秒 |
# 配置示例
## 识别请求参数 apikey进行区别限流
```yaml
rule_name: default_rule
rule_items:
- limit_by_param: apikey
limit_keys:
- key: 9a342114-ba8a-11ec-b1bf-00163e1250b5
token_per_minute: 10
- key: a6a6d7f2-ba8a-11ec-bec2-00163e1250b5
token_per_hour: 100
- limit_by_per_param: apikey
limit_keys:
# 正则表达式匹配以a开头的所有字符串每个apikey对应的请求10qds
- key: "regexp:^a.*"
token_per_second: 10
# 正则表达式匹配以b开头的所有字符串每个apikey对应的请求100qd
- key: "regexp:^b.*"
token_per_minute: 100
# 兜底用匹配所有请求每个apikey对应的请求1000qdh
- key: "*"
token_per_hour: 1000
redis:
service_name: redis.static
```
## 识别请求头 x-ca-key进行区别限流
```yaml
rule_name: default_rule
rule_items:
- limit_by_header: x-ca-key
limit_keys:
- key: 102234
token_per_minute: 10
- key: 308239
token_per_hour: 10
- limit_by_per_header: x-ca-key
limit_keys:
# 正则表达式匹配以a开头的所有字符串每个apikey对应的请求10qds
- key: "regexp:^a.*"
token_per_second: 10
# 正则表达式匹配以b开头的所有字符串每个apikey对应的请求100qd
- key: "regexp:^b.*"
token_per_minute: 100
# 兜底用匹配所有请求每个apikey对应的请求1000qdh
- key: "*"
token_per_hour: 1000
redis:
service_name: redis.static
```
## 根据请求头 x-forwarded-for 获取对端IP进行区别限流
```yaml
rule_name: default_rule
rule_items:
- limit_by_per_ip: from-header-x-forwarded-for
limit_keys:
# 精确ip
- key: 1.1.1.1
token_per_day: 10
# ip段符合这个ip段的ip每个ip 100qpd
- key: 1.1.1.0/24
token_per_day: 100
# 兜底用即默认每个ip 1000qpd
- key: 0.0.0.0/0
token_per_day: 1000
redis:
service_name: redis.static
```
## 识别consumer进行区别限流
```yaml
rule_name: default_rule
rule_items:
- limit_by_consumer: ''
limit_keys:
- key: consumer1
token_per_second: 10
- key: consumer2
token_per_hour: 100
- limit_by_per_consumer: ''
limit_keys:
# 正则表达式匹配以a开头的所有字符串每个consumer对应的请求10qds
- key: "regexp:^a.*"
token_per_second: 10
# 正则表达式匹配以b开头的所有字符串每个consumer对应的请求100qd
- key: "regexp:^b.*"
token_per_minute: 100
# 兜底用匹配所有请求每个consumer对应的请求1000qdh
- key: "*"
token_per_hour: 1000
redis:
service_name: redis.static
```
## 识别cookie中的键值对进行区别限流
```yaml
rule_name: default_rule
rule_items:
- limit_by_cookie: key1
limit_keys:
- key: value1
token_per_minute: 10
- key: value2
token_per_hour: 100
- limit_by_per_cookie: key1
limit_keys:
# 正则表达式匹配以a开头的所有字符串每个cookie中的value对应的请求10qds
- key: "regexp:^a.*"
token_per_second: 10
# 正则表达式匹配以b开头的所有字符串每个cookie中的value对应的请求100qd
- key: "regexp:^b.*"
token_per_minute: 100
# 兜底用匹配所有请求每个cookie中的value对应的请求1000qdh
- key: "*"
token_per_hour: 1000
rejected_code: 200
rejected_msg: '{"code":-1,"msg":"Too many requests"}'
redis:
service_name: redis.static
```

View File

@@ -0,0 +1,297 @@
package main
import (
"errors"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
re "github.com/wasilibs/go-re2"
"github.com/zmap/go-iptree/iptree"
)
// 限流规则项类型
type limitRuleItemType string
// 限流配置项key类型
type limitConfigItemType string
const (
limitByHeaderType limitRuleItemType = "limit_by_header"
limitByParamType limitRuleItemType = "limit_by_param"
limitByConsumerType limitRuleItemType = "limit_by_consumer"
limitByCookieType limitRuleItemType = "limit_by_cookie"
limitByPerHeaderType limitRuleItemType = "limit_by_per_header"
limitByPerParamType limitRuleItemType = "limit_by_per_param"
limitByPerConsumerType limitRuleItemType = "limit_by_per_consumer"
limitByPerCookieType limitRuleItemType = "limit_by_per_cookie"
limitByPerIpType limitRuleItemType = "limit_by_per_ip"
exactType limitConfigItemType = "exact" // 精确匹配
regexpType limitConfigItemType = "regexp" // 正则表达式
allType limitConfigItemType = "*" // 匹配所有情况
ipNetType limitConfigItemType = "ipNet" // ip段
RemoteAddrSourceType = "remote-addr"
HeaderSourceType = "header"
DefaultRejectedCode uint32 = 429
DefaultRejectedMsg string = "Too many requests"
Second int64 = 1
SecondsPerMinute = 60 * Second
SecondsPerHour = 60 * SecondsPerMinute
SecondsPerDay = 24 * SecondsPerHour
)
var timeWindows = map[string]int64{
"token_per_second": Second,
"token_per_minute": SecondsPerMinute,
"token_per_hour": SecondsPerHour,
"token_per_day": SecondsPerDay,
}
type ClusterKeyRateLimitConfig struct {
ruleName string // 限流规则名称
ruleItems []LimitRuleItem // 限流规则项
showLimitQuotaHeader bool // 响应头中是否显示X-RateLimit-Limit和X-RateLimit-Remaining
rejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码
rejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体
redisClient wrapper.RedisClient
}
type LimitRuleItem struct {
limitType limitRuleItemType // 限流类型
key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值
limitByPerIp LimitByPerIp // 对端ip地址或ip段
configItems []LimitConfigItem // 限流配置项
}
type LimitByPerIp struct {
sourceType string // ip来源类型
headerName string // 根据该请求头获取客户端ip
}
type LimitConfigItem struct {
configType limitConfigItemType // 限流配置项key类型
key string // 限流key
ipNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType
regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType
count int64 // 指定时间窗口内的总请求数量阈值
timeWindow int64 // 时间窗口大小
}
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
redisConfig := json.Get("redis")
if !redisConfig.Exists() {
return errors.New("missing redis in config")
}
serviceName := redisConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("redis service name must not be empty")
}
servicePort := int(redisConfig.Get("service_port").Int())
if servicePort == 0 {
if strings.HasSuffix(serviceName, ".static") {
// use default logic port which is 80 for static service
servicePort = 80
} else {
servicePort = 6379
}
}
username := redisConfig.Get("username").String()
password := redisConfig.Get("password").String()
timeout := int(redisConfig.Get("timeout").Int())
if timeout == 0 {
timeout = 1000
}
config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: int64(servicePort),
})
return config.redisClient.Init(username, password, int64(timeout))
}
func parseClusterKeyRateLimitConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
ruleName := json.Get("rule_name")
if !ruleName.Exists() {
return errors.New("missing rule_name in config")
}
config.ruleName = ruleName.String()
// 初始化ruleItems
err := initRuleItems(json, config)
if err != nil {
return err
}
rejectedCode := json.Get("rejected_code")
if rejectedCode.Exists() {
config.rejectedCode = uint32(rejectedCode.Uint())
} else {
config.rejectedCode = DefaultRejectedCode
}
rejectedMsg := json.Get("rejected_msg")
if rejectedCode.Exists() {
config.rejectedMsg = rejectedMsg.String()
} else {
config.rejectedMsg = DefaultRejectedMsg
}
return nil
}
func initRuleItems(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
ruleItemsResult := json.Get("rule_items")
if !ruleItemsResult.Exists() {
return errors.New("missing rule_items in config")
}
if len(ruleItemsResult.Array()) == 0 {
return errors.New("config rule_items cannot be empty")
}
var ruleItems []LimitRuleItem
for _, item := range ruleItemsResult.Array() {
var ruleItem LimitRuleItem
// 根据配置区分限流类型
var limitType limitRuleItemType
setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr limitRuleItemType) {
if field.Exists() && field.String() != "" {
ruleItem.key = field.String()
limitType = limitTypeStr
}
}
setLimitByKeyIfExists(item.Get("limit_by_header"), limitByHeaderType)
setLimitByKeyIfExists(item.Get("limit_by_param"), limitByParamType)
setLimitByKeyIfExists(item.Get("limit_by_cookie"), limitByCookieType)
setLimitByKeyIfExists(item.Get("limit_by_per_header"), limitByPerHeaderType)
setLimitByKeyIfExists(item.Get("limit_by_per_param"), limitByPerParamType)
setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), limitByPerCookieType)
limitByConsumer := item.Get("limit_by_consumer")
if limitByConsumer.Exists() {
ruleItem.key = ConsumerHeader
limitType = limitByConsumerType
}
limitByPerConsumer := item.Get("limit_by_per_consumer")
if limitByPerConsumer.Exists() {
ruleItem.key = ConsumerHeader
limitType = limitByPerConsumerType
}
limitByPerIpResult := item.Get("limit_by_per_ip")
if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" {
limitByPerIp := limitByPerIpResult.String()
ruleItem.key = limitByPerIp
if strings.HasPrefix(limitByPerIp, "from-header-") {
headerName := limitByPerIp[len("from-header-"):]
if headerName == "" {
return errors.New("limit_by_per_ip parse error: empty after 'from-header-'")
}
ruleItem.limitByPerIp = LimitByPerIp{
sourceType: HeaderSourceType,
headerName: headerName,
}
} else if limitByPerIp == "from-remote-addr" {
ruleItem.limitByPerIp = LimitByPerIp{
sourceType: RemoteAddrSourceType,
headerName: "",
}
} else {
return errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'")
}
limitType = limitByPerIpType
}
if limitType == "" {
return errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set")
}
ruleItem.limitType = limitType
// 初始化configItems
err := initConfigItems(item, &ruleItem)
if err != nil {
return err
}
ruleItems = append(ruleItems, ruleItem)
}
config.ruleItems = ruleItems
return nil
}
func initConfigItems(json gjson.Result, rule *LimitRuleItem) error {
limitKeys := json.Get("limit_keys")
if !limitKeys.Exists() {
return errors.New("missing limit_keys in config")
}
if len(limitKeys.Array()) == 0 {
return errors.New("config limit_keys cannot be empty")
}
var configItems []LimitConfigItem
for _, item := range limitKeys.Array() {
key := item.Get("key")
if !key.Exists() || key.String() == "" {
return errors.New("limit_keys key is required")
}
var (
itemKey = key.String()
itemType limitConfigItemType
ipNet *iptree.IPTree
regexp *re.Regexp
)
if rule.limitType == limitByPerIpType {
var err error
ipNet, err = parseIPNet(itemKey)
if err != nil {
return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err)
}
itemType = ipNetType
} else if rule.limitType == limitByPerHeaderType ||
rule.limitType == limitByPerParamType ||
rule.limitType == limitByPerConsumerType ||
rule.limitType == limitByPerCookieType {
if itemKey == "*" {
itemType = allType
} else if strings.HasPrefix(itemKey, "regexp:") {
regexpStr := itemKey[len("regexp:"):]
var err error
regexp, err = re.Compile(regexpStr)
if err != nil {
return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err)
}
itemType = regexpType
} else {
return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.limitType)
}
} else {
itemType = exactType
}
if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil {
return err
} else if configItem != nil {
configItems = append(configItems, *configItem)
}
}
rule.configItems = configItems
return nil
}
func createConfigItemFromRate(item gjson.Result, itemType limitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) {
for timeWindowKey, duration := range timeWindows {
q := item.Get(timeWindowKey)
if q.Exists() && q.Int() > 0 {
return &LimitConfigItem{
configType: itemType,
key: key,
ipNet: ipNet,
regexp: regexp,
count: q.Int(),
timeWindow: duration,
}, nil
}
}
return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for key: " + key)
}

View File

@@ -0,0 +1,25 @@
module ai-token-ratelimit
go 1.18
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc
github.com/tidwall/gjson v1.14.3
github.com/wasilibs/go-re2 v1.5.3
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
)
require (
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
github.com/tetratelabs/wazero v1.7.1 // indirect
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1
)

View File

@@ -0,0 +1,31 @@
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c h1:wKCSg4rYfwkZrMk7tYY7navjgcHCMZjcgFrCsjLQBmg=
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/wasilibs/go-re2 v1.5.3 h1:wiuTcgDZdLhu8NG8oqF5sF5Q3yIU14lPAvXqeYzDK3g=
github.com/wasilibs/go-re2 v1.5.3/go.mod h1:PzpVPsBdFC7vM8QJbbEnOeTmwA0DGE783d/Gex8eCV8=
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M=
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,303 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
func main() {
wrapper.SetCtx(
"ai-token-ratelimit",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
)
}
const (
ClusterRateLimitFormat string = "higress-token-ratelimit:%s:%s:%s:%s"
RequestPhaseFixedWindowScript string = `
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2])
return {ARGV[1], ARGV[1], ARGV[2]}
end
return {ARGV[1], redis.call('get', KEYS[1]), ttl}
`
ResponsePhaseFixedWindowScript string = `
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2])
return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]}
end
return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl}
`
LimitRedisContextKey string = "LimitRedisContext"
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
CookieHeader string = "cookie"
RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数
RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数
RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回)
)
type LimitContext struct {
count int
remaining int
reset int
}
type LimitRedisContext struct {
key string
count int64
window int64
}
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
err := initRedisClusterClient(json, config)
if err != nil {
return err
}
err = parseClusterKeyRateLimitConfig(json, config)
if err != nil {
return err
}
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
// 判断是否命中限流规则
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
if ruleItem == nil || configItem == nil {
return types.ActionContinue
}
// 构建redis限流key和参数
limitKey := fmt.Sprintf(ClusterRateLimitFormat, config.ruleName, ruleItem.limitType, ruleItem.key, val)
keys := []interface{}{limitKey}
args := []interface{}{configItem.count, configItem.timeWindow}
limitRedisContext := LimitRedisContext{
key: limitKey,
count: configItem.count,
window: configItem.timeWindow,
}
ctx.SetContext(LimitRedisContextKey, limitRedisContext)
// 执行限流逻辑
err := config.redisClient.Eval(RequestPhaseFixedWindowScript, 1, keys, args, func(response resp.Value) {
resultArray := response.Array()
if len(resultArray) != 3 {
log.Errorf("redis response parse error, response: %v", response)
return
}
context := LimitContext{
count: resultArray[0].Integer(),
remaining: resultArray[1].Integer(),
reset: resultArray[2].Integer(),
}
if context.remaining < 0 {
// 触发限流
rejected(config, context)
} else {
proxywasm.ResumeHttpRequest()
}
})
if err != nil {
log.Errorf("redis call failed: %v", err)
return types.ActionContinue
}
return types.ActionPause
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
if !endOfStream {
return data
}
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
if err != nil {
return data
}
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
if err != nil {
return data
}
inputToken, err := strconv.Atoi(string(inputTokenStr))
if err != nil {
return data
}
outputToken, err := strconv.Atoi(string(outputTokenStr))
if err != nil {
return data
}
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok {
return data
}
keys := []interface{}{limitRedisContext.key}
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, func(response resp.Value) {
if response.Error() != nil {
log.Errorf("call Eval error: %v", response.Error())
}
proxywasm.ResumeHttpResponse()
})
if err != nil {
log.Errorf("redis call failed: %v", err)
return data
} else {
return data
}
}
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
for _, rule := range ruleItems {
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
if ruleItem != nil && configItem != nil {
return val, ruleItem, configItem
}
}
return "", nil, nil
}
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
switch rule.limitType {
// 根据HTTP请求头限流
case limitByHeaderType, limitByPerHeaderType:
val, err := proxywasm.GetHttpRequestHeader(rule.key)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据HTTP请求参数限流
case limitByParamType, limitByPerParamType:
parse, err := url.Parse(ctx.Path())
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
}
query, err := url.ParseQuery(parse.RawQuery)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
}
val, ok := query[rule.key]
if !ok {
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
}
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
// 根据consumer限流
case limitByConsumerType, limitByPerConsumerType:
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据cookie中key值限流
case limitByCookieType, limitByPerCookieType:
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
}
val := extractCookieValueByKey(cookie, rule.key)
if val == "" {
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据客户端IP限流
case limitByPerIpType:
realIp, err := getDownStreamIp(rule)
if err != nil {
log.Warnf("failed to get down stream ip: %v", err)
return "", &rule, nil
}
for _, item := range rule.configItems {
if _, found, _ := item.ipNet.Get(realIp); !found {
continue
}
return realIp.String(), &rule, &item
}
}
return "", nil, nil
}
func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
log.Debugf(errMsg, args...)
return "", nil, nil
}
func findMatchingItem(limitType limitRuleItemType, items []LimitConfigItem, key string) *LimitConfigItem {
for _, item := range items {
// per类型,检查allType和regexpType
if limitType == limitByPerHeaderType ||
limitType == limitByPerParamType ||
limitType == limitByPerConsumerType ||
limitType == limitByPerCookieType {
if item.configType == allType || (item.configType == regexpType && item.regexp.MatchString(key)) {
return &item
}
}
// 其他类型,直接比较key
if item.key == key {
return &item
}
}
return nil
}
func getDownStreamIp(rule LimitRuleItem) (net.IP, error) {
var (
realIpStr string
err error
)
if rule.limitByPerIp.sourceType == HeaderSourceType {
realIpStr, err = proxywasm.GetHttpRequestHeader(rule.limitByPerIp.headerName)
if err == nil {
realIpStr = strings.Split(strings.Trim(realIpStr, " "), ",")[0]
}
} else {
var bs []byte
bs, err = proxywasm.GetProperty([]string{"source", "address"})
realIpStr = string(bs)
}
if err != nil {
return nil, err
}
ip := parseIP(realIpStr)
realIP := net.ParseIP(ip)
if realIP == nil {
return nil, fmt.Errorf("invalid ip[%s]", ip)
}
return realIP, nil
}
func rejected(config ClusterKeyRateLimitConfig, context LimitContext) {
headers := make(map[string][]string)
headers[RateLimitResetHeader] = []string{strconv.Itoa(context.reset)}
_ = proxywasm.SendHttpResponse(
config.rejectedCode, reconvertHeaders(headers), []byte(config.rejectedMsg), -1)
}

View File

@@ -0,0 +1,60 @@
package main
import (
"fmt"
"sort"
"strings"
"github.com/zmap/go-iptree/iptree"
)
// parseIPNet 解析Ip段配置
func parseIPNet(key string) (*iptree.IPTree, error) {
tree := iptree.New()
err := tree.AddByString(key, 0)
if err != nil {
return nil, fmt.Errorf("invalid IP[%s]", key)
}
return tree, nil
}
// parseIP 解析IP
func parseIP(source string) string {
if strings.Contains(source, ".") {
// parse ipv4
return strings.Split(source, ":")[0]
}
// parse ipv6
if strings.Contains(source, "]") {
return strings.Split(source, "]")[0][1:]
}
return source
}
// reconvertHeaders headers: map[string][]string -> [][2]string
func reconvertHeaders(hs map[string][]string) [][2]string {
var ret [][2]string
for k, vs := range hs {
for _, v := range vs {
ret = append(ret, [2]string{k, v})
}
}
sort.SliceStable(ret, func(i, j int) bool {
return ret[i][0] < ret[j][0]
})
return ret
}
// extractCookieValueByKey 从cookie中提取key对应的value
func extractCookieValueByKey(cookie string, key string) (value string) {
pairs := strings.Split(cookie, ";")
for _, pair := range pairs {
pair = strings.TrimSpace(pair)
kv := strings.Split(pair, "=")
if kv[0] == key {
value = kv[1]
break
}
}
return value
}