diff --git a/plugins/wasm-go/extensions/ai-quota/go.mod b/plugins/wasm-go/extensions/ai-quota/go.mod index bd99507de..f44c8dc4e 100644 --- a/plugins/wasm-go/extensions/ai-quota/go.mod +++ b/plugins/wasm-go/extensions/ai-quota/go.mod @@ -6,7 +6,7 @@ toolchain go1.24.4 require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.0 + github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa github.com/tidwall/gjson v1.18.0 github.com/tidwall/resp v0.1.1 ) diff --git a/plugins/wasm-go/extensions/ai-quota/go.sum b/plugins/wasm-go/extensions/ai-quota/go.sum index bc44cf8f0..02ef6bc0d 100644 --- a/plugins/wasm-go/extensions/ai-quota/go.sum +++ b/plugins/wasm-go/extensions/ai-quota/go.sum @@ -6,6 +6,8 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/plugins/wasm-go/extensions/ai-quota/main.go b/plugins/wasm-go/extensions/ai-quota/main.go index 10021ce13..1a1edcfc6 100644 --- a/plugins/wasm-go/extensions/ai-quota/main.go +++ b/plugins/wasm-go/extensions/ai-quota/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "encoding/json" "errors" "fmt" @@ -10,13 +9,15 @@ import ( "strconv" "strings" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/tokenusage" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/resp" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util" ) const ( @@ -45,10 +46,10 @@ func main() {} func init() { wrapper.SetCtx( pluginName, - wrapper.ParseConfigBy(parseConfig), - wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), - wrapper.ProcessRequestBodyBy(onHttpRequestBody), - wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingResponseBody), + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody), ) } @@ -75,7 +76,7 @@ type RedisInfo struct { Database int `required:"false" yaml:"database" json:"database"` } -func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error { +func parseConfig(json gjson.Result, config *QuotaConfig) error { log.Debugf("parse config()") // admin config.AdminPath = json.Get("admin_path").String() @@ -129,7 +130,7 @@ func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error { return config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database)) } -func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log log.Log) types.Action { +func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig) types.Action { log.Debugf("onHttpRequestHeaders()") // get tokens consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer") @@ -142,7 +143,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l rawPath := context.Path() path, _ := url.Parse(rawPath) - chatMode, adminMode := getOperationMode(path.Path, config.AdminPath, log) + chatMode, adminMode := getOperationMode(path.Path, config.AdminPath) context.SetContext("chatMode", chatMode) context.SetContext("adminMode", adminMode) context.SetContext("consumer", consumer) @@ -153,7 +154,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l if chatMode == ChatModeAdmin { // query quota if adminMode == AdminModeQuery { - return queryQuota(context, config, consumer, path, log) + return queryQuota(context, config, consumer, path) } if adminMode == AdminModeRefresh || adminMode == AdminModeDelta { context.BufferRequestBody() @@ -186,7 +187,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log l return types.HeaderStopAllIterationAndWatermark } -func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log log.Log) types.Action { +func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte) types.Action { log.Debugf("onHttpRequestBody()") chatMode, ok := ctx.GetContext("chatMode").(ChatMode) if !ok { @@ -205,16 +206,16 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, } if adminMode == AdminModeRefresh { - return refreshQuota(ctx, config, adminConsumer, string(body), log) + return refreshQuota(ctx, config, adminConsumer, string(body)) } if adminMode == AdminModeDelta { - return deltaQuota(ctx, config, adminConsumer, string(body), log) + return deltaQuota(ctx, config, adminConsumer, string(body)) } return types.ActionContinue } -func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log log.Log) []byte { +func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool) []byte { chatMode, ok := ctx.GetContext("chatMode").(ChatMode) if !ok { return data @@ -222,11 +223,9 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da if chatMode == ChatModeNone || chatMode == ChatModeAdmin { return data } - var inputToken, outputToken int64 - var consumer string - if inputToken, outputToken, ok := getUsage(data); ok { - ctx.SetContext("input_token", inputToken) - ctx.SetContext("output_token", outputToken) + if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 { + ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken) + ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken) } // chat completion mode @@ -234,39 +233,19 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da return data } - if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil { + if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil || ctx.GetContext("consumer") == nil { return data } - inputToken = ctx.GetContext("input_token").(int64) - outputToken = ctx.GetContext("output_token").(int64) - consumer = ctx.GetContext("consumer").(string) + inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64) + outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64) + consumer := ctx.GetContext("consumer").(string) totalToken := int(inputToken + outputToken) log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) return data } -func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { - chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) - for _, chunk := range chunks { - // the feature strings are used to identify the usage data, like: - // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} - if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { - continue - } - inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") - outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") - if inputTokenObj.Exists() && outputTokenObj.Exists() { - inputTokenUsage = inputTokenObj.Int() - outputTokenUsage = outputTokenObj.Int() - ok = true - return - } - } - return -} - func deniedNoKeyAuthData() types.Action { util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.") return types.ActionContinue @@ -277,7 +256,7 @@ func deniedUnauthorizedConsumer() types.Action { return types.ActionContinue } -func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, AdminMode) { +func getOperationMode(path string, adminPath string) (ChatMode, AdminMode) { fullAdminPath := "/v1/chat/completions" + adminPath if strings.HasSuffix(path, fullAdminPath+"/refresh") { return ChatModeAdmin, AdminModeRefresh @@ -294,7 +273,7 @@ func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, Adm return ChatModeNone, AdminModeNone } -func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action { +func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action { // check consumer if adminConsumer != config.AdminConsumer { util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") @@ -328,7 +307,8 @@ func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer str return types.ActionPause } -func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log log.Log) types.Action { + +func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL) types.Action { // check consumer if adminConsumer != config.AdminConsumer { util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") @@ -371,7 +351,8 @@ func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer strin } return types.ActionPause } -func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action { + +func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string) types.Action { // check consumer if adminConsumer != config.AdminConsumer { util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.") diff --git a/plugins/wasm-go/extensions/ai-search/go.mod b/plugins/wasm-go/extensions/ai-search/go.mod index 1eb193c73..00b0f52ca 100644 --- a/plugins/wasm-go/extensions/ai-search/go.mod +++ b/plugins/wasm-go/extensions/ai-search/go.mod @@ -7,7 +7,7 @@ toolchain go1.24.4 require ( github.com/antchfx/xmlquery v1.4.4 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 + github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 ) @@ -19,6 +19,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect - golang.org/x/net v0.33.0 // indirect - golang.org/x/text v0.21.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/text v0.23.0 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-search/go.sum b/plugins/wasm-go/extensions/ai-search/go.sum index a6a24979c..ff21cab0b 100644 --- a/plugins/wasm-go/extensions/ai-search/go.sum +++ b/plugins/wasm-go/extensions/ai-search/go.sum @@ -11,10 +11,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= -github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= -github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I= -github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= @@ -51,8 +49,9 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -88,8 +87,9 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/plugins/wasm-go/extensions/ai-search/main.go b/plugins/wasm-go/extensions/ai-search/main.go index 99fe8c720..04cf24963 100644 --- a/plugins/wasm-go/extensions/ai-search/main.go +++ b/plugins/wasm-go/extensions/ai-search/main.go @@ -15,7 +15,6 @@ package main import ( - "bytes" _ "embed" "errors" "fmt" @@ -27,11 +26,10 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing" @@ -86,16 +84,16 @@ func main() {} func init() { wrapper.SetCtx( "ai-search", - wrapper.ParseConfigBy(parseConfig), - wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), - wrapper.ProcessRequestBodyBy(onHttpRequestBody), - wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), - wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody), - wrapper.ProcessResponseBodyBy(onHttpResponseBody), + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.ProcessResponseHeaders(onHttpResponseHeaders), + wrapper.ProcessStreamingResponseBody(onStreamingResponseBody), + wrapper.ProcessResponseBody(onHttpResponseBody), ) } -func parseConfig(json gjson.Result, config *Config, log log.Log) error { +func parseConfig(json gjson.Result, config *Config) error { config.defaultEnable = true // Default to true if not specified if json.Get("defaultEnable").Exists() { config.defaultEnable = json.Get("defaultEnable").Bool() @@ -279,7 +277,7 @@ func parseConfig(json gjson.Result, config *Config, log log.Log) error { return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") // The request does not have a body. if contentType == "" { @@ -296,7 +294,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) t return types.ActionContinue } -func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action { +func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action { // Check if plugin should be enabled based on config and request webSearchOptions := gjson.GetBytes(body, "web_search_options") if !config.defaultEnable { @@ -437,7 +435,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log proxywasm.ResumeHttpRequest() return } - if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) { + if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts) { proxywasm.ResumeHttpRequest() } }, searchRewrite.timeoutMillisecond) @@ -453,10 +451,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{ Querys: []string{query}, Language: config.defaultLanguage, - }}, log) + }}) } -func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action { +func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext) types.Action { searchResultGroups := make([][]engine.SearchResult, len(config.engine)) var finished int var searching int @@ -559,7 +557,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body return types.ActionContinue } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action { +func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action { if !config.needReference { ctx.DontReadResponseBody() return types.ActionContinue @@ -576,7 +574,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action { +func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action { references := ctx.GetStringContext("References", "") if references == "" { return types.ActionContinue @@ -618,19 +616,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log return types.ActionContinue } -func unifySSEChunk(data []byte) []byte { - data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) - data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n")) - return data -} - const ( PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" BUFFER_CONTENT_CONTEXT_KEY = "bufferContent" BUFFER_SIZE = 30 ) -func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte { +func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool) []byte { if ctx.GetBoolContext("ReferenceAppended", false) { return chunk } @@ -638,7 +630,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt if references == "" { return chunk } - chunk = unifySSEChunk(chunk) + chunk = wrapper.UnifySSEChunk(chunk) var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY)) @@ -651,7 +643,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt var newMessages []string for i, msg := range messages { if i < len(messages)-1 { - newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log) + newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail") if newMsg != "" { newMessages = append(newMessages, newMsg) } @@ -669,7 +661,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt } } -func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string { +func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool) string { log.Debugf("single sse message: %s", sseMessage) subMessages := strings.Split(sseMessage, "\n") var message string diff --git a/plugins/wasm-go/extensions/ai-statistics/go.mod b/plugins/wasm-go/extensions/ai-statistics/go.mod index 1925145c5..845da637b 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.mod +++ b/plugins/wasm-go/extensions/ai-statistics/go.mod @@ -6,7 +6,7 @@ toolchain go1.24.4 require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 + github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa github.com/tidwall/gjson v1.18.0 ) diff --git a/plugins/wasm-go/extensions/ai-statistics/go.sum b/plugins/wasm-go/extensions/ai-statistics/go.sum index f6a3df1ee..c7936609e 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.sum +++ b/plugins/wasm-go/extensions/ai-statistics/go.sum @@ -8,6 +8,8 @@ github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxX github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802 h1:oaeYQ7bMtPL9gG2yZzxu0VXWLx5/C1RctyBbcpwG49I= github.com/higress-group/wasm-go v1.0.1-0.20250703020647-acfb94430802/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index c9c131711..db097d00b 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -5,12 +5,14 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strings" "time" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/tokenusage" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -20,12 +22,12 @@ func main() {} func init() { wrapper.SetCtx( "ai-statistics", - wrapper.ParseConfigBy(parseConfig), - wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), - wrapper.ProcessRequestBodyBy(onHttpRequestBody), - wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), - wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), - wrapper.ProcessResponseBodyBy(onHttpResponseBody), + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.ProcessResponseHeaders(onHttpResponseHeaders), + wrapper.ProcessStreamingResponseBody(onHttpStreamingBody), + wrapper.ProcessResponseBody(onHttpResponseBody), ) } @@ -41,6 +43,7 @@ const ( ClusterName = "cluster" APIName = "api" ConsumerKey = "x-mse-consumer" + RequestPath = "request_path" // Source Type FixedValue = "fixed_value" @@ -51,9 +54,6 @@ const ( ResponseBody = "response_body" // Inner metric & log attributes - Model = "model" - InputToken = "input_token" - OutputToken = "output_token" LLMFirstTokenDuration = "llm_first_token_duration" LLMServiceDuration = "llm_service_duration" LLMDurationCount = "llm_duration_count" @@ -146,7 +146,7 @@ func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64 counter.Increment(inc) } -func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Log) error { +func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error { // Parse tracing span attributes setting. attributeConfigs := configJson.Get("attributes").Array() config.attributes = make([]Attribute, len(attributeConfigs)) @@ -174,17 +174,20 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Lo return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action { route, _ := getRouteName() cluster, _ := getClusterName() - api, api_error := getAPIName() - if api_error == nil { + api, apiError := getAPIName() + if apiError == nil { route = api } ctx.SetContext(RouteName, route) ctx.SetContext(ClusterName, cluster) ctx.SetUserAttribute(APIName, api) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) + if requestPath, _ := proxywasm.GetHttpRequestHeader(":path"); requestPath != "" { + ctx.SetContext(RequestPath, requestPath) + } if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" { ctx.SetContext(ConsumerKey, consumer) } @@ -195,56 +198,71 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo } // Set user defined log & span attributes which type is fixed_value - setAttributeBySource(ctx, config, FixedValue, nil, log) + setAttributeBySource(ctx, config, FixedValue, nil) // Set user defined log & span attributes which type is request_header - setAttributeBySource(ctx, config, RequestHeader, nil, log) + setAttributeBySource(ctx, config, RequestHeader, nil) // Set span attributes for ARMS. - setSpanAttribute(ArmsSpanKind, "LLM", log) + setSpanAttribute(ArmsSpanKind, "LLM") return types.ActionContinue } -func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action { +func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action { // Set user defined log & span attributes. - setAttributeBySource(ctx, config, RequestBody, body, log) + setAttributeBySource(ctx, config, RequestBody, body) // Set span attributes for ARMS. - requestModel := gjson.GetBytes(body, "model").String() - if requestModel == "" { - requestModel = "UNKNOWN" + requestModel := "UNKNOWN" + if model := gjson.GetBytes(body, "model"); model.Exists() { + requestModel = model.String() + } else { + requestPath := ctx.GetStringContext(RequestPath, "") + if strings.Contains(requestPath, "generateContent") || strings.Contains(requestPath, "streamGenerateContent") { // Google Gemini GenerateContent + reg := regexp.MustCompile(`^.*/(?P[^/]+)/models/(?P[^:]+):\w+Content$`) + matches := reg.FindStringSubmatch(requestPath) + if len(matches) == 3 { + requestModel = matches[2] + } + } } - setSpanAttribute(ArmsRequestModel, requestModel, log) + setSpanAttribute(ArmsRequestModel, requestModel) // Set the number of conversation rounds - if gjson.GetBytes(body, "messages").Exists() { - userPromptCount := 0 - for _, msg := range gjson.GetBytes(body, "messages").Array() { + + userPromptCount := 0 + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + for _, msg := range messages.Array() { if msg.Get("role").String() == "user" { userPromptCount += 1 } } - ctx.SetUserAttribute(ChatRound, userPromptCount) + } else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() { // Google Gemini GenerateContent + for _, content := range contents.Array() { + if !content.Get("role").Exists() || content.Get("role").String() == "user" { + userPromptCount += 1 + } + } } + ctx.SetUserAttribute(ChatRound, userPromptCount) // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) return types.ActionContinue } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action { +func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if !strings.Contains(contentType, "text/event-stream") { ctx.BufferResponseBody() } // Set user defined log & span attributes. - setAttributeBySource(ctx, config, ResponseHeader, nil, log) + setAttributeBySource(ctx, config, ResponseHeader, nil) return types.ActionContinue } -func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log log.Log) []byte { +func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool) []byte { // Buffer stream body for record log & span attributes if config.shouldBufferStreamingBody { - var streamingBodyBuffer []byte streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte) if !ok { streamingBodyBuffer = data @@ -255,9 +273,13 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat } ctx.SetUserAttribute(ResponseType, "stream") - chatID := gjson.GetBytes(data, "id").String() - if chatID != "" { - ctx.SetUserAttribute(ChatID, chatID) + if chatID := wrapper.GetValueFromBody(data, []string{ + "id", + "response.id", + "responseId", // Gemini generateContent + "message.id", // anthropic messages + }); chatID != nil { + ctx.SetUserAttribute(ChatID, chatID.String()) } // Get requestStartTime from http context @@ -276,15 +298,12 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat // Set information about this request if !config.disableOpenaiUsage { - if model, inputToken, outputToken, ok := getUsage(data); ok { - ctx.SetUserAttribute(Model, model) - ctx.SetUserAttribute(InputToken, inputToken) - ctx.SetUserAttribute(OutputToken, outputToken) + if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 { // Set span attributes for ARMS. - setSpanAttribute(ArmsModelName, model, log) - setSpanAttribute(ArmsInputToken, inputToken, log) - setSpanAttribute(ArmsOutputToken, outputToken, log) - setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) + setSpanAttribute(ArmsTotalToken, usage.TotalToken) + setSpanAttribute(ArmsModelName, usage.Model) + setSpanAttribute(ArmsInputToken, usage.InputToken) + setSpanAttribute(ArmsOutputToken, usage.OutputToken) } } // If the end of the stream is reached, record metrics/logs/spans. @@ -298,19 +317,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat if !ok { return data } - setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) + setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer) } // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics - writeMetric(ctx, config, log) + writeMetric(ctx, config) } return data } -func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action { +func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte) types.Action { // Get requestStartTime from http context requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) @@ -318,74 +337,41 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) ctx.SetUserAttribute(ResponseType, "normal") - chatID := gjson.GetBytes(body, "id").String() - if chatID != "" { - ctx.SetUserAttribute(ChatID, chatID) + if chatID := wrapper.GetValueFromBody(body, []string{ + "id", + "response.id", + "responseId", // Gemini generateContent + "message.id", // anthropic messages + }); chatID != nil { + ctx.SetUserAttribute(ChatID, chatID.String()) } // Set information about this request if !config.disableOpenaiUsage { - if model, inputToken, outputToken, ok := getUsage(body); ok { - ctx.SetUserAttribute(Model, model) - ctx.SetUserAttribute(InputToken, inputToken) - ctx.SetUserAttribute(OutputToken, outputToken) + if usage := tokenusage.GetTokenUsage(ctx, body); usage.TotalToken > 0 { // Set span attributes for ARMS. - setSpanAttribute(ArmsModelName, model, log) - setSpanAttribute(ArmsInputToken, inputToken, log) - setSpanAttribute(ArmsOutputToken, outputToken, log) - setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) + setSpanAttribute(ArmsModelName, usage.Model) + setSpanAttribute(ArmsInputToken, usage.InputToken) + setSpanAttribute(ArmsOutputToken, usage.OutputToken) + setSpanAttribute(ArmsTotalToken, usage.TotalToken) } } // Set user defined log & span attributes. - setAttributeBySource(ctx, config, ResponseBody, body, log) + setAttributeBySource(ctx, config, ResponseBody, body) // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics - writeMetric(ctx, config, log) + writeMetric(ctx, config) return types.ActionContinue } -func unifySSEChunk(data []byte) []byte { - data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) - data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n")) - return data -} - -func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) { - chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) - for _, chunk := range chunks { - // the feature strings are used to identify the usage data, like: - // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} - if !bytes.Contains(chunk, []byte("prompt_tokens")) { - continue - } - if !bytes.Contains(chunk, []byte("completion_tokens")) { - continue - } - modelObj := gjson.GetBytes(chunk, "model") - if modelObj.Exists() { - model = modelObj.String() - } else { - model = "unknown" - } - inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") - outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") - if inputTokenObj.Exists() && outputTokenObj.Exists() { - inputTokenUsage = inputTokenObj.Int() - outputTokenUsage = outputTokenObj.Int() - ok = true - return - } - } - return -} - // fetches the tracing span value from the specified source. -func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log log.Log) { + +func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte) { for _, attribute := range config.attributes { var key string var value interface{} @@ -401,7 +387,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so case ResponseHeader: value, _ = proxywasm.GetHttpResponseHeader(attribute.Value) case ResponseStreamingBody: - value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) + value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule) case ResponseBody: value = gjson.GetBytes(body, attribute.Value).Value() default: @@ -421,21 +407,21 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so } } // for metrics - if key == Model || key == InputToken || key == OutputToken { + if key == tokenusage.CtxKeyModel || key == tokenusage.CtxKeyInputToken || key == tokenusage.CtxKeyOutputToken || key == tokenusage.CtxKeyTotalToken { ctx.SetContext(key, value) } if attribute.ApplyToSpan { if attribute.TraceSpanKey != "" { key = attribute.TraceSpanKey } - setSpanAttribute(key, value, log) + setSpanAttribute(key, value) } } } } -func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log log.Log) interface{} { - chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) +func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string) interface{} { + chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) var value interface{} if rule == RuleFirst { for _, chunk := range chunks { @@ -469,7 +455,7 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l } // Set the tracing span with value. -func setSpanAttribute(key string, value interface{}, log log.Log) { +func setSpanAttribute(key string, value interface{}) { if value != "" { traceSpanTag := wrapper.TraceSpanTagPrefix + key if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil { @@ -480,11 +466,10 @@ func setSpanAttribute(key string, value interface{}, log log.Log) { } } -func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) { +func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig) { // Generate usage metrics var ok bool var route, cluster, model string - var inputToken, outputToken uint64 consumer := ctx.GetStringContext(ConsumerKey, "none") route, ok = ctx.GetContext(RouteName).(string) if !ok { @@ -501,31 +486,30 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log return } - if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { + if ctx.GetUserAttribute(tokenusage.CtxKeyModel) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyInputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken) == nil || ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken) == nil { log.Warnf("get usage information failed, skip metric record") return } - model, ok = ctx.GetUserAttribute(Model).(string) + model, ok = ctx.GetUserAttribute(tokenusage.CtxKeyModel).(string) if !ok { log.Warnf("Model typd assert failed, skip metric record") return } - inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken)) - if !ok { + if inputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyInputToken)); ok { + config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyInputToken), inputToken) + } else { log.Warnf("InputToken typd assert failed, skip metric record") - return } - outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken)) - if !ok { + if outputToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyOutputToken)); ok { + config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyOutputToken), outputToken) + } else { log.Warnf("OutputToken typd assert failed, skip metric record") - return } - if inputToken == 0 || outputToken == 0 { - log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") - return + if totalToken, ok := convertToUInt(ctx.GetUserAttribute(tokenusage.CtxKeyTotalToken)); ok { + config.incrementCounter(generateMetricName(route, cluster, model, consumer, tokenusage.CtxKeyTotalToken), totalToken) + } else { + log.Warnf("TotalToken typd assert failed, skip metric record") } - config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken) - config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken) // Generate duration metrics var llmFirstTokenDuration, llmServiceDuration uint64 diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/config.go b/plugins/wasm-go/extensions/ai-token-ratelimit/config.go index 0ed598caa..8f87d952c 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/config.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/config.go @@ -3,9 +3,8 @@ package main import ( "errors" "fmt" - "strings" - re "regexp" + "strings" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/wasm-go/pkg/log" @@ -86,7 +85,7 @@ type LimitConfigItem struct { timeWindow int64 // 时间窗口大小 } -func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error { +func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig) error { redisConfig := json.Get("redis") if !redisConfig.Exists() { return errors.New("missing redis in config") diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod index e3beaacb6..280299a3e 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.mod @@ -6,16 +6,14 @@ toolchain go1.24.4 require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.0 github.com/tidwall/gjson v1.18.0 + github.com/tidwall/resp v0.1.1 github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 ) -require github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect - require ( - github.com/google/uuid v1.6.0 // indirect + github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 // indirect + github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/resp v0.1.1 ) diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum index 588c05eaa..f3e02a04e 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum @@ -1,17 +1,12 @@ github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw= github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw= -github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa h1:GnYXjsG9/nRJ4+GQeJBKS8/a28N8yAt1pUmGZcxnHd4= +github.com/higress-group/wasm-go v1.0.1-0.20250714125049-cb970b4561fa/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -24,4 +19,3 @@ github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYg github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837 h1:DjHnADS2r2zynZ3WkCFAQ+PNYngMSNceRROi0pO6c3M= github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837/go.mod h1:9vp0bxqozzQwcjBwenEXfKVq8+mYbwHkQ1NF9Ap0DMw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go index 00c0ad082..413a98c1a 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -15,7 +15,6 @@ package main import ( - "bytes" "fmt" "net" "net/url" @@ -25,6 +24,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/tokenusage" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/resp" @@ -35,9 +35,9 @@ func main() {} func init() { wrapper.SetCtx( "ai-token-ratelimit", - wrapper.ParseConfigBy(parseConfig), - wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), - wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), + wrapper.ParseConfig(parseConfig), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessStreamingResponseBody(onHttpStreamingBody), ) } @@ -84,8 +84,8 @@ type LimitRedisContext struct { window int64 } -func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error { - err := initRedisClusterClient(json, config, log) +func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig) error { + err := initRedisClusterClient(json, config) if err != nil { return err } @@ -98,9 +98,9 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.L return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig) types.Action { // 判断是否命中限流规则 - val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log) + val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems) if ruleItem == nil || configItem == nil { return types.ActionContinue } @@ -146,18 +146,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon return types.ActionPause } -func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte { - var inputToken, outputToken int64 - if inputToken, outputToken, ok := getUsage(data); ok { - ctx.SetContext("input_token", inputToken) - ctx.SetContext("output_token", outputToken) +func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool) []byte { + if usage := tokenusage.GetTokenUsage(ctx, data); usage.TotalToken > 0 { + ctx.SetContext(tokenusage.CtxKeyInputToken, usage.InputToken) + ctx.SetContext(tokenusage.CtxKeyOutputToken, usage.OutputToken) } if endOfStream { - if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil { + if ctx.GetContext(tokenusage.CtxKeyInputToken) == nil || ctx.GetContext(tokenusage.CtxKeyOutputToken) == nil { return data } - inputToken = ctx.GetContext("input_token").(int64) - outputToken = ctx.GetContext("output_token").(int64) + inputToken := ctx.GetContext(tokenusage.CtxKeyInputToken).(int64) + outputToken := ctx.GetContext(tokenusage.CtxKeyOutputToken).(int64) limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) if !ok { return data @@ -172,29 +171,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConf return data } -func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { - chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) - for _, chunk := range chunks { - // the feature strings are used to identify the usage data, like: - // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} - if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { - continue - } - inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") - outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") - if inputTokenObj.Exists() && outputTokenObj.Exists() { - inputTokenUsage = inputTokenObj.Int() - outputTokenUsage = outputTokenObj.Int() - ok = true - return - } - } - return -} - -func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) { +func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) { for _, rule := range ruleItems { - val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log) + val, ruleItem, configItem := hitRateRuleItem(ctx, rule) if ruleItem != nil && configItem != nil { return val, ruleItem, configItem } @@ -202,46 +181,46 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule return "", nil, nil } -func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) { +func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem) (string, *LimitRuleItem, *LimitConfigItem) { switch rule.limitType { // 根据HTTP请求头限流 case limitByHeaderType, limitByPerHeaderType: val, err := proxywasm.GetHttpRequestHeader(rule.key) if err != nil { - return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", rule.key, err) + return logDebugAndReturnEmpty("failed to get request header %s: %v", rule.key, err) } return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) // 根据HTTP请求参数限流 case limitByParamType, limitByPerParamType: parse, err := url.Parse(ctx.Path()) if err != nil { - return logDebugAndReturnEmpty(log, "failed to parse request path: %v", err) + return logDebugAndReturnEmpty("failed to parse request path: %v", err) } query, err := url.ParseQuery(parse.RawQuery) if err != nil { - return logDebugAndReturnEmpty(log, "failed to parse query params: %v", err) + return logDebugAndReturnEmpty("failed to parse query params: %v", err) } val, ok := query[rule.key] if !ok { - return logDebugAndReturnEmpty(log, "request param %s is empty", rule.key) + return logDebugAndReturnEmpty("request param %s is empty", rule.key) } return val[0], &rule, findMatchingItem(rule.limitType, rule.configItems, val[0]) // 根据consumer限流 case limitByConsumerType, limitByPerConsumerType: val, err := proxywasm.GetHttpRequestHeader(ConsumerHeader) if err != nil { - return logDebugAndReturnEmpty(log, "failed to get request header %s: %v", ConsumerHeader, err) + return logDebugAndReturnEmpty("failed to get request header %s: %v", ConsumerHeader, err) } return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) // 根据cookie中key值限流 case limitByCookieType, limitByPerCookieType: cookie, err := proxywasm.GetHttpRequestHeader(CookieHeader) if err != nil { - return logDebugAndReturnEmpty(log, "failed to get request cookie : %v", err) + return logDebugAndReturnEmpty("failed to get request cookie : %v", err) } val := extractCookieValueByKey(cookie, rule.key) if val == "" { - return logDebugAndReturnEmpty(log, "cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie) + return logDebugAndReturnEmpty("cookie key '%s' extracted from cookie '%s' is empty.", rule.key, cookie) } return val, &rule, findMatchingItem(rule.limitType, rule.configItems, val) // 根据客户端IP限流 @@ -261,7 +240,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) ( return "", nil, nil } -func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) { +func logDebugAndReturnEmpty(errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) { log.Debugf(errMsg, args...) return "", nil, nil }