diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 380e0d5bc..fa5f80879 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "fmt" "strings" @@ -52,79 +53,66 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, l return types.ActionContinue } -func getLastChunk(data []byte) []byte { - chunks := strings.Split(strings.TrimSpace(string(data)), "\n\n") - length := len(chunks) - if length < 2 { +func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { + model, inputToken, outputToken, ok := getUsage(data) + if !ok { return data } - // ai-proxy append extra usage chunk - return []byte(chunks[length-1]) -} - -func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { - lastChunk := getLastChunk(data) - modelObj := gjson.GetBytes(lastChunk, "model") - inputTokenObj := gjson.GetBytes(lastChunk, "usage.prompt_tokens") - outputTokenObj := gjson.GetBytes(lastChunk, "usage.completion_tokens") - if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() { - ctx.SetContext("model", modelObj.String()) - ctx.SetContext("input_token", inputTokenObj.Int()) - ctx.SetContext("output_token", outputTokenObj.Int()) - } - - if endOfStream { - var route, cluster string - if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil { - route = string(raw) - } - if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err == nil { - cluster = string(raw) - } - model, ok := ctx.GetContext("model").(string) - if !ok { - log.Error("Get model failed!") - return data - } - inputToken, ok := ctx.GetContext("input_token").(int64) - if !ok { - log.Error("Get input_token failed!") - return data - } - outputToken, ok := ctx.GetContext("output_token").(int64) - if !ok { - log.Error("Get output_token failed!") - return data - } - config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log) - config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log) - proxywasm.SetProperty([]string{"model"}, []byte(model)) - proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken))) - proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken))) - } - + setFilterStateData(model, inputToken, outputToken, log) + incrementCounter(config, model, inputToken, outputToken, log) return data } func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { - modeObj := gjson.GetBytes(body, "model") - inputTokenObj := gjson.GetBytes(body, "usage.prompt_tokens") - outputTokenObj := gjson.GetBytes(body, "usage.completion_tokens") - if !modeObj.Exists() { - log.Error("Get model failed") + model, inputToken, outputToken, ok := getUsage(body) + if !ok { return types.ActionContinue } - if !inputTokenObj.Exists() { - log.Error("Get input_token failed") - return types.ActionContinue + setFilterStateData(model, inputToken, outputToken, log) + incrementCounter(config, model, inputToken, outputToken, log) + return types.ActionContinue +} + +func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) { + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + for _, chunk := range chunks { + // the feature strings are used to identify the usage data, like: + // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} + if !bytes.Contains(chunk, []byte("prompt_tokens")) { + continue + } + if !bytes.Contains(chunk, []byte("completion_tokens")) { + continue + } + modelObj := gjson.GetBytes(chunk, "model") + inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") + outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") + if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() { + model = modelObj.String() + inputTokenUsage = inputTokenObj.Int() + outputTokenUsage = outputTokenObj.Int() + ok = true + return + } } - if !outputTokenObj.Exists() { - log.Error("Get output_token failed") - return types.ActionContinue + return +} + +// setFilterData sets the input_token and output_token in the filter state. +// ai-token-ratelimit will use these values to calculate the total token usage. +func setFilterStateData(model string, inputToken int64, outputToken int64, log wrapper.Log) { + if e := proxywasm.SetProperty([]string{"model"}, []byte(model)); e != nil { + log.Errorf("failed to set model in filter state: %v", e) } - model := modeObj.String() - inputToken := inputTokenObj.Int() - outputToken := outputTokenObj.Int() + if e := proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprintf("%d", inputToken))); e != nil { + log.Errorf("failed to set input_token in filter state: %v", e) + } + if e := proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprintf("%d", outputToken))); e != nil { + log.Errorf("failed to set output_token in filter state: %v", e) + } +} + +func incrementCounter(config AIStatisticsConfig, model string, inputToken int64, outputToken int64, log wrapper.Log) { var route, cluster string if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil { route = string(raw) @@ -134,10 +122,4 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body } config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log) config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log) - - proxywasm.SetProperty([]string{"model"}, []byte(model)) - proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken))) - proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken))) - - return types.ActionContinue }