feat: add AI quota plugin (#1200)

This commit is contained in:
Jun
2024-08-14 13:43:31 +08:00
committed by GitHub
parent daa374d9a4
commit d31c978ed3
6 changed files with 582 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
# 功能说明
`ai-qutoa` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。
`ai-quota` 插件需要配合 认证插件比如 `key-auth``jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statatistics` 插件获取 AI Token 统计信息。
# 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|--------------------------------------| ---- |--------------------------------------------|
| `redis_key_prefix` | string | 选填 | chat_quota: | qutoa redis key 前缀 |
| `admin_consumer` | string | 必填 | | 管理 quota 管理身份的 consumer 名称 |
| `admin_path` | string | 选填 | /quota | 管理 quota 请求 path 前缀 |
| `redis` | object | 是 | | redis相关配置 |
`redis`中每一项的配置字段说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- |
| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
| service_port | int | 否 | 服务类型为固定地址static service默认值为80其他为6379 | 输入redis服务的服务端口 |
| username | string | 否 | - | redis用户名 |
| password | string | 否 | - | redis密码 |
| timeout | int | 否 | 1000 | redis连接超时时间单位毫秒 |
# 配置示例
## 识别请求参数 apikey进行区别限流
```yaml
redis_key_prefix: "chat_quota:"
admin_consumer: consumer3
admin_path: /quota
redis:
service_name: redis-service.default.svc.cluster.local
service_port: 6379
timeout: 2000
```
## 刷新 quota
如果当前请求 url 的后缀符合 admin_path例如插件在 example.com/v1/chat/completions 这个路由上生效,那么更新 quota 可以通过
curl https://example.com/v1/chat/completions/quota/refresh -H "Authorization: Bearer credential3" -d "consumer=consumer1&quota=10000"
Redis 中 key 为 chat_quota:consumer1 的值就会被刷新为 10000
## 查询 quota
查询特定用户的 quota 可以通过 curl https://example.com/v1/chat/completions/quota?consumer=consumer1 -H "Authorization: Bearer credential3"
将返回: {"quota": 10000, "consumer": "consumer1"}
## 增减 quota
增减特定用户的 quota 可以通过 curl https://example.com/v1/chat/completions/quota/delta -d "consumer=consumer1&value=100" -H "Authorization: Bearer credential3"
这样 Redis 中 Key 为 chat_quota:consumer1 的值就会增加100可以支持负数则减去对应值。

View File

@@ -0,0 +1,20 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota
go 1.19
//replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
)

View File

