mirror of
https://github.com/alibaba/higress.git
synced 2026-02-26 13:40:49 +08:00
fixed ai-statistics plugin statistics error (#1060)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -52,79 +53,66 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, l
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func getLastChunk(data []byte) []byte {
|
||||
chunks := strings.Split(strings.TrimSpace(string(data)), "\n\n")
|
||||
length := len(chunks)
|
||||
if length < 2 {
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
model, inputToken, outputToken, ok := getUsage(data)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
// ai-proxy append extra usage chunk
|
||||
return []byte(chunks[length-1])
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
lastChunk := getLastChunk(data)
|
||||
modelObj := gjson.GetBytes(lastChunk, "model")
|
||||
inputTokenObj := gjson.GetBytes(lastChunk, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(lastChunk, "usage.completion_tokens")
|
||||
if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() {
|
||||
ctx.SetContext("model", modelObj.String())
|
||||
ctx.SetContext("input_token", inputTokenObj.Int())
|
||||
ctx.SetContext("output_token", outputTokenObj.Int())
|
||||
}
|
||||
|
||||
if endOfStream {
|
||||
var route, cluster string
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil {
|
||||
route = string(raw)
|
||||
}
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err == nil {
|
||||
cluster = string(raw)
|
||||
}
|
||||
model, ok := ctx.GetContext("model").(string)
|
||||
if !ok {
|
||||
log.Error("Get model failed!")
|
||||
return data
|
||||
}
|
||||
inputToken, ok := ctx.GetContext("input_token").(int64)
|
||||
if !ok {
|
||||
log.Error("Get input_token failed!")
|
||||
return data
|
||||
}
|
||||
outputToken, ok := ctx.GetContext("output_token").(int64)
|
||||
if !ok {
|
||||
log.Error("Get output_token failed!")
|
||||
return data
|
||||
}
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log)
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log)
|
||||
proxywasm.SetProperty([]string{"model"}, []byte(model))
|
||||
proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken)))
|
||||
proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken)))
|
||||
}
|
||||
|
||||
setFilterStateData(model, inputToken, outputToken, log)
|
||||
incrementCounter(config, model, inputToken, outputToken, log)
|
||||
return data
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
modeObj := gjson.GetBytes(body, "model")
|
||||
inputTokenObj := gjson.GetBytes(body, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(body, "usage.completion_tokens")
|
||||
if !modeObj.Exists() {
|
||||
log.Error("Get model failed")
|
||||
model, inputToken, outputToken, ok := getUsage(body)
|
||||
if !ok {
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !inputTokenObj.Exists() {
|
||||
log.Error("Get input_token failed")
|
||||
return types.ActionContinue
|
||||
setFilterStateData(model, inputToken, outputToken, log)
|
||||
incrementCounter(config, model, inputToken, outputToken, log)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
for _, chunk := range chunks {
|
||||
// the feature strings are used to identify the usage data, like:
|
||||
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
|
||||
if !bytes.Contains(chunk, []byte("prompt_tokens")) {
|
||||
continue
|
||||
}
|
||||
if !bytes.Contains(chunk, []byte("completion_tokens")) {
|
||||
continue
|
||||
}
|
||||
modelObj := gjson.GetBytes(chunk, "model")
|
||||
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
|
||||
if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() {
|
||||
model = modelObj.String()
|
||||
inputTokenUsage = inputTokenObj.Int()
|
||||
outputTokenUsage = outputTokenObj.Int()
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
if !outputTokenObj.Exists() {
|
||||
log.Error("Get output_token failed")
|
||||
return types.ActionContinue
|
||||
return
|
||||
}
|
||||
|
||||
// setFilterData sets the input_token and output_token in the filter state.
|
||||
// ai-token-ratelimit will use these values to calculate the total token usage.
|
||||
func setFilterStateData(model string, inputToken int64, outputToken int64, log wrapper.Log) {
|
||||
if e := proxywasm.SetProperty([]string{"model"}, []byte(model)); e != nil {
|
||||
log.Errorf("failed to set model in filter state: %v", e)
|
||||
}
|
||||
model := modeObj.String()
|
||||
inputToken := inputTokenObj.Int()
|
||||
outputToken := outputTokenObj.Int()
|
||||
if e := proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprintf("%d", inputToken))); e != nil {
|
||||
log.Errorf("failed to set input_token in filter state: %v", e)
|
||||
}
|
||||
if e := proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprintf("%d", outputToken))); e != nil {
|
||||
log.Errorf("failed to set output_token in filter state: %v", e)
|
||||
}
|
||||
}
|
||||
|
||||
func incrementCounter(config AIStatisticsConfig, model string, inputToken int64, outputToken int64, log wrapper.Log) {
|
||||
var route, cluster string
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil {
|
||||
route = string(raw)
|
||||
@@ -134,10 +122,4 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
}
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log)
|
||||
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log)
|
||||
|
||||
proxywasm.SetProperty([]string{"model"}, []byte(model))
|
||||
proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprint(inputToken)))
|
||||
proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprint(outputToken)))
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user