feat: Support statistics images/audio/responses API Token usage (#2542)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
Xijun Dai
2025-07-16 10:34:09 +08:00
committed by GitHub
parent ce271849de
commit 8346b4a4a2
13 changed files with 193 additions and 262 deletions

View File

@@ -6,7 +6,7 @@ toolchain go1.24.4
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0 github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1 github.com/tidwall/resp v0.1.1
) )

View File

@@ -6,6 +6,8 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -10,13 +9,15 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
) )
const ( const (
@@ -45,10 +46,10 @@ func main() {}
func init() { func init() {
wrapper.SetCtx( wrapper.SetCtx(
pluginName, pluginName,
wrapper.ParseConfigBy(parseConfig), wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingResponseBody), wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
) )
} }
@@ -75,7 +76,7 @@ type RedisInfo struct {
Database int `required:"false" yaml:"database" json:"database"` Database int `required:"false" yaml:"database" json:"database"`
} }
func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error { func parseConfig(json gjson.Result, config *QuotaConfig) error {
log.Debugf("parse config()") log.Debugf("parse config()")
// admin // admin
config.AdminPath = json.Get("admin_path").String() config.AdminPath = json.Get("admin_path").String()
@@ -129,7 +130,7 @@ func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error {
return config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database)) return config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database))
} }
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log log.Log) types.Action { func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig) types.Action {
log.Debugf("onHttpRequestHeaders()") log.Debugf("onHttpRequestHeaders()")
// get tokens // get tokens
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer") consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
@@ -142,7 +143,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
rawPath := context.Path() rawPath := context.Path()
path, _ := url.Parse(rawPath) path, _ := url.Parse(rawPath)
chatMode, adminMode := getOperationMode(path.Path, config.AdminPath, log) chatMode, adminMode := getOperationMode(path.Path, config.AdminPath)
context.SetContext("chatMode", chatMode) context.SetContext("chatMode", chatMode)
context.SetContext("adminMode", adminMode) context.SetContext("adminMode", adminMode)
context.SetContext("consumer", consumer) context.SetContext("consumer", consumer)
@@ -153,7 +154,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
if chatMode == ChatModeAdmin { if chatMode == ChatModeAdmin {
// query quota // query quota
if adminMode == AdminModeQuery { if adminMode == AdminModeQuery {
return queryQuota(context, config, consumer, path, log) return queryQuota(context, config, consumer, path)
} }
if adminMode == AdminModeRefresh || adminMode == AdminModeDelta { if adminMode == AdminModeRefresh || adminMode == AdminModeDelta {
context.BufferRequestBody() context.BufferRequestBody()
@@ -186,7 +187,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l
return types.HeaderStopAllIterationAndWatermark return types.HeaderStopAllIterationAndWatermark
} }
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log log.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte) types.Action {
log.Debugf("onHttpRequestBody()") log.Debugf("onHttpRequestBody()")
chatMode, ok := ctx.GetContext("chatMode").(ChatMode) chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok { if !ok {
@@ -205,16 +206,16 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte,
} }
if adminMode == AdminModeRefresh { if adminMode == AdminModeRefresh {
return refreshQuota(ctx, config, adminConsumer, string(body), log) return refreshQuota(ctx, config, adminConsumer, string(body))
} }
if adminMode == AdminModeDelta { if adminMode == AdminModeDelta {
return deltaQuota(ctx, config, adminConsumer, string(body), log) return deltaQuota(ctx, config, adminConsumer, string(body))
} }
return types.ActionContinue return types.ActionContinue
} }
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log log.Log) []byte { func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool) []byte {
chatMode, ok := ctx.GetContext("chatMode").(ChatMode) chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok { if !ok {
return data return data
@@ -222,11 +223,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
if chatMode == ChatModeNone || chatMode == ChatModeAdmin { if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
return data return data
} }
var inputToken, outputToken int64 if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
var consumer string ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
if inputToken, outputToken, ok := getUsage(data); ok { ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
} }
// chat completion mode // chat completion mode
@@ -234,39 +233,19 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
return data return data
} }
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil { if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil || ctx.GetContext("consumer") == nil {
return data return data
} }
inputToken = ctx.GetContext("input_token").(int64) inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
outputToken = ctx.GetContext("output_token").(int64) outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
consumer = ctx.GetContext("consumer").(string) consumer := ctx.GetContext("consumer").(string)
totalToken := int(inputToken + outputToken) totalToken := int(inputToken + outputToken)
log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
return data return data
} }
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func deniedNoKeyAuthData() types.Action { func deniedNoKeyAuthData() types.Action {
util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.") util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
return types.ActionContinue return types.ActionContinue
@@ -277,7 +256,7 @@ func deniedUnauthorizedConsumer() types.Action {
return types.ActionContinue return types.ActionContinue
} }
func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, AdminMode) { func getOperationMode(path string, adminPath string) (ChatMode, AdminMode) {
fullAdminPath := "/v1/chat/completions" + adminPath fullAdminPath := "/v1/chat/completions" + adminPath
if strings.HasSuffix(path, fullAdminPath+"/refresh") { if strings.HasSuffix(path, fullAdminPath+"/refresh") {
return ChatModeAdmin, AdminModeRefresh return ChatModeAdmin, AdminModeRefresh
@@ -294,7 +273,7 @@ func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, Adm
return ChatModeNone, AdminModeNone return ChatModeNone, AdminModeNone
} }
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action { func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action {
// check consumer // check consumer
if adminConsumer != config.AdminConsumer { if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
@@ -328,7 +307,8 @@ func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer str
return types.ActionPause return types.ActionPause
} }
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log log.Log) types.Action {
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL) types.Action {
// check consumer // check consumer
if adminConsumer != config.AdminConsumer { if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
@@ -371,7 +351,8 @@ func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer strin
} }
return types.ActionPause return types.ActionPause
} }
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action {
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action {
// check consumer // check consumer
if adminConsumer != config.AdminConsumer { if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")

View File

@@ -7,7 +7,7 @@ toolchain go1.24.4
require ( require (
github.com/antchfx/xmlquery v1.4.4 github.com/antchfx/xmlquery v1.4.4
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
) )
@@ -19,6 +19,6 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect github.com/tidwall/resp v0.1.1 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.23.0 // indirect
) )

View File

@@ -11,10 +11,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
@@ -51,8 +49,9 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -88,8 +87,9 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@@ -15,7 +15,6 @@
package main package main
import ( import (
"bytes"
_ "embed" _ "embed"
"errors" "errors"
"fmt" "fmt"
@@ -27,11 +26,10 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
@@ -86,16 +84,16 @@ func main() {}
func init() { func init() {
wrapper.SetCtx( wrapper.SetCtx(
"ai-search", "ai-search",
wrapper.ParseConfigBy(parseConfig), wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody), wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody), wrapper.ProcessResponseBody(onHttpResponseBody),
) )
} }
func parseConfig(json gjson.Result, config *Config, log log.Log) error { func parseConfig(json gjson.Result, config *Config) error {
config.defaultEnable = true // Default to true if not specified config.defaultEnable = true // Default to true if not specified
if json.Get("defaultEnable").Exists() { if json.Get("defaultEnable").Exists() {
config.defaultEnable = json.Get("defaultEnable").Bool() config.defaultEnable = json.Get("defaultEnable").Bool()
@@ -279,7 +277,7 @@ func parseConfig(json gjson.Result, config *Config, log log.Log) error {
return nil return nil
} }
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type") contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
// The request does not have a body. // The request does not have a body.
if contentType == "" { if contentType == "" {
@@ -296,7 +294,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) t
return types.ActionContinue return types.ActionContinue
} }
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
// Check if plugin should be enabled based on config and request // Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options") webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable { if !config.defaultEnable {
@@ -437,7 +435,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
proxywasm.ResumeHttpRequest() proxywasm.ResumeHttpRequest()
return return
} }
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) { if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts) {
proxywasm.ResumeHttpRequest() proxywasm.ResumeHttpRequest()
} }
}, searchRewrite.timeoutMillisecond) }, searchRewrite.timeoutMillisecond)
@@ -453,10 +451,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{ return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query}, Querys: []string{query},
Language: config.defaultLanguage, Language: config.defaultLanguage,
}}, log) }})
} }
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action { func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext) types.Action {
searchResultGroups := make([][]engine.SearchResult, len(config.engine)) searchResultGroups := make([][]engine.SearchResult, len(config.engine))
var finished int var finished int
var searching int var searching int
@@ -559,7 +557,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body
return types.ActionContinue return types.ActionContinue
} }
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
if !config.needReference { if !config.needReference {
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
@@ -576,7 +574,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log)
return types.ActionContinue return types.ActionContinue
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
references := ctx.GetStringContext("References", "") references := ctx.GetStringContext("References", "")
if references == "" { if references == "" {
return types.ActionContinue return types.ActionContinue
@@ -618,19 +616,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
return types.ActionContinue return types.ActionContinue
} }
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
const ( const (
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent" BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
BUFFER_SIZE = 30 BUFFER_SIZE = 30
) )
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte { func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) { if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk return chunk
} }
@@ -638,7 +630,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
if references == "" { if references == "" {
return chunk return chunk
} }
chunk = unifySSEChunk(chunk) chunk = wrapper.UnifySSEChunk(chunk)
var partialMessage []byte var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY)) log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
@@ -651,7 +643,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
var newMessages []string var newMessages []string
for i, msg := range messages { for i, msg := range messages {
if i < len(messages)-1 { if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log) newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail")
if newMsg != "" { if newMsg != "" {
newMessages = append(newMessages, newMsg) newMessages = append(newMessages, newMsg)
} }
@@ -669,7 +661,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
} }
} }
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string { func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool) string {
log.Debugf("single sse message: %s", sseMessage) log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n") subMessages := strings.Split(sseMessage, "\n")
var message string var message string

View File

@@ -6,7 +6,7 @@ toolchain go1.24.4
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
) )

