From 1ea87f0e7a61e5abf54b9e218da39dfca2fa412d Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Wed, 19 Jun 2024 13:46:59 +0800 Subject: [PATCH] add plugin: ai-token-ratelimit (#1015) --- .../extensions/ai-token-ratelimit/.gitignore | 2 + .../extensions/ai-token-ratelimit/README.md | 186 +++++++++++ .../extensions/ai-token-ratelimit/config.go | 297 +++++++++++++++++ .../extensions/ai-token-ratelimit/go.mod | 25 ++ .../extensions/ai-token-ratelimit/go.sum | 31 ++ .../extensions/ai-token-ratelimit/main.go | 303 ++++++++++++++++++ .../extensions/ai-token-ratelimit/utils.go | 60 ++++ 7 files changed, 904 insertions(+) create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/README.md create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/config.go create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/go.mod create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/go.sum create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/main.go create mode 100644 plugins/wasm-go/extensions/ai-token-ratelimit/utils.go diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore b/plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore new file mode 100644 index 000000000..32841a534 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/.gitignore @@ -0,0 +1,2 @@ +main.wasm +config.yaml \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/README.md b/plugins/wasm-go/extensions/ai-token-ratelimit/README.md new file mode 100644 index 000000000..740191454 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/README.md @@ -0,0 +1,186 @@ +# 功能说明 + +`ai-token-ratelimit`插件实现了基于特定键值实现token限流,键值来源可以是 URL 参数、HTTP 请求头、客户端 IP 地址、consumer 名称、cookie中 key 名称 + + + +# 配置说明 + +| 配置项 | 类型 | 必填 | 默认值 | 说明 | +| ----------------------- | ------ | ---- | ------ |---------------------------------------------------------------------------| +| rule_name | string | 是 | - | 限流规则名称,根据限流规则名称+限流类型+限流key名称+限流key对应的实际值来拼装redis key | +| rule_items | array of object | 是 | - | 限流规则项,按照rule_items下的排列顺序,匹配第一个rule_item后命中限流规则,后续规则将被忽略 | +| rejected_code | int | 否 | 429 | 请求被限流时,返回的HTTP状态码 | +| rejected_msg | string | 否 | Too many requests | 请求被限流时,返回的响应体 | +| redis | object | 是 | - | redis相关配置 | + +`rule_items`中每一项的配置字段说明 + +| 配置项 | 类型 | 必填 | 默认值 | 说明 | +| --------------------- | --------------- | -------------------------- | ------ | ------------------------------------------------------------ | +| limit_by_header | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 HTTP 请求头名称 | +| limit_by_param | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 URL 参数名称 | +| limit_by_consumer | string | 否,`limit_by_*`中选填一项 | - | 根据 consumer 名称进行限流,无需添加实际值 | +| limit_by_cookie | string | 否,`limit_by_*`中选填一项 | - | 配置获取限流键值的来源 Cookie中 key 名称 | +| limit_by_per_header | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 HTTP 请求头,并对每个请求头分别计算限流,配置获取限流键值的来源 HTTP 请求头名称,配置`limit_keys`时支持正则表达式或`*` | +| limit_by_per_param | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 URL 参数,并对每个参数分别计算限流,配置获取限流键值的来源 URL 参数名称,配置`limit_keys`时支持正则表达式或`*` | +| limit_by_per_consumer | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 consumer,并对每个 consumer 分别计算限流,根据 consumer 名称进行限流,无需添加实际值,配置`limit_keys`时支持正则表达式或`*` | +| limit_by_per_cookie | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 Cookie,并对每个 Cookie 分别计算限流,配置获取限流键值的来源 Cookie中 key 名称,配置`limit_keys`时支持正则表达式或`*` | +| limit_by_per_ip | string | 否,`limit_by_*`中选填一项 | - | 按规则匹配特定 IP,并对每个 IP 分别计算限流,配置获取限流键值的来源 IP 参数名称,从请求头获取,以`from-header-对应的header名`,示例:`from-header-x-forwarded-for`,直接获取对端socket ip,配置为`from-remote-addr` | +| limit_keys | array of object | 是 | - | 配置匹配键值后的限流次数 | + +`limit_keys`中每一项的配置字段说明 + +| 配置项 | 类型 | 必填 | 默认值 | 说明 | +| ---------------- | ------ | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | +| key | string | 是 | - | 匹配的键值,`limit_by_per_header`,`limit_by_per_param`,`limit_by_per_consumer`,`limit_by_per_cookie` 类型支持配置正则表达式(以regexp:开头后面跟正则表达式)或者*(代表所有),正则表达式示例:`regexp:^d.*`(以d开头的所有字符串);`limit_by_per_ip`支持配置 IP 地址或 IP 段 | +| token_per_second | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每秒请求token数 | +| token_per_minute | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每分钟请求token数 | +| token_per_hour | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每小时请求token数 | +| token_per_day | int | 否,`token_per_second`,`token_per_minute`,`token_per_hour`,`token_per_day` 中选填一项 | - | 允许每天请求token数 | + +`redis`中每一项的配置字段说明 + +| 配置项 | 类型 | 必填 | 默认值 | 说明 | +| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- | +| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local | +| service_port | int | 否 | 服务类型为固定地址(static service)默认值为80,其他为6379 | 输入redis服务的服务端口 | +| username | string | 否 | - | redis用户名 | +| password | string | 否 | - | redis密码 | +| timeout | int | 否 | 1000 | redis连接超时时间,单位毫秒 | + + + +# 配置示例 + +## 识别请求参数 apikey,进行区别限流 + +```yaml +rule_name: default_rule +rule_items: + - limit_by_param: apikey + limit_keys: + - key: 9a342114-ba8a-11ec-b1bf-00163e1250b5 + token_per_minute: 10 + - key: a6a6d7f2-ba8a-11ec-bec2-00163e1250b5 + token_per_hour: 100 + - limit_by_per_param: apikey + limit_keys: + # 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds + - key: "regexp:^a.*" + token_per_second: 10 + # 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd + - key: "regexp:^b.*" + token_per_minute: 100 + # 兜底用,匹配所有请求,每个apikey对应的请求1000qdh + - key: "*" + token_per_hour: 1000 +redis: + service_name: redis.static +``` + + + +## 识别请求头 x-ca-key,进行区别限流 + +```yaml +rule_name: default_rule +rule_items: + - limit_by_header: x-ca-key + limit_keys: + - key: 102234 + token_per_minute: 10 + - key: 308239 + token_per_hour: 10 + - limit_by_per_header: x-ca-key + limit_keys: + # 正则表达式,匹配以a开头的所有字符串,每个apikey对应的请求10qds + - key: "regexp:^a.*" + token_per_second: 10 + # 正则表达式,匹配以b开头的所有字符串,每个apikey对应的请求100qd + - key: "regexp:^b.*" + token_per_minute: 100 + # 兜底用,匹配所有请求,每个apikey对应的请求1000qdh + - key: "*" + token_per_hour: 1000 +redis: + service_name: redis.static +``` + + + +## 根据请求头 x-forwarded-for 获取对端IP,进行区别限流 + +```yaml +rule_name: default_rule +rule_items: + - limit_by_per_ip: from-header-x-forwarded-for + limit_keys: + # 精确ip + - key: 1.1.1.1 + token_per_day: 10 + # ip段,符合这个ip段的ip,每个ip 100qpd + - key: 1.1.1.0/24 + token_per_day: 100 + # 兜底用,即默认每个ip 1000qpd + - key: 0.0.0.0/0 + token_per_day: 1000 +redis: + service_name: redis.static +``` + +## 识别consumer,进行区别限流 + +```yaml +rule_name: default_rule +rule_items: + - limit_by_consumer: '' + limit_keys: + - key: consumer1 + token_per_second: 10 + - key: consumer2 + token_per_hour: 100 + - limit_by_per_consumer: '' + limit_keys: + # 正则表达式,匹配以a开头的所有字符串,每个consumer对应的请求10qds + - key: "regexp:^a.*" + token_per_second: 10 + # 正则表达式,匹配以b开头的所有字符串,每个consumer对应的请求100qd + - key: "regexp:^b.*" + token_per_minute: 100 + # 兜底用,匹配所有请求,每个consumer对应的请求1000qdh + - key: "*" + token_per_hour: 1000 +redis: + service_name: redis.static +``` + + + +## 识别cookie中的键值对,进行区别限流 + +```yaml +rule_name: default_rule +rule_items: + - limit_by_cookie: key1 + limit_keys: + - key: value1 + token_per_minute: 10 + - key: value2 + token_per_hour: 100 + - limit_by_per_cookie: key1 + limit_keys: + # 正则表达式,匹配以a开头的所有字符串,每个cookie中的value对应的请求10qds + - key: "regexp:^a.*" + token_per_second: 10 + # 正则表达式,匹配以b开头的所有字符串,每个cookie中的value对应的请求100qd + - key: "regexp:^b.*" + token_per_minute: 100 + # 兜底用,匹配所有请求,每个cookie中的value对应的请求1000qdh + - key: "*" + token_per_hour: 1000 +rejected_code: 200 +rejected_msg: '{"code":-1,"msg":"Too many requests"}' +redis: + service_name: redis.static +``` \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/config.go b/plugins/wasm-go/extensions/ai-token-ratelimit/config.go new file mode 100644 index 000000000..9668f1861 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/config.go @@ -0,0 +1,297 @@ +package main + +import ( + "errors" + "fmt" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" + re "github.com/wasilibs/go-re2" + "github.com/zmap/go-iptree/iptree" +) + +// 限流规则项类型 +type limitRuleItemType string + +// 限流配置项key类型 +type limitConfigItemType string + +const ( + limitByHeaderType limitRuleItemType = "limit_by_header" + limitByParamType limitRuleItemType = "limit_by_param" + limitByConsumerType limitRuleItemType = "limit_by_consumer" + limitByCookieType limitRuleItemType = "limit_by_cookie" + limitByPerHeaderType limitRuleItemType = "limit_by_per_header" + limitByPerParamType limitRuleItemType = "limit_by_per_param" + limitByPerConsumerType limitRuleItemType = "limit_by_per_consumer" + limitByPerCookieType limitRuleItemType = "limit_by_per_cookie" + limitByPerIpType limitRuleItemType = "limit_by_per_ip" + + exactType limitConfigItemType = "exact" // 精确匹配 + regexpType limitConfigItemType = "regexp" // 正则表达式 + allType limitConfigItemType = "*" // 匹配所有情况 + ipNetType limitConfigItemType = "ipNet" // ip段 + + RemoteAddrSourceType = "remote-addr" + HeaderSourceType = "header" + + DefaultRejectedCode uint32 = 429 + DefaultRejectedMsg string = "Too many requests" + + Second int64 = 1 + SecondsPerMinute = 60 * Second + SecondsPerHour = 60 * SecondsPerMinute + SecondsPerDay = 24 * SecondsPerHour +) + +var timeWindows = map[string]int64{ + "token_per_second": Second, + "token_per_minute": SecondsPerMinute, + "token_per_hour": SecondsPerHour, + "token_per_day": SecondsPerDay, +} + +type ClusterKeyRateLimitConfig struct { + ruleName string // 限流规则名称 + ruleItems []LimitRuleItem // 限流规则项 + showLimitQuotaHeader bool // 响应头中是否显示X-RateLimit-Limit和X-RateLimit-Remaining + rejectedCode uint32 // 当请求超过阈值被拒绝时,返回的HTTP状态码 + rejectedMsg string // 当请求超过阈值被拒绝时,返回的响应体 + redisClient wrapper.RedisClient +} + +type LimitRuleItem struct { + limitType limitRuleItemType // 限流类型 + key string // 根据该key值进行限流,limit_by_consumer和limit_by_per_consumer两种类型为ConsumerHeader,其他类型为对应的key值 + limitByPerIp LimitByPerIp // 对端ip地址或ip段 + configItems []LimitConfigItem // 限流配置项 +} + +type LimitByPerIp struct { + sourceType string // ip来源类型 + headerName string // 根据该请求头获取客户端ip +} + +type LimitConfigItem struct { + configType limitConfigItemType // 限流配置项key类型 + key string // 限流key + ipNet *iptree.IPTree // 限流key转换的ip地址或者ip段,仅用于itemType为ipNetType + regexp *re.Regexp // 正则表达式,仅用于itemType为regexpType + count int64 // 指定时间窗口内的总请求数量阈值 + timeWindow int64 // 时间窗口大小 +} + +func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error { + redisConfig := json.Get("redis") + if !redisConfig.Exists() { + return errors.New("missing redis in config") + } + serviceName := redisConfig.Get("service_name").String() + if serviceName == "" { + return errors.New("redis service name must not be empty") + } + servicePort := int(redisConfig.Get("service_port").Int()) + if servicePort == 0 { + if strings.HasSuffix(serviceName, ".static") { + // use default logic port which is 80 for static service + servicePort = 80 + } else { + servicePort = 6379 + } + } + username := redisConfig.Get("username").String() + password := redisConfig.Get("password").String() + timeout := int(redisConfig.Get("timeout").Int()) + if timeout == 0 { + timeout = 1000 + } + config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: serviceName, + Port: int64(servicePort), + }) + return config.redisClient.Init(username, password, int64(timeout)) +} + +func parseClusterKeyRateLimitConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error { + ruleName := json.Get("rule_name") + if !ruleName.Exists() { + return errors.New("missing rule_name in config") + } + config.ruleName = ruleName.String() + + // 初始化ruleItems + err := initRuleItems(json, config) + if err != nil { + return err + } + + rejectedCode := json.Get("rejected_code") + if rejectedCode.Exists() { + config.rejectedCode = uint32(rejectedCode.Uint()) + } else { + config.rejectedCode = DefaultRejectedCode + } + rejectedMsg := json.Get("rejected_msg") + if rejectedCode.Exists() { + config.rejectedMsg = rejectedMsg.String() + } else { + config.rejectedMsg = DefaultRejectedMsg + } + return nil +} + +func initRuleItems(json gjson.Result, config *ClusterKeyRateLimitConfig) error { + ruleItemsResult := json.Get("rule_items") + if !ruleItemsResult.Exists() { + return errors.New("missing rule_items in config") + } + if len(ruleItemsResult.Array()) == 0 { + return errors.New("config rule_items cannot be empty") + } + var ruleItems []LimitRuleItem + for _, item := range ruleItemsResult.Array() { + var ruleItem LimitRuleItem + + // 根据配置区分限流类型 + var limitType limitRuleItemType + setLimitByKeyIfExists := func(field gjson.Result, limitTypeStr limitRuleItemType) { + if field.Exists() && field.String() != "" { + ruleItem.key = field.String() + limitType = limitTypeStr + } + } + setLimitByKeyIfExists(item.Get("limit_by_header"), limitByHeaderType) + setLimitByKeyIfExists(item.Get("limit_by_param"), limitByParamType) + setLimitByKeyIfExists(item.Get("limit_by_cookie"), limitByCookieType) + setLimitByKeyIfExists(item.Get("limit_by_per_header"), limitByPerHeaderType) + setLimitByKeyIfExists(item.Get("limit_by_per_param"), limitByPerParamType) + setLimitByKeyIfExists(item.Get("limit_by_per_cookie"), limitByPerCookieType) + + limitByConsumer := item.Get("limit_by_consumer") + if limitByConsumer.Exists() { + ruleItem.key = ConsumerHeader + limitType = limitByConsumerType + } + limitByPerConsumer := item.Get("limit_by_per_consumer") + if limitByPerConsumer.Exists() { + ruleItem.key = ConsumerHeader + limitType = limitByPerConsumerType + } + + limitByPerIpResult := item.Get("limit_by_per_ip") + if limitByPerIpResult.Exists() && limitByPerIpResult.String() != "" { + limitByPerIp := limitByPerIpResult.String() + ruleItem.key = limitByPerIp + if strings.HasPrefix(limitByPerIp, "from-header-") { + headerName := limitByPerIp[len("from-header-"):] + if headerName == "" { + return errors.New("limit_by_per_ip parse error: empty after 'from-header-'") + } + ruleItem.limitByPerIp = LimitByPerIp{ + sourceType: HeaderSourceType, + headerName: headerName, + } + } else if limitByPerIp == "from-remote-addr" { + ruleItem.limitByPerIp = LimitByPerIp{ + sourceType: RemoteAddrSourceType, + headerName: "", + } + } else { + return errors.New("the 'limit_by_per_ip' restriction must start with 'from-header-' or be exactly 'from-remote-addr'") + } + limitType = limitByPerIpType + } + + if limitType == "" { + return errors.New("only one of 'limit_by_header' and 'limit_by_param' and 'limit_by_consumer' and 'limit_by_cookie' and 'limit_by_per_header' and 'limit_by_per_param' and 'limit_by_per_consumer' and 'limit_by_per_cookie' and 'limit_by_per_ip' can be set") + } + ruleItem.limitType = limitType + + // 初始化configItems + err := initConfigItems(item, &ruleItem) + if err != nil { + return err + } + + ruleItems = append(ruleItems, ruleItem) + } + config.ruleItems = ruleItems + return nil +} + +func initConfigItems(json gjson.Result, rule *LimitRuleItem) error { + limitKeys := json.Get("limit_keys") + if !limitKeys.Exists() { + return errors.New("missing limit_keys in config") + } + if len(limitKeys.Array()) == 0 { + return errors.New("config limit_keys cannot be empty") + } + var configItems []LimitConfigItem + for _, item := range limitKeys.Array() { + key := item.Get("key") + if !key.Exists() || key.String() == "" { + return errors.New("limit_keys key is required") + } + + var ( + itemKey = key.String() + itemType limitConfigItemType + ipNet *iptree.IPTree + regexp *re.Regexp + ) + if rule.limitType == limitByPerIpType { + var err error + ipNet, err = parseIPNet(itemKey) + if err != nil { + return fmt.Errorf("failed to parse IPNet for key '%s': %w", itemKey, err) + } + itemType = ipNetType + } else if rule.limitType == limitByPerHeaderType || + rule.limitType == limitByPerParamType || + rule.limitType == limitByPerConsumerType || + rule.limitType == limitByPerCookieType { + if itemKey == "*" { + itemType = allType + } else if strings.HasPrefix(itemKey, "regexp:") { + regexpStr := itemKey[len("regexp:"):] + var err error + regexp, err = re.Compile(regexpStr) + if err != nil { + return fmt.Errorf("failed to compile regex for key '%s': %w", itemKey, err) + } + itemType = regexpType + } else { + return fmt.Errorf("the '%s' restriction must start with 'regexp:' or be exactly '*'", rule.limitType) + } + } else { + itemType = exactType + } + + if configItem, err := createConfigItemFromRate(item, itemType, itemKey, ipNet, regexp); err != nil { + return err + } else if configItem != nil { + configItems = append(configItems, *configItem) + } + } + rule.configItems = configItems + return nil +} + +func createConfigItemFromRate(item gjson.Result, itemType limitConfigItemType, key string, ipNet *iptree.IPTree, regexp *re.Regexp) (*LimitConfigItem, error) { + for timeWindowKey, duration := range timeWindows { + q := item.Get(timeWindowKey) + if q.Exists() && q.Int() > 0 { + return &LimitConfigItem{ + configType: itemType, + key: key, + ipNet: ipNet, + regexp: regexp, + count: q.Int(), + timeWindow: duration, + }, nil + } + } + return nil, errors.New("one of 'token_per_second', 'token_per_minute', 'token_per_hour', or 'token_per_day' must be set for key: " + key) +} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod new file mode 100644 index 000000000..74a576d48 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod @@ -0,0 +1,25 @@ +module ai-token-ratelimit + +go 1.18 + +require ( + github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc + github.com/tidwall/gjson v1.14.3 + github.com/wasilibs/go-re2 v1.5.3 + github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 +) + +require ( + github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect + github.com/tetratelabs/wazero v1.7.1 // indirect +) + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect + github.com/magefile/mage v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/resp v0.1.1 +) diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum new file mode 100644 index 000000000..5800e0dc8 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum @@ -0,0 +1,31 @@ +github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c h1:wKCSg4rYfwkZrMk7tYY7navjgcHCMZjcgFrCsjLQBmg= +github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk= +github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw= +github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= +github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8= +github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/wasilibs/go-re2 v1.5.3 h1:wiuTcgDZdLhu8NG8oqF5sF5Q3yIU14lPAvXqeYzDK3g= +github.com/wasilibs/go-re2 v1.5.3/go.mod h1:PzpVPsBdFC7vM8QJbbEnOeTmwA0DGE783d/Gex8eCV8= +github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ= +github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M= +github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go new file mode 100644 index 000000000..a74019f63 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -0,0 +1,303 @@ +// Copyright (c) 2024 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "net" + "net/url" + "strconv" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" + "github.com/tidwall/resp" +) + +func main() { + wrapper.SetCtx( + "ai-token-ratelimit", + wrapper.ParseConfigBy(parseConfig), + wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), + wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), + ) +} + +const ( + ClusterRateLimitFormat string = "higress-token-ratelimit:%s:%s:%s:%s" + RequestPhaseFixedWindowScript string = ` + local ttl = redis.call('ttl', KEYS[1]) + if ttl < 0 then + redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2]) + return {ARGV[1], ARGV[1], ARGV[2]} + end + return {ARGV[1], redis.call('get', KEYS[1]), ttl} + ` + ResponsePhaseFixedWindowScript string = ` + local ttl = redis.call('ttl', KEYS[1]) + if ttl < 0 then + redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2]) + return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]} + end + return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl} + ` + + LimitRedisContextKey string = "LimitRedisContext" + + ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 + CookieHeader string = "cookie" + + RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数 + RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数 + RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回) +) + +type LimitContext struct { + count int + remaining int + reset int +} + +type LimitRedisContext struct { + key string + count int64 + window int64 +} + +func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error { + err := initRedisClusterClient(json, config) + if err != nil { + return err + } + err = parseClusterKeyRateLimitConfig(json, config) + if err != nil { + return err + } + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action { + // 判断是否命中限流规则 + val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log) + if ruleItem == nil || configItem == nil { + return types.ActionContinue + } + + // 构建redis限流key和参数 + limitKey := fmt.Sprintf(ClusterRateLimitFormat, config.ruleName, ruleItem.limitType, ruleItem.key, val) + keys := []interface{}{limitKey} + args := []interface{}{configItem.count, configItem.timeWindow} + + limitRedisContext := LimitRedisContext{ + key: limitKey, + count: configItem.count, + window: configItem.timeWindow, + } + ctx.SetContext(LimitRedisContextKey, limitRedisContext) + + // 执行限流逻辑 + err := config.redisClient.Eval(RequestPhaseFixedWindowScript, 1, keys, args, func(response resp.Value) { + resultArray := response.Array() + if len(resultArray) != 3 { + log.Errorf("redis response parse error, response: %v", response) + return + } + context := LimitContext{ + count: resultArray[0].Integer(), + remaining: resultArray[1].Integer(), + reset: resultArray[2].Integer(), + } + if context.remaining < 0 { + // 触发限流 + rejected(config, context) + } else { + proxywasm.ResumeHttpRequest() + } + }) + if err != nil { + log.Errorf("redis call failed: %v", err) + return types.ActionContinue + } + return types.ActionPause +} + +func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { + if !endOfStream { + return data + } + inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) + if err != nil { + return data + } + outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) + if err != nil { + return data + } + inputToken, err := strconv.Atoi(string(inputTokenStr)) + if err != nil { + return data + } + outputToken, err := strconv.Atoi(string(outputTokenStr)) + if err != nil { + return data + } + limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) + if !ok { + return data + } + keys := []interface{}{limitRedisContext.key} + args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} + + err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, func(response resp.Value) { + if response.Error() != nil { + log.Errorf("call Eval error: %v", response.Error()) + } + proxywasm.ResumeHttpResponse() + }) + if err != nil { + log.Errorf("redis call failed: %v", err) + return data + } else { + return data + } +} + +func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) { + for _, rule := range ruleItems { + val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log) + if ruleItem != nil && configItem != nil { + return val, ruleItem, configItem + } + } + return "", nil, nil +} + +func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) { + switch rule.limitType { + // 根据HTTP请求头限流 + case limitByHeaderType, limitByPerHeaderType: + val, err := proxywasm.GetHttpRequestHeader(rule.key) + if err != nil { + return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err) + } + return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + // 根据HTTP请求参数限流 + case limitByParamType, limitByPerParamType: + parse, err := url.Parse(ctx.Path()) + if err != nil { + return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err) + } + query, err := url.ParseQuery(parse.RawQuery) + if err != nil { + return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err) + } + val, ok := query[rule.key] + if !ok { + return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key) + } + return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0]) + // 根据consumer限流 + case limitByConsumerType, limitByPerConsumerType: + val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader) + if err != nil { + return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err) + } + return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + // 根据cookie中key值限流 + case limitByCookieType, limitByPerCookieType: + cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader) + if err != nil { + return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err) + } + val := extractCookieValueByKey(cookie, rule.key) + if val == "" { + return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie) + } + return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) + // 根据客户端IP限流 + case limitByPerIpType: + realIp, err := getDownStreamIp(rule) + if err != nil { + log.Warnf("failed to get down stream ip: %v", err) + return "", &rule, nil + } + for _, item := range rule.configItems { + if _, found, _ := item.ipNet.Get(realIp); !found { + continue + } + return realIp.String(), &rule, &item + } + } + return "", nil, nil +} + +func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) { + log.Debugf(errMsg, args...) + return "", nil, nil +} + +func findMatchingItem(limitType limitRuleItemType, items []LimitConfigItem, key string) *LimitConfigItem { + for _, item := range items { + // per类型,检查allType和regexpType + if limitType == limitByPerHeaderType || + limitType == limitByPerParamType || + limitType == limitByPerConsumerType || + limitType == limitByPerCookieType { + if item.configType == allType || (item.configType == regexpType && item.regexp.MatchString(key)) { + return &item + } + } + // 其他类型,直接比较key + if item.key == key { + return &item + } + } + return nil +} + +func getDownStreamIp(rule LimitRuleItem) (net.IP, error) { + var ( + realIpStr string + err error + ) + if rule.limitByPerIp.sourceType == HeaderSourceType { + realIpStr, err = proxywasm.GetHttpRequestHeader(rule.limitByPerIp.headerName) + if err == nil { + realIpStr = strings.Split(strings.Trim(realIpStr, " "), ",")[0] + } + } else { + var bs []byte + bs, err = proxywasm.GetProperty([]string{"source", "address"}) + realIpStr = string(bs) + } + if err != nil { + return nil, err + } + ip := parseIP(realIpStr) + realIP := net.ParseIP(ip) + if realIP == nil { + return nil, fmt.Errorf("invalid ip[%s]", ip) + } + return realIP, nil +} + +func rejected(config ClusterKeyRateLimitConfig, context LimitContext) { + headers := make(map[string][]string) + headers[RateLimitResetHeader] = []string{strconv.Itoa(context.reset)} + _ = proxywasm.SendHttpResponse( + config.rejectedCode, reconvertHeaders(headers), []byte(config.rejectedMsg), -1) +} diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go b/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go new file mode 100644 index 000000000..e1908a26b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/utils.go @@ -0,0 +1,60 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + "github.com/zmap/go-iptree/iptree" +) + +// parseIPNet 解析Ip段配置 +func parseIPNet(key string) (*iptree.IPTree, error) { + tree := iptree.New() + err := tree.AddByString(key, 0) + if err != nil { + return nil, fmt.Errorf("invalid IP[%s]", key) + } + return tree, nil +} + +// parseIP 解析IP +func parseIP(source string) string { + if strings.Contains(source, ".") { + // parse ipv4 + return strings.Split(source, ":")[0] + } + // parse ipv6 + if strings.Contains(source, "]") { + return strings.Split(source, "]")[0][1:] + } + return source +} + +// reconvertHeaders headers: map[string][]string -> [][2]string +func reconvertHeaders(hs map[string][]string) [][2]string { + var ret [][2]string + for k, vs := range hs { + for _, v := range vs { + ret = append(ret, [2]string{k, v}) + } + } + sort.SliceStable(ret, func(i, j int) bool { + return ret[i][0] < ret[j][0] + }) + return ret +} + +// extractCookieValueByKey 从cookie中提取key对应的value +func extractCookieValueByKey(cookie string, key string) (value string) { + pairs := strings.Split(cookie, ";") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + kv := strings.Split(pair, "=") + if kv[0] == key { + value = kv[1] + break + } + } + return value +}