Files
higress/plugins/wasm-go/extensions/ai-statistics/main.go

126 lines
4.3 KiB
Go

package main
import (
"bytes"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
func main() {
wrapper.SetCtx(
"ai-statistics",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
type AIStatisticsConfig struct {
enable bool
metrics map[string]proxywasm.MetricCounter
}
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64, log wrapper.Log) {
counter, ok := config.metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.metrics[metricName] = counter
}
counter.Increment(inc)
}
func parseConfig(json gjson.Result, config *AIStatisticsConfig, log wrapper.Log) error {
config.enable = json.Get("enable").Bool()
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
if !config.enable {
ctx.DontReadResponseBody()
return types.ActionContinue
}
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if !strings.Contains(contentType, "text/event-stream") {
ctx.BufferResponseBody()
}
return types.ActionContinue
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
model, inputToken, outputToken, ok := getUsage(data)
if !ok {
return data
}
setFilterStateData(model, inputToken, outputToken, log)
incrementCounter(config, model, inputToken, outputToken, log)
return data
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
model, inputToken, outputToken, ok := getUsage(body)
if !ok {
return types.ActionContinue
}
setFilterStateData(model, inputToken, outputToken, log)
incrementCounter(config, model, inputToken, outputToken, log)
return types.ActionContinue
}
func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) {
continue
}
if !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
modelObj := gjson.GetBytes(chunk, "model")
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if modelObj.Exists() && inputTokenObj.Exists() && outputTokenObj.Exists() {
model = modelObj.String()
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
// setFilterData sets the input_token and output_token in the filter state.
// ai-token-ratelimit will use these values to calculate the total token usage.
func setFilterStateData(model string, inputToken int64, outputToken int64, log wrapper.Log) {
if e := proxywasm.SetProperty([]string{"model"}, []byte(model)); e != nil {
log.Errorf("failed to set model in filter state: %v", e)
}
if e := proxywasm.SetProperty([]string{"input_token"}, []byte(fmt.Sprintf("%d", inputToken))); e != nil {
log.Errorf("failed to set input_token in filter state: %v", e)
}
if e := proxywasm.SetProperty([]string{"output_token"}, []byte(fmt.Sprintf("%d", outputToken))); e != nil {
log.Errorf("failed to set output_token in filter state: %v", e)
}
}
func incrementCounter(config AIStatisticsConfig, model string, inputToken int64, outputToken int64, log wrapper.Log) {
var route, cluster string
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err == nil {
route = string(raw)
}
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err == nil {
cluster = string(raw)
}
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".input_token", uint64(inputToken), log)
config.incrementCounter("route."+route+".upstream."+cluster+".model."+model+".output_token", uint64(outputToken), log)
}