View File

@@ -8,6 +8,8 @@ github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxX
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I= github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I=
github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -5,12 +5,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"regexp"
"strings" "strings"
"time" "time"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -20,12 +22,12 @@ func main() {}
func init() { func init() {
wrapper.SetCtx( wrapper.SetCtx(
"ai-statistics", "ai-statistics",
wrapper.ParseConfigBy(parseConfig), wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody), wrapper.ProcessResponseBody(onHttpResponseBody),
) )
} }
@@ -41,6 +43,7 @@ const (
ClusterName = "cluster" ClusterName = "cluster"
APIName = "api" APIName = "api"
ConsumerKey = "x-mse-consumer" ConsumerKey = "x-mse-consumer"
RequestPath = "request_path"
// Source Type // Source Type
FixedValue = "fixed_value" FixedValue = "fixed_value"
@@ -51,9 +54,6 @@ const (
ResponseBody = "response_body" ResponseBody = "response_body"
// Inner metric & log attributes // Inner metric & log attributes
Model = "model"
InputToken = "input_token"
OutputToken = "output_token"
LLMFirstTokenDuration = "llm_first_token_duration" LLMFirstTokenDuration = "llm_first_token_duration"
LLMServiceDuration = "llm_service_duration" LLMServiceDuration = "llm_service_duration"
LLMDurationCount = "llm_duration_count" LLMDurationCount = "llm_duration_count"
@@ -146,7 +146,7 @@ func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64
counter.Increment(inc) counter.Increment(inc)
} }
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Log) error { func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
// Parse tracing span attributes setting. // Parse tracing span attributes setting.
attributeConfigs := configJson.Get("attributes").Array() attributeConfigs := configJson.Get("attributes").Array()
config.attributes = make([]Attribute, len(attributeConfigs)) config.attributes = make([]Attribute, len(attributeConfigs))
@@ -174,17 +174,20 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Lo
return nil return nil
} }
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action {
route, _ := getRouteName() route, _ := getRouteName()
cluster, _ := getClusterName() cluster, _ := getClusterName()
api, api_error := getAPIName() api, apiError := getAPIName()
if api_error == nil { if apiError == nil {
route = api route = api
} }
ctx.SetContext(RouteName, route) ctx.SetContext(RouteName, route)
ctx.SetContext(ClusterName, cluster) ctx.SetContext(ClusterName, cluster)
ctx.SetUserAttribute(APIName, api) ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
if requestPath, _ := proxywasm.GetHttpRequestHeader(":path"); requestPath != "" {
ctx.SetContext(RequestPath, requestPath)
}
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" { if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
ctx.SetContext(ConsumerKey, consumer) ctx.SetContext(ConsumerKey, consumer)
} }
@@ -195,56 +198,71 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
} }
// Set user defined log & span attributes which type is fixed_value // Set user defined log & span attributes which type is fixed_value
setAttributeBySource(ctx, config, FixedValue, nil, log) setAttributeBySource(ctx, config, FixedValue, nil)
// Set user defined log & span attributes which type is request_header // Set user defined log & span attributes which type is request_header
setAttributeBySource(ctx, config, RequestHeader, nil, log) setAttributeBySource(ctx, config, RequestHeader, nil)
// Set span attributes for ARMS. // Set span attributes for ARMS.
setSpanAttribute(ArmsSpanKind, "LLM", log) setSpanAttribute(ArmsSpanKind, "LLM")
return types.ActionContinue return types.ActionContinue
} }
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action {
// Set user defined log & span attributes. // Set user defined log & span attributes.
setAttributeBySource(ctx, config, RequestBody, body, log) setAttributeBySource(ctx, config, RequestBody, body)
// Set span attributes for ARMS. // Set span attributes for ARMS.
requestModel := gjson.GetBytes(body, "model").String() requestModel := "UNKNOWN"
if requestModel == "" { if model := gjson.GetBytes(body, "model"); model.Exists() {
requestModel = "UNKNOWN" requestModel = model.String()
} else {
requestPath := ctx.GetStringContext(RequestPath, "")
if strings.Contains(requestPath, "generateContent") || strings.Contains(requestPath, "streamGenerateContent") { // Google Gemini GenerateContent
reg := regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):\w+Content$`)
matches := reg.FindStringSubmatch(requestPath)
if len(matches) == 3 {
requestModel = matches[2]
} }
setSpanAttribute(ArmsRequestModel, requestModel, log) }
}
setSpanAttribute(ArmsRequestModel, requestModel)
// Set the number of conversation rounds // Set the number of conversation rounds
if gjson.GetBytes(body, "messages").Exists() {
userPromptCount := 0 userPromptCount := 0
for _, msg := range gjson.GetBytes(body, "messages").Array() { if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" { if msg.Get("role").String() == "user" {
userPromptCount += 1 userPromptCount += 1
} }
} }
ctx.SetUserAttribute(ChatRound, userPromptCount) } else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() { // Google Gemini GenerateContent
for _, content := range contents.Array() {
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
userPromptCount += 1
} }
}
}
ctx.SetUserAttribute(ChatRound, userPromptCount)
// Write log // Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return types.ActionContinue return types.ActionContinue
} }
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type") contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if !strings.Contains(contentType, "text/event-stream") { if !strings.Contains(contentType, "text/event-stream") {
ctx.BufferResponseBody() ctx.BufferResponseBody()
} }
// Set user defined log & span attributes. // Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseHeader, nil, log) setAttributeBySource(ctx, config, ResponseHeader, nil)
return types.ActionContinue return types.ActionContinue
} }
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log log.Log) []byte { func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool) []byte {
// Buffer stream body for record log & span attributes // Buffer stream body for record log & span attributes
if config.shouldBufferStreamingBody { if config.shouldBufferStreamingBody {
var streamingBodyBuffer []byte
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte) streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
if !ok { if !ok {
streamingBodyBuffer = data streamingBodyBuffer = data
@@ -255,9 +273,13 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
} }
ctx.SetUserAttribute(ResponseType, "stream") ctx.SetUserAttribute(ResponseType, "stream")
chatID := gjson.GetBytes(data, "id").String() if chatID := wrapper.GetValueFromBody(data, []string{
if chatID != "" { "id",
ctx.SetUserAttribute(ChatID, chatID) "response.id",
"responseId", // Gemini generateContent
"message.id", // anthropic messages
}); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String())
} }
// Get requestStartTime from http context // Get requestStartTime from http context
@@ -276,15 +298,12 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
// Set information about this request // Set information about this request
if !config.disableOpenaiUsage { if !config.disableOpenaiUsage {
if model, inputToken, outputToken, ok := getUsage(data); ok { if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
ctx.SetUserAttribute(Model, model)
ctx.SetUserAttribute(InputToken, inputToken)
ctx.SetUserAttribute(OutputToken, outputToken)
// Set span attributes for ARMS. // Set span attributes for ARMS.
setSpanAttribute(ArmsModelName, model, log) setSpanAttribute(ArmsTotalToken, usage.TotalToken)
setSpanAttribute(ArmsInputToken, inputToken, log) setSpanAttribute(ArmsModelName, usage.Model)
setSpanAttribute(ArmsOutputToken, outputToken, log) setSpanAttribute(ArmsInputToken, usage.InputToken)
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) setSpanAttribute(ArmsOutputToken, usage.OutputToken)
} }
} }
// If the end of the stream is reached, record metrics/logs/spans. // If the end of the stream is reached, record metrics/logs/spans.
@@ -298,19 +317,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
if !ok { if !ok {
return data return data
} }
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
} }
// Write log // Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics // Write metrics
writeMetric(ctx, config, log) writeMetric(ctx, config)
} }
return data return data
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action {
// Get requestStartTime from http context // Get requestStartTime from http context
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
@@ -318,74 +337,41 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
ctx.SetUserAttribute(ResponseType, "normal") ctx.SetUserAttribute(ResponseType, "normal")
chatID := gjson.GetBytes(body, "id").String() if chatID := wrapper.GetValueFromBody(body, []string{
if chatID != "" { "id",
ctx.SetUserAttribute(ChatID, chatID) "response.id",
"responseId", // Gemini generateContent
"message.id", // anthropic messages
}); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String())
} }
// Set information about this request // Set information about this request
if !config.disableOpenaiUsage { if !config.disableOpenaiUsage {
if model, inputToken, outputToken, ok := getUsage(body); ok { if usage := tokenusage.GetTokenUsage(ctx, body); usage.TotalToken > 0 {
ctx.SetUserAttribute(Model, model)
ctx.SetUserAttribute(InputToken, inputToken)
ctx.SetUserAttribute(OutputToken, outputToken)
// Set span attributes for ARMS. // Set span attributes for ARMS.
setSpanAttribute(ArmsModelName, model, log) setSpanAttribute(ArmsModelName, usage.Model)
setSpanAttribute(ArmsInputToken, inputToken, log) setSpanAttribute(ArmsInputToken, usage.InputToken)
setSpanAttribute(ArmsOutputToken, outputToken, log) setSpanAttribute(ArmsOutputToken, usage.OutputToken)
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) setSpanAttribute(ArmsTotalToken, usage.TotalToken)
} }
} }
// Set user defined log & span attributes. // Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseBody, body, log) setAttributeBySource(ctx, config, ResponseBody, body)
// Write log // Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics // Write metrics
writeMetric(ctx, config, log) writeMetric(ctx, config)
return types.ActionContinue return types.ActionContinue
} }
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) {
continue
}
if !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
modelObj := gjson.GetBytes(chunk, "model")
if modelObj.Exists() {
model = modelObj.String()
} else {
model = "unknown"
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
// fetches the tracing span value from the specified source. // fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log log.Log) {
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte) {
for _, attribute := range config.attributes { for _, attribute := range config.attributes {
var key string var key string
var value interface{} var value interface{}
@@ -401,7 +387,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
case ResponseHeader: case ResponseHeader:
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value) value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
case ResponseStreamingBody: case ResponseStreamingBody:
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule)
case ResponseBody: case ResponseBody:
value = gjson.GetBytes(body, attribute.Value).Value() value = gjson.GetBytes(body, attribute.Value).Value()
default: default:
@@ -421,21 +407,21 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
} }
} }
// for metrics // for metrics
if key == Model || key == InputToken || key == OutputToken { if key == tokenusage.CtxKeyModel || key == tokenusage.CtxKeyInputToken || key == tokenusage.CtxKeyOutputToken || key == tokenusage.CtxKeyTotalToken {
ctx.SetContext(key, value) ctx.SetContext(key, value)
} }
if attribute.ApplyToSpan { if attribute.ApplyToSpan {
if attribute.TraceSpanKey != "" { if attribute.TraceSpanKey != "" {
key = attribute.TraceSpanKey key = attribute.TraceSpanKey
} }
setSpanAttribute(key, value, log) setSpanAttribute(key, value)
} }
} }
} }
} }
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log log.Log) interface{} { func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string) interface{} {
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
var value interface{} var value interface{}
if rule == RuleFirst { if rule == RuleFirst {
for _, chunk := range chunks { for _, chunk := range chunks {
@@ -469,7 +455,7 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
} }
// Set the tracing span with value. // Set the tracing span with value.
func setSpanAttribute(key string, value interface{}, log log.Log) { func setSpanAttribute(key string, value interface{}) {
if value != "" { if value != "" {
traceSpanTag := wrapper.TraceSpanTagPrefix + key traceSpanTag := wrapper.TraceSpanTagPrefix + key
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil { if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
@@ -480,11 +466,10 @@ func setSpanAttribute(key string, value interface{}, log log.Log) {
} }
} }
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) { func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig) {
// Generate usage metrics // Generate usage metrics
var ok bool var ok bool
var route, cluster, model string var route, cluster, model string
var inputToken, outputToken uint64
consumer := ctx.GetStringContext(ConsumerKey, "none") consumer := ctx.GetStringContext(ConsumerKey, "none")
route, ok = ctx.GetContext(RouteName).(string) route, ok = ctx.GetContext(RouteName).(string)
if !ok { if !ok {
@@ -501,31 +486,30 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log
return return
} }
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { if ctx.GetUserAttribute(tokenusage.CtxKeyModel) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyInputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken) == nil {
log.Warnf("get usage information failed, skip metric record") log.Warnf("get usage information failed, skip metric record")
return return
} }
model, ok = ctx.GetUserAttribute(Model).(string) model, ok = ctx.GetUserAttribute(tokenusage.CtxKeyModel).(string)
if !ok { if !ok {
log.Warnf("Model typd assert failed, skip metric record") log.Warnf("Model typd assert failed, skip metric record")
return return
} }
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken)) if inputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyInputToken)); ok {
if !ok { config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyInputToken), inputToken)
} else {
log.Warnf("InputToken typd assert failed, skip metric record") log.Warnf("InputToken typd assert failed, skip metric record")
return
} }
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken)) if outputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken)); ok {
if !ok { config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyOutputToken), outputToken)
} else {
log.Warnf("OutputToken typd assert failed, skip metric record") log.Warnf("OutputToken typd assert failed, skip metric record")
return
} }
if inputToken == 0 || outputToken == 0 { if totalToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken)); ok {
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyTotalToken), totalToken)
return } else {
log.Warnf("TotalToken typd assert failed, skip metric record")
} }
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
// Generate duration metrics // Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration uint64 var llmFirstTokenDuration, llmServiceDuration uint64

View File

@@ -3,9 +3,8 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
re "regexp" re "regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
@@ -86,7 +85,7 @@ type LimitConfigItem struct {
timeWindow int64 // 时间窗口大小 timeWindow int64 // 时间窗口大小
} }
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error { func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
redisConfig := json.Get("redis") redisConfig := json.Get("redis")
if !redisConfig.Exists() { if !redisConfig.Exists() {
return errors.New("missing redis in config") return errors.New("missing redis in config")

View File

@@ -6,16 +6,14 @@ toolchain go1.24.4
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
) )
require github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
require ( require (
github.com/google/uuid v1.6.0 // indirect github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect
github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1
) )

View File

@@ -1,17 +1,12 @@
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw= github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4= github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -24,4 +19,3 @@ github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYg
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M= github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M=
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw= github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -15,7 +15,6 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -25,6 +24,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/tokenusage"
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -35,9 +35,9 @@ func main() {}
func init() { func init() {
wrapper.SetCtx( wrapper.SetCtx(
"ai-token-ratelimit", "ai-token-ratelimit",
wrapper.ParseConfigBy(parseConfig), wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), wrapper.ProcessStreamingResponseBody(onHttpStreamingBody),
) )
} }
@@ -84,8 +84,8 @@ type LimitRedisContext struct {
window int64 window int64
} }
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error { func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error {
err := initRedisClusterClient(json, config, log) err := initRedisClusterClient(json, config)
if err != nil { if err != nil {
return err return err
} }
@@ -98,9 +98,9 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.L
return nil return nil
} }
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action {
// 判断是否命中限流规则 // 判断是否命中限流规则
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log) val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems)
if ruleItem == nil || configItem == nil { if ruleItem == nil || configItem == nil {
return types.ActionContinue return types.ActionContinue
} }
@@ -146,18 +146,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
return types.ActionPause return types.ActionPause
} }
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte { func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte {
var inputToken, outputToken int64 if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 {
if inputToken, outputToken, ok := getUsage(data); ok { ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken)
ctx.SetContext("input_token", inputToken) ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken)
ctx.SetContext("output_token", outputToken)
} }
if endOfStream { if endOfStream {
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil { if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil {
return data return data
} }
inputToken = ctx.GetContext("input_token").(int64) inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64)
outputToken = ctx.GetContext("output_token").(int64) outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64)
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok { if !ok {
return data return data
@@ -172,29 +171,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf
return data return data
} }
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
for _, rule := range ruleItems { for _, rule := range ruleItems {
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log) val, ruleItem, configItem := hitRateRuleItem(ctx, rule)
if ruleItem != nil && configItem != nil { if ruleItem != nil && configItem != nil {
return val, ruleItem, configItem return val, ruleItem, configItem
} }
@@ -202,46 +181,46 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
return "", nil, nil return "", nil, nil
} }
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) { func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) {
switch rule.limitType { switch rule.limitType {
// 根据HTTP请求头限流 // 根据HTTP请求头限流
case limitByHeaderType, limitByPerHeaderType: case limitByHeaderType, limitByPerHeaderType:
val, err := proxywasm.GetHttpRequestHeader(rule.key) val, err := proxywasm.GetHttpRequestHeader(rule.key)
if err != nil { if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err) return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err)
} }
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据HTTP请求参数限流 // 根据HTTP请求参数限流
case limitByParamType, limitByPerParamType: case limitByParamType, limitByPerParamType:
parse, err := url.Parse(ctx.Path()) parse, err := url.Parse(ctx.Path())
if err != nil { if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err) return logDebugAndReturnEmpty("failed to parse request path: %v", err)
} }
query, err := url.ParseQuery(parse.RawQuery) query, err := url.ParseQuery(parse.RawQuery)
if err != nil { if err != nil {
return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err) return logDebugAndReturnEmpty("failed to parse query params: %v", err)
} }
val, ok := query[rule.key] val, ok := query[rule.key]
if !ok { if !ok {
return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key) return logDebugAndReturnEmpty("request param %s is empty", rule.key)
} }
return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0]) return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0])
// 根据consumer限流 // 根据consumer限流
case limitByConsumerType, limitByPerConsumerType: case limitByConsumerType, limitByPerConsumerType:
val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader) val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader)
if err != nil { if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err) return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err)
} }
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据cookie中key值限流 // 根据cookie中key值限流
case limitByCookieType, limitByPerCookieType: case limitByCookieType, limitByPerCookieType:
cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader) cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader)
if err != nil { if err != nil {
return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err) return logDebugAndReturnEmpty("failed to get request cookie : %v", err)
} }
val := extractCookieValueByKey(cookie, rule.key) val := extractCookieValueByKey(cookie, rule.key)
if val == "" { if val == "" {
return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie) return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie)
} }
return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val)
// 根据客户端IP限流 // 根据客户端IP限流
@@ -261,7 +240,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (
return "", nil, nil return "", nil, nil
} }
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) { func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
log.Debugf(errMsg, args...) log.Debugf(errMsg, args...)
return "", nil, nil return "", nil, nil
} }