feat: Support statistics images/audio/responses API Token usage (#2542)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
Xijun Dai
2025-07-16 10:34:09 +08:00
committed by GitHub
parent ce271849de
commit 8346b4a4a2
13 changed files with 193 additions and 262 deletions

View File

@@ -6,7 +6,7 @@ toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
)

View File

@@ -6,6 +6,8 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -1,7 +1,6 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
@@ -10,13 +9,15 @@ import (
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
)
const (
@@ -45,10 +46,10 @@ func main() {}
func init() {
wrapper.SetCtx(
pluginName,
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingResponseBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
)
}
@@ -75,7 +76,7 @@ type RedisInfo struct {
Database int `required:"false" yaml:"database" json:"database"`
}
func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error {
func parseConfig(json gjson.Result, config *QuotaConfig) error {
log.Debugf("parse config()")
// admin
config.AdminPath = json.Get("admin_path").String()
@@ -129,7 +130,7 @@ func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error {
return config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database))
}
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log log.Log) types.Action {
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig) types.Action {
log.Debugf("onHttpRequestHeaders()")
// get tokens
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
@@ -142,7 +143,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
rawPath := context.Path()
path, _ := url.Parse(rawPath)
chatMode, adminMode := getOperationMode(path.Path, config.AdminPath, log)
chatMode, adminMode := getOperationMode(path.Path, config.AdminPath)
context.SetContext("chatMode", chatMode)
context.SetContext("adminMode", adminMode)
context.SetContext("consumer", consumer)
@@ -153,7 +154,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
if chatMode == ChatModeAdmin {
// query quota
if adminMode == AdminModeQuery {
return queryQuota(context, config, consumer, path, log)
return queryQuota(context, config, consumer, path)
}
if adminMode == AdminModeRefresh || adminMode == AdminModeDelta {
context.BufferRequestBody()
@@ -186,7 +187,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
return types.HeaderStopAllIterationAndWatermark
}
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log log.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte) types.Action {
log.Debugf("onHttpRequestBody()")
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
@@ -205,16 +206,16 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte,
}
if adminMode == AdminModeRefresh {
return refreshQuota(ctx, config, adminConsumer, string(body), log)
return refreshQuota(ctx, config, adminConsumer, string(body))
}
if adminMode == AdminModeDelta {
return deltaQuota(ctx, config, adminConsumer, string(body), log)
return deltaQuota(ctx, config, adminConsumer, string(body))
}
return types.ActionContinue
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log log.Log) []byte {
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool) []byte {
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
return data
@@ -222,11 +223,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
return data
}
var inputToken, outputToken int64
var consumer string
if inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
}
// chat completion mode
@@ -234,39 +233,19 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
return data
}
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil {
if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil || ctx.GetContext("consumer") == nil {
return data
}
inputToken = ctx.GetContext("input_token").(int64)
outputToken = ctx.GetContext("output_token").(int64)
consumer = ctx.GetContext("consumer").(string)
inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
consumer := ctx.GetContext("consumer").(string)
totalToken := int(inputToken + outputToken)
log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
return data
}
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func deniedNoKeyAuthData() types.Action {
util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
return types.ActionContinue
@@ -277,7 +256,7 @@ func deniedUnauthorizedConsumer() types.Action {
return types.ActionContinue
}
func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, AdminMode) {
func getOperationMode(path string, adminPath string) (ChatMode, AdminMode) {
fullAdminPath := "/v1/chat/completions" + adminPath
if strings.HasSuffix(path, fullAdminPath+"/refresh") {
return ChatModeAdmin, AdminModeRefresh
@@ -294,7 +273,7 @@ func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, Adm
return ChatModeNone, AdminModeNone
}
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action {
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
@@ -328,7 +307,8 @@ func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer str
return types.ActionPause
}
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log log.Log) types.Action {
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
@@ -371,7 +351,8 @@ func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer strin
}
return types.ActionPause
}
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action {
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")

View File

@@ -7,7 +7,7 @@ toolchain go1.24.4
require (
github.com/antchfx/xmlquery v1.4.4
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
@@ -19,6 +19,6 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/text v0.23.0 // indirect
)

View File

@@ -11,10 +11,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
@@ -51,8 +49,9 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -88,8 +87,9 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -15,7 +15,6 @@
package main
import (
"bytes"
_ "embed"
"errors"
"fmt"
@@ -27,11 +26,10 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
@@ -86,16 +84,16 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-search",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
func parseConfig(json gjson.Result, config *Config, log log.Log) error {
func parseConfig(json gjson.Result, config *Config) error {
config.defaultEnable = true // Default to true if not specified
if json.Get("defaultEnable").Exists() {
config.defaultEnable = json.Get("defaultEnable").Bool()
@@ -279,7 +277,7 @@ func parseConfig(json gjson.Result, config *Config, log log.Log) error {
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
// The request does not have a body.
if contentType == "" {
@@ -296,7 +294,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) t
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
@@ -437,7 +435,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
proxywasm.ResumeHttpRequest()
return
}
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts) {
proxywasm.ResumeHttpRequest()
}
}, searchRewrite.timeoutMillisecond)
@@ -453,10 +451,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query},
Language: config.defaultLanguage,
}}, log)
}})
}
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action {
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext) types.Action {
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
var finished int
var searching int
@@ -559,7 +557,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
if !config.needReference {
ctx.DontReadResponseBody()
return types.ActionContinue
@@ -576,7 +574,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log)
return types.ActionContinue
}
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
references := ctx.GetStringContext("References", "")
if references == "" {
return types.ActionContinue
@@ -618,19 +616,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
return types.ActionContinue
}
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
const (
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
BUFFER_SIZE = 30
)
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte {
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk
}
@@ -638,7 +630,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
if references == "" {
return chunk
}
chunk = unifySSEChunk(chunk)
chunk = wrapper.UnifySSEChunk(chunk)
var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
@@ -651,7 +643,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail")
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
@@ -669,7 +661,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string {
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string

View File

@@ -6,7 +6,7 @@ toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0
)

