feat: Support statistics images/audio/responses API Token usage (#2542)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
Xijun Dai
2025-07-16 10:34:09 +08:00
committed by GitHub
parent ce271849de
commit 8346b4a4a2
13 changed files with 193 additions and 262 deletions

View File

@@ -15,7 +15,6 @@
package main
import (
"bytes"
"fmt"
"net"
"net/url"
@@ -25,6 +24,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
@@ -35,9 +35,9 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-token-ratelimit",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
)
}
@@ -84,8 +84,8 @@ type LimitRedisContext struct {
window int64
}
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
err := initRedisClusterClient(json, config, log)
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
err := initRedisClusterClient(json, config)
if err != nil {
return err
}
@@ -98,9 +98,9 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.L
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action {
// 判断是否命中限流规则
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems)
if ruleItem == nil || configItem == nil {
return types.ActionContinue
}
@@ -146,18 +146,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
return types.ActionPause
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte {
var inputToken, outputToken int64
if inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte {
if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
}
if endOfStream {
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil {
return data
}
inputToken = ctx.GetContext("input_token").(int64)
outputToken = ctx.GetContext("output_token").(int64)
inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok {
return data
@@ -172,29 +171,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf
return data
}
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
for _, rule := range ruleItems {
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
val, ruleItem, configItem := hitRateRuleItem(ctx, rule)
if ruleItem != nil && configItem != nil {
return val, ruleItem, configItem
}
@@ -202,46 +181,46 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
return "", nil, nil
}
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
switch rule.limitType {
// 根据HTTP请求头限流
case limitByHeaderType, limitByPerHeaderType:
val, err := proxywasm.GetHttpRequestHeader(rule.key)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据HTTP请求参数限流
case limitByParamType, limitByPerParamType:
parse, err := url.Parse(ctx.Path())
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
return logDebugAndReturnEmpty("failed to parse request path: %v", err)
}
query, err := url.ParseQuery(parse.RawQuery)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
return logDebugAndReturnEmpty("failed to parse query params: %v", err)
}
val, ok := query[rule.key]
if !ok {
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
return logDebugAndReturnEmpty("request param %s is empty", rule.key)
}
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
// 根据consumer限流
case limitByConsumerType, limitByPerConsumerType:
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据cookie中key值限流
case limitByCookieType, limitByPerCookieType:
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
return logDebugAndReturnEmpty("failed to get request cookie : %v", err)
}
val := extractCookieValueByKey(cookie, rule.key)
if val == "" {
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据客户端IP限流
@@ -261,7 +240,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (
return "", nil, nil
}
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
log.Debugf(errMsg, args...)
return "", nil, nil
}