mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
feat: Support statistics images/audio/responses API Token usage (#2542)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/tokenusage"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
@@ -35,9 +35,9 @@ func main() {}
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"ai-token-ratelimit",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -84,8 +84,8 @@ type LimitRedisContext struct {
|
||||
window int64
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
|
||||
err := initRedisClusterClient(json, config, log)
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
|
||||
err := initRedisClusterClient(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -98,9 +98,9 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.L
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action {
|
||||
// 判断是否命中限流规则
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems)
|
||||
if ruleItem == nil || configItem == nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -146,18 +146,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte {
|
||||
var inputToken, outputToken int64
|
||||
if inputToken, outputToken, ok := getUsage(data); ok {
|
||||
ctx.SetContext("input_token", inputToken)
|
||||
ctx.SetContext("output_token", outputToken)
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte {
|
||||
if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
|
||||
ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
|
||||
ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
|
||||
}
|
||||
if endOfStream {
|
||||
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
|
||||
if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil {
|
||||
return data
|
||||
}
|
||||
inputToken = ctx.GetContext("input_token").(int64)
|
||||
outputToken = ctx.GetContext("output_token").(int64)
|
||||
inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
|
||||
outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
|
||||
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
|
||||
if !ok {
|
||||
return data
|
||||
@@ -172,29 +171,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf
|
||||
return data
|
||||
}
|
||||
|
||||
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
for _, chunk := range chunks {
|
||||
// the feature strings are used to identify the usage data, like:
|
||||
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
|
||||
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
|
||||
continue
|
||||
}
|
||||
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
|
||||
if inputTokenObj.Exists() && outputTokenObj.Exists() {
|
||||
inputTokenUsage = inputTokenObj.Int()
|
||||
outputTokenUsage = outputTokenObj.Int()
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
for _, rule := range ruleItems {
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule)
|
||||
if ruleItem != nil && configItem != nil {
|
||||
return val, ruleItem, configItem
|
||||
}
|
||||
@@ -202,46 +181,46 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
switch rule.limitType {
|
||||
// 根据HTTP请求头限流
|
||||
case limitByHeaderType, limitByPerHeaderType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(rule.key)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
|
||||
return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据HTTP请求参数限流
|
||||
case limitByParamType, limitByPerParamType:
|
||||
parse, err := url.Parse(ctx.Path())
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
|
||||
return logDebugAndReturnEmpty("failed to parse request path: %v", err)
|
||||
}
|
||||
query, err := url.ParseQuery(parse.RawQuery)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
|
||||
return logDebugAndReturnEmpty("failed to parse query params: %v", err)
|
||||
}
|
||||
val, ok := query[rule.key]
|
||||
if !ok {
|
||||
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
|
||||
return logDebugAndReturnEmpty("request param %s is empty", rule.key)
|
||||
}
|
||||
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
|
||||
// 根据consumer限流
|
||||
case limitByConsumerType, limitByPerConsumerType:
|
||||
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
|
||||
return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据cookie中key值限流
|
||||
case limitByCookieType, limitByPerCookieType:
|
||||
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
|
||||
if err != nil {
|
||||
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
|
||||
return logDebugAndReturnEmpty("failed to get request cookie : %v", err)
|
||||
}
|
||||
val := extractCookieValueByKey(cookie, rule.key)
|
||||
if val == "" {
|
||||
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
|
||||
return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
|
||||
}
|
||||
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
|
||||
// 根据客户端IP限流
|
||||
@@ -261,7 +240,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
log.Debugf(errMsg, args...)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user