View File

@@ -8,6 +8,8 @@ github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxX
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -5,12 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"time"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
@@ -20,12 +22,12 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-statistics",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
@@ -41,6 +43,7 @@ const (
ClusterName = "cluster"
APIName = "api"
ConsumerKey = "x-mse-consumer"
RequestPath = "request_path"
// Source Type
FixedValue = "fixed_value"
@@ -51,9 +54,6 @@ const (
ResponseBody = "response_body"
// Inner metric & log attributes
Model = "model"
InputToken = "input_token"
OutputToken = "output_token"
LLMFirstTokenDuration = "llm_first_token_duration"
LLMServiceDuration = "llm_service_duration"
LLMDurationCount = "llm_duration_count"
@@ -146,7 +146,7 @@ func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64
counter.Increment(inc)
}
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Log) error {
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
// Parse tracing span attributes setting.
attributeConfigs := configJson.Get("attributes").Array()
config.attributes = make([]Attribute, len(attributeConfigs))
@@ -174,17 +174,20 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Lo
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action {
route, _ := getRouteName()
cluster, _ := getClusterName()
api, api_error := getAPIName()
if api_error == nil {
api, apiError := getAPIName()
if apiError == nil {
route = api
}
ctx.SetContext(RouteName, route)
ctx.SetContext(ClusterName, cluster)
ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
if requestPath, _ := proxywasm.GetHttpRequestHeader(":path"); requestPath != "" {
ctx.SetContext(RequestPath, requestPath)
}
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
ctx.SetContext(ConsumerKey, consumer)
}
@@ -195,56 +198,71 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
}
// Set user defined log & span attributes which type is fixed_value
setAttributeBySource(ctx, config, FixedValue, nil, log)
setAttributeBySource(ctx, config, FixedValue, nil)
// Set user defined log & span attributes which type is request_header
setAttributeBySource(ctx, config, RequestHeader, nil, log)
setAttributeBySource(ctx, config, RequestHeader, nil)
// Set span attributes for ARMS.
setSpanAttribute(ArmsSpanKind, "LLM", log)
setSpanAttribute(ArmsSpanKind, "LLM")
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action {
// Set user defined log & span attributes.
setAttributeBySource(ctx, config, RequestBody, body, log)
setAttributeBySource(ctx, config, RequestBody, body)
// Set span attributes for ARMS.
requestModel := gjson.GetBytes(body, "model").String()
if requestModel == "" {
requestModel = "UNKNOWN"
requestModel := "UNKNOWN"
if model := gjson.GetBytes(body, "model"); model.Exists() {
requestModel = model.String()
} else {
requestPath := ctx.GetStringContext(RequestPath, "")
if strings.Contains(requestPath, "generateContent") || strings.Contains(requestPath, "streamGenerateContent") { // Google Gemini GenerateContent
reg := regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):\w+Content$`)
matches := reg.FindStringSubmatch(requestPath)
if len(matches) == 3 {
requestModel = matches[2]
}
}
}
setSpanAttribute(ArmsRequestModel, requestModel, log)
setSpanAttribute(ArmsRequestModel, requestModel)
// Set the number of conversation rounds
if gjson.GetBytes(body, "messages").Exists() {
userPromptCount := 0
for _, msg := range gjson.GetBytes(body, "messages").Array() {
userPromptCount := 0
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" {
userPromptCount += 1
}
}
ctx.SetUserAttribute(ChatRound, userPromptCount)
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() { // Google Gemini GenerateContent
for _, content := range contents.Array() {
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
userPromptCount += 1
}
}
}
ctx.SetUserAttribute(ChatRound, userPromptCount)
// Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action {
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if !strings.Contains(contentType, "text/event-stream") {
ctx.BufferResponseBody()
}
// Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseHeader, nil, log)
setAttributeBySource(ctx, config, ResponseHeader, nil)
return types.ActionContinue
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log log.Log) []byte {
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool) []byte {
// Buffer stream body for record log & span attributes
if config.shouldBufferStreamingBody {
var streamingBodyBuffer []byte
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
if !ok {
streamingBodyBuffer = data
@@ -255,9 +273,13 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
}
ctx.SetUserAttribute(ResponseType, "stream")
chatID := gjson.GetBytes(data, "id").String()
if chatID != "" {
ctx.SetUserAttribute(ChatID, chatID)
if chatID := wrapper.GetValueFromBody(data, []string{
"id",
"response.id",
"responseId", // Gemini generateContent
"message.id", // anthropic messages
}); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String())
}
// Get requestStartTime from http context
@@ -276,15 +298,12 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
// Set information about this request
if !config.disableOpenaiUsage {
if model, inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetUserAttribute(Model, model)
ctx.SetUserAttribute(InputToken, inputToken)
ctx.SetUserAttribute(OutputToken, outputToken)
if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
// Set span attributes for ARMS.
setSpanAttribute(ArmsModelName, model, log)
setSpanAttribute(ArmsInputToken, inputToken, log)
setSpanAttribute(ArmsOutputToken, outputToken, log)
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log)
setSpanAttribute(ArmsTotalToken, usage.TotalToken)
setSpanAttribute(ArmsModelName, usage.Model)
setSpanAttribute(ArmsInputToken, usage.InputToken)
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
}
}
// If the end of the stream is reached, record metrics/logs/spans.
@@ -298,19 +317,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
if !ok {
return data
}
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log)
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
}
// Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics
writeMetric(ctx, config, log)
writeMetric(ctx, config)
}
return data
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action {
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action {
// Get requestStartTime from http context
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
@@ -318,74 +337,41 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
ctx.SetUserAttribute(ResponseType, "normal")
chatID := gjson.GetBytes(body, "id").String()
if chatID != "" {
ctx.SetUserAttribute(ChatID, chatID)
if chatID := wrapper.GetValueFromBody(body, []string{
"id",
"response.id",
"responseId", // Gemini generateContent
"message.id", // anthropic messages
}); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String())
}
// Set information about this request
if !config.disableOpenaiUsage {
if model, inputToken, outputToken, ok := getUsage(body); ok {
ctx.SetUserAttribute(Model, model)
ctx.SetUserAttribute(InputToken, inputToken)
ctx.SetUserAttribute(OutputToken, outputToken)
if usage := tokenusage.GetTokenUsage(ctx, body); usage.TotalToken > 0 {
// Set span attributes for ARMS.
setSpanAttribute(ArmsModelName, model, log)
setSpanAttribute(ArmsInputToken, inputToken, log)
setSpanAttribute(ArmsOutputToken, outputToken, log)
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log)
setSpanAttribute(ArmsModelName, usage.Model)
setSpanAttribute(ArmsInputToken, usage.InputToken)
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
setSpanAttribute(ArmsTotalToken, usage.TotalToken)
}
}
// Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseBody, body, log)
setAttributeBySource(ctx, config, ResponseBody, body)
// Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics
writeMetric(ctx, config, log)
writeMetric(ctx, config)
return types.ActionContinue
}
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) {
continue
}
if !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
modelObj := gjson.GetBytes(chunk, "model")
if modelObj.Exists() {
model = modelObj.String()
} else {
model = "unknown"
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
// fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log log.Log) {
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte) {
for _, attribute := range config.attributes {
var key string
var value interface{}
@@ -401,7 +387,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
case ResponseHeader:
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
case ResponseStreamingBody:
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule)
case ResponseBody:
value = gjson.GetBytes(body, attribute.Value).Value()
default:
@@ -421,21 +407,21 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
}
}
// for metrics
if key == Model || key == InputToken || key == OutputToken {
if key == tokenusage.CtxKeyModel || key == tokenusage.CtxKeyInputToken || key == tokenusage.CtxKeyOutputToken || key == tokenusage.CtxKeyTotalToken {
ctx.SetContext(key, value)
}
if attribute.ApplyToSpan {
if attribute.TraceSpanKey != "" {
key = attribute.TraceSpanKey
}
setSpanAttribute(key, value, log)
setSpanAttribute(key, value)
}
}
}
}
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log log.Log) interface{} {
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n"))
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string) interface{} {
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
var value interface{}
if rule == RuleFirst {
for _, chunk := range chunks {
@@ -469,7 +455,7 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
}
// Set the tracing span with value.
func setSpanAttribute(key string, value interface{}, log log.Log) {
func setSpanAttribute(key string, value interface{}) {
if value != "" {
traceSpanTag := wrapper.TraceSpanTagPrefix + key
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
@@ -480,11 +466,10 @@ func setSpanAttribute(key string, value interface{}, log log.Log) {
}
}
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) {
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig) {
// Generate usage metrics
var ok bool
var route, cluster, model string
var inputToken, outputToken uint64
consumer := ctx.GetStringContext(ConsumerKey, "none")
route, ok = ctx.GetContext(RouteName).(string)
if !ok {
@@ -501,31 +486,30 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log
return
}
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
if ctx.GetUserAttribute(tokenusage.CtxKeyModel) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyInputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken) == nil {
log.Warnf("get usage information failed, skip metric record")
return
}
model, ok = ctx.GetUserAttribute(Model).(string)
model, ok = ctx.GetUserAttribute(tokenusage.CtxKeyModel).(string)
if !ok {
log.Warnf("Model typd assert failed, skip metric record")
return
}
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
if !ok {
if inputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyInputToken)); ok {
config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyInputToken), inputToken)
} else {
log.Warnf("InputToken typd assert failed, skip metric record")
return
}
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
if !ok {
if outputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken)); ok {
config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyOutputToken), outputToken)
} else {
log.Warnf("OutputToken typd assert failed, skip metric record")
return
}
if inputToken == 0 || outputToken == 0 {
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return
if totalToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken)); ok {
config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyTotalToken), totalToken)
} else {
log.Warnf("TotalToken typd assert failed, skip metric record")
}
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
// Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration uint64

View File

@@ -3,9 +3,8 @@ package main
import (
"errors"
"fmt"
"strings"
re "regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
@@ -86,7 +85,7 @@ type LimitConfigItem struct {
timeWindow int64 // 时间窗口大小
}
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
redisConfig := json.Get("redis")
if !redisConfig.Exists() {
return errors.New("missing redis in config")

View File

@@ -6,16 +6,14 @@ toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
)
require github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
require (
github.com/google/uuid v1.6.0 // indirect
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1
)

View File

@@ -1,17 +1,12 @@
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -24,4 +19,3 @@ github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYg
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M=
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -15,7 +15,6 @@
package main
import (
"bytes"
"fmt"
"net"
"net/url"
@@ -25,6 +24,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
@@ -35,9 +35,9 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-token-ratelimit",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
)
}
@@ -84,8 +84,8 @@ type LimitRedisContext struct {
window int64
}
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
err := initRedisClusterClient(json, config, log)
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
err := initRedisClusterClient(json, config)
if err != nil {
return err
}
@@ -98,9 +98,9 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.L
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action {
// 判断是否命中限流规则
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems)
if ruleItem == nil || configItem == nil {
return types.ActionContinue
}
@@ -146,18 +146,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
return types.ActionPause
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte {
var inputToken, outputToken int64
if inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte {
if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
}
if endOfStream {
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil {
return data
}
inputToken = ctx.GetContext("input_token").(int64)
outputToken = ctx.GetContext("output_token").(int64)
inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok {
return data
@@ -172,29 +171,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf
return data
}
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
for _, rule := range ruleItems {
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
val, ruleItem, configItem := hitRateRuleItem(ctx, rule)
if ruleItem != nil && configItem != nil {
return val, ruleItem, configItem
}
@@ -202,46 +181,46 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
return "", nil, nil
}
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
switch rule.limitType {
// 根据HTTP请求头限流
case limitByHeaderType, limitByPerHeaderType:
val, err := proxywasm.GetHttpRequestHeader(rule.key)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err)
return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据HTTP请求参数限流
case limitByParamType, limitByPerParamType:
parse, err := url.Parse(ctx.Path())
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err)
return logDebugAndReturnEmpty("failed to parse request path: %v", err)
}
query, err := url.ParseQuery(parse.RawQuery)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err)
return logDebugAndReturnEmpty("failed to parse query params: %v", err)
}
val, ok := query[rule.key]
if !ok {
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key)
return logDebugAndReturnEmpty("request param %s is empty", rule.key)
}
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
// 根据consumer限流
case limitByConsumerType, limitByPerConsumerType:
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err)
return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据cookie中key值限流
case limitByCookieType, limitByPerCookieType:
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err)
return logDebugAndReturnEmpty("failed to get request cookie : %v", err)
}
val := extractCookieValueByKey(cookie, rule.key)
if val == "" {
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
}
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据客户端IP限流
@@ -261,7 +240,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (
return "", nil, nil
}
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
log.Debugf(errMsg, args...)
return "", nil, nil
}