@@ -0,0 +1,22 @@
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de h1:lDLqj7Hw41ox8VdsP7oCTPhjPa3+QJUCKApcLh2a45Y=
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de/go.mod h1:359don/ahMxpfeLMzr29Cjwcu8IywTTDUzWlBPRNLHw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,399 @@
package main
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
pluginName = "ai-quota"
)
type ChatMode string
const (
ChatModeCompletion ChatMode = "completion"
ChatModeAdmin ChatMode = "admin"
ChatModeNone ChatMode = "none"
)
type AdminMode string
const (
AdminModeRefresh AdminMode = "refresh"
AdminModeQuery AdminMode = "query"
AdminModeDelta AdminMode = "delta"
AdminModeNone AdminMode = "none"
)
func main() {
wrapper.SetCtx(
pluginName,
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingResponseBody),
)
}
type QuotaConfig struct {
redisInfo RedisInfo `yaml:"redis"`
RedisKeyPrefix string `yaml:"redis_key_prefix"`
AdminConsumer string `yaml:"admin_consumer"`
AdminPath string `yaml:"admin_path"`
credential2Name map[string]string `yaml:"-"`
redisClient wrapper.RedisClient
}
type Consumer struct {
Name string `yaml:"name"`
Credential string `yaml:"credential"`
}
type RedisInfo struct {
ServiceName string `required:"true" yaml:"service_name" json:"service_name"`
ServicePort int `required:"false" yaml:"service_port" json:"service_port"`
Username string `required:"false" yaml:"username" json:"username"`
Password string `required:"false" yaml:"password" json:"password"`
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
}
func parseConfig(json gjson.Result, config *QuotaConfig, log wrapper.Log) error {
log.Debugf("parse config()")
// admin
config.AdminPath = json.Get("admin_path").String()
config.AdminConsumer = json.Get("admin_consumer").String()
if config.AdminPath == "" {
config.AdminPath = "/quota"
}
if config.AdminConsumer == "" {
return errors.New("missing admin_consumer in config")
}
// Redis
config.RedisKeyPrefix = json.Get("redis_key_prefix").String()
if config.RedisKeyPrefix == "" {
config.RedisKeyPrefix = "chat_quota:"
}
redisConfig := json.Get("redis")
if !redisConfig.Exists() {
return errors.New("missing redis in config")
}
serviceName := redisConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("redis service name must not be empty")
}
servicePort := int(redisConfig.Get("service_port").Int())
if servicePort == 0 {
if strings.HasSuffix(serviceName, ".static") {
// use default logic port which is 80 for static service
servicePort = 80
} else {
servicePort = 6379
}
}
username := redisConfig.Get("username").String()
password := redisConfig.Get("password").String()
timeout := int(redisConfig.Get("timeout").Int())
if timeout == 0 {
timeout = 1000
}
config.redisInfo.ServiceName = serviceName
config.redisInfo.ServicePort = servicePort
config.redisInfo.Username = username
config.redisInfo.Password = password
config.redisInfo.Timeout = timeout
config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: int64(servicePort),
})
return config.redisClient.Init(username, password, int64(timeout))
}
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log wrapper.Log) types.Action {
log.Debugf("onHttpRequestHeaders()")
// get tokens
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
if err != nil {
return deniedNoKeyAuthData()
}
if consumer == "" {
return deniedUnauthorizedConsumer()
}
rawPath := context.Path()
path, _ := url.Parse(rawPath)
chatMode, adminMode := getOperationMode(path.Path, config.AdminPath, log)
context.SetContext("chatMode", chatMode)
context.SetContext("adminMode", adminMode)
context.SetContext("consumer", consumer)
log.Debugf("chatMode:%s, adminMode:%s, consumer:%s", chatMode, adminMode, consumer)
if chatMode == ChatModeNone {
return types.ActionContinue
}
if chatMode == ChatModeAdmin {
// query quota
if adminMode == AdminModeQuery {
return queryQuota(context, config, consumer, path, log)
}
if adminMode == AdminModeRefresh || adminMode == AdminModeDelta {
context.BufferRequestBody()
return types.HeaderStopIteration
}
return types.ActionContinue
}
// there is no need to read request body when it is on chat completion mode
context.DontReadRequestBody()
// check quota here
config.redisClient.Get(config.RedisKeyPrefix+consumer, func(response resp.Value) {
isDenied := false
if err := response.Error(); err != nil {
isDenied = true
}
if response.IsNull() {
isDenied = true
}
if response.Integer() <= 0 {
isDenied = true
}
log.Debugf("get consumer:%s quota:%d isDenied:%t", consumer, response.Integer(), isDenied)
if isDenied {
util.SendResponse(http.StatusForbidden, "ai-quota.noquota", "text/plain", "Request denied by ai quota check, No quota left")
return
}
proxywasm.ResumeHttpRequest()
})
return types.HeaderStopAllIterationAndWatermark
}
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("onHttpRequestBody()")
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
return types.ActionContinue
}
if chatMode == ChatModeNone || chatMode == ChatModeCompletion {
return types.ActionContinue
}
adminMode, ok := ctx.GetContext("adminMode").(AdminMode)
if !ok {
return types.ActionContinue
}
adminConsumer, ok := ctx.GetContext("consumer").(string)
if !ok {
return types.ActionContinue
}
if adminMode == AdminModeRefresh {
return refreshQuota(ctx, config, adminConsumer, string(body), log)
}
if adminMode == AdminModeDelta {
return deltaQuota(ctx, config, adminConsumer, string(body), log)
}
return types.ActionContinue
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
return data
}
if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
return data
}
// chat completion mode
if !endOfStream {
return data
}
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
if err != nil {
return data
}
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
if err != nil {
return data
}
inputToken, err := strconv.Atoi(string(inputTokenStr))
if err != nil {
return data
}
outputToken, err := strconv.Atoi(string(outputTokenStr))
if err != nil {
return data
}
consumer, ok := ctx.GetContext("consumer").(string)
if ok {
totalToken := int(inputToken + outputToken)
log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
}
return data
}
func deniedNoKeyAuthData() types.Action {
util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
return types.ActionContinue
}
func deniedUnauthorizedConsumer() types.Action {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized consumer.")
return types.ActionContinue
}
func getOperationMode(path string, adminPath string, log wrapper.Log) (ChatMode, AdminMode) {
fullAdminPath := "/v1/chat/completions" + adminPath
if strings.HasSuffix(path, fullAdminPath+"/refresh") {
return ChatModeAdmin, AdminModeRefresh
}
if strings.HasSuffix(path, fullAdminPath+"/delta") {
return ChatModeAdmin, AdminModeDelta
}
if strings.HasSuffix(path, fullAdminPath) {
return ChatModeAdmin, AdminModeQuery
}
if strings.HasSuffix(path, "/v1/chat/completions") {
return ChatModeCompletion, AdminModeNone
}
return ChatModeNone, AdminModeNone
}
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
queryValues, _ := url.ParseQuery(body)
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
queryConsumer := values["consumer"]
quota, err := strconv.Atoi(values["quota"])
if queryConsumer == "" || err != nil {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty and quota must be integer.")
return types.ActionContinue
}
err2 := config.redisClient.Set(config.RedisKeyPrefix+queryConsumer, quota, func(response resp.Value) {
log.Debugf("Redis set key = %s quota = %d", config.RedisKeyPrefix+queryConsumer, quota)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.refreshquota", "text/plain", "refresh quota successful")
})
if err2 != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
return types.ActionPause
}
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
// check url
queryValues := url.Query()
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
if values["consumer"] == "" {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty.")
return types.ActionContinue
}
queryConsumer := values["consumer"]
err := config.redisClient.Get(config.RedisKeyPrefix+queryConsumer, func(response resp.Value) {
quota := 0
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
} else if response.IsNull() {
quota = 0
} else {
quota = response.Integer()
}
result := struct {
Consumer string `json:"consumer"`
Quota int `json:"quota"`
}{
Consumer: queryConsumer,
Quota: quota,
}
body, _ := json.Marshal(result)
util.SendResponse(http.StatusOK, "ai-quota.queryquota", "application/json", string(body))
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
return types.ActionPause
}
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
queryValues, _ := url.ParseQuery(body)
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
queryConsumer := values["consumer"]
value, err := strconv.Atoi(values["value"])
if queryConsumer == "" || err != nil {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty and value must be integer.")
return types.ActionContinue
}
if value >= 0 {
err := config.redisClient.IncrBy(config.RedisKeyPrefix+queryConsumer, value, func(response resp.Value) {
log.Debugf("Redis Incr key = %s value = %d", config.RedisKeyPrefix+queryConsumer, value)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.deltaquota", "text/plain", "delta quota successful")
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
} else {
err := config.redisClient.DecrBy(config.RedisKeyPrefix+queryConsumer, 0-value, func(response resp.Value) {
log.Debugf("Redis Decr key = %s value = %d", config.RedisKeyPrefix+queryConsumer, 0-value)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.deltaquota", "text/plain", "delta quota successful")
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
}
return types.ActionPause
}

View File

@@ -0,0 +1,61 @@
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-quota
namespace: higress-system
spec:
defaultConfig: {}
defaultConfigDisable: true
matchRules:
- config:
redis_key_prefix: "chat_quota:"
admin_consumer: consumer3
admin_path: /quota
redis:
service_name: redis-service.default.svc.cluster.local
service_port: 6379
timeout: 2000
configDisable: false
ingress:
- qwen
phase: UNSPECIFIED_PHASE
priority: 280
url: oci://registry.cn-hangzhou.aliyuncs.com/2456868764/ai-quota:1.0.8
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-statistics
namespace: higress-system
spec:
defaultConfig:
enable: true
defaultConfigDisable: false
phase: UNSPECIFIED_PHASE
priority: 250
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-statistics:1.0.0
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: wasm-keyauth
namespace: higress-system
spec:
defaultConfig:
consumers:
- credential: "Bearer credential1"
name: consumer1
- credential: "Bearer credential2"
name: consumer2
- credential: "Bearer credential3"
name: consumer3
global_auth: true
keys:
- authorization
in_header: true
defaultConfigDisable: false
priority: 300
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/key-auth:1.0.0
imagePullPolicy: Always

View File

@@ -0,0 +1,22 @@
package util
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
const (
HeaderContentType = "Content-Type"
MimeTypeTextPlain = "text/plain"
MimeTypeApplicationJson = "application/json"
)
func SendResponse(statusCode uint32, statusCodeDetails string, contentType, body string) error {
return proxywasm.SendHttpResponseWithDetail(statusCode, statusCodeDetails, CreateHeaders(HeaderContentType, contentType), []byte(body), -1)
}
func CreateHeaders(kvs ...string) [][2]string {
headers := make([][2]string, 0, len(kvs)/2)
for i := 0; i < len(kvs); i += 2 {
headers = append(headers, [2]string{kvs[i], kvs[i+1]})
}
return headers
}