mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
541 lines
17 KiB
Go
541 lines
17 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
func main() {
|
|
wrapper.SetCtx(
|
|
"ai-statistics",
|
|
wrapper.ParseConfigBy(parseConfig),
|
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
|
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
|
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody),
|
|
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
|
)
|
|
}
|
|
|
|
const (
|
|
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
|
// Context consts
|
|
StatisticsRequestStartTime = "ai-statistics-request-start-time"
|
|
StatisticsFirstTokenTime = "ai-statistics-first-token-time"
|
|
CtxGeneralAtrribute = "attributes"
|
|
CtxLogAtrribute = "logAttributes"
|
|
CtxStreamingBodyBuffer = "streamingBodyBuffer"
|
|
RouteName = "route"
|
|
ClusterName = "cluster"
|
|
APIName = "api"
|
|
ConsumerKey = "x-mse-consumer"
|
|
|
|
// Source Type
|
|
FixedValue = "fixed_value"
|
|
RequestHeader = "request_header"
|
|
RequestBody = "request_body"
|
|
ResponseHeader = "response_header"
|
|
ResponseStreamingBody = "response_streaming_body"
|
|
ResponseBody = "response_body"
|
|
|
|
// Inner metric & log attributes
|
|
Model = "model"
|
|
InputToken = "input_token"
|
|
OutputToken = "output_token"
|
|
LLMFirstTokenDuration = "llm_first_token_duration"
|
|
LLMServiceDuration = "llm_service_duration"
|
|
LLMDurationCount = "llm_duration_count"
|
|
LLMStreamDurationCount = "llm_stream_duration_count"
|
|
ResponseType = "response_type"
|
|
ChatID = "chat_id"
|
|
ChatRound = "chat_round"
|
|
|
|
// Inner span attributes
|
|
ArmsSpanKind = "gen_ai.span.kind"
|
|
ArmsModelName = "gen_ai.model_name"
|
|
ArmsRequestModel = "gen_ai.request.model"
|
|
ArmsInputToken = "gen_ai.usage.input_tokens"
|
|
ArmsOutputToken = "gen_ai.usage.output_tokens"
|
|
ArmsTotalToken = "gen_ai.usage.total_tokens"
|
|
|
|
// Extract Rule
|
|
RuleFirst = "first"
|
|
RuleReplace = "replace"
|
|
RuleAppend = "append"
|
|
)
|
|
|
|
// TracingSpan is the tracing span configuration.
|
|
type Attribute struct {
|
|
Key string `json:"key"`
|
|
ValueSource string `json:"value_source"`
|
|
Value string `json:"value"`
|
|
DefaultValue string `json:"default_value,omitempty"`
|
|
Rule string `json:"rule,omitempty"`
|
|
ApplyToLog bool `json:"apply_to_log,omitempty"`
|
|
ApplyToSpan bool `json:"apply_to_span,omitempty"`
|
|
}
|
|
|
|
type AIStatisticsConfig struct {
|
|
// Metrics
|
|
// TODO: add more metrics in Gauge and Histogram format
|
|
counterMetrics map[string]proxywasm.MetricCounter
|
|
// Attributes to be recorded in log & span
|
|
attributes []Attribute
|
|
// If there exist attributes extracted from streaming body, chunks should be buffered
|
|
shouldBufferStreamingBody bool
|
|
}
|
|
|
|
func generateMetricName(route, cluster, model, consumer, metricName string) string {
|
|
return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
|
|
}
|
|
|
|
func getRouteName() (string, error) {
|
|
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
|
|
return "-", err
|
|
} else {
|
|
return string(raw), nil
|
|
}
|
|
}
|
|
|
|
func getAPIName() (string, error) {
|
|
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
|
|
return "-", err
|
|
} else {
|
|
parts := strings.Split(string(raw), "@")
|
|
if len(parts) != 5 {
|
|
return "-", errors.New("not api type")
|
|
} else {
|
|
return strings.Join(parts[:3], "@"), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func getClusterName() (string, error) {
|
|
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
|
|
return "-", err
|
|
} else {
|
|
return string(raw), nil
|
|
}
|
|
}
|
|
|
|
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
|
|
if inc == 0 {
|
|
return
|
|
}
|
|
counter, ok := config.counterMetrics[metricName]
|
|
if !ok {
|
|
counter = proxywasm.DefineCounterMetric(metricName)
|
|
config.counterMetrics[metricName] = counter
|
|
}
|
|
counter.Increment(inc)
|
|
}
|
|
|
|
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrapper.Log) error {
|
|
// Parse tracing span attributes setting.
|
|
attributeConfigs := configJson.Get("attributes").Array()
|
|
config.attributes = make([]Attribute, len(attributeConfigs))
|
|
for i, attributeConfig := range attributeConfigs {
|
|
attribute := Attribute{}
|
|
err := json.Unmarshal([]byte(attributeConfig.Raw), &attribute)
|
|
if err != nil {
|
|
log.Errorf("parse config failed, %v", err)
|
|
return err
|
|
}
|
|
if attribute.ValueSource == ResponseStreamingBody {
|
|
config.shouldBufferStreamingBody = true
|
|
}
|
|
if attribute.Rule != "" && attribute.Rule != RuleFirst && attribute.Rule != RuleReplace && attribute.Rule != RuleAppend {
|
|
return errors.New("value of rule must be one of [nil, first, replace, append]")
|
|
}
|
|
config.attributes[i] = attribute
|
|
}
|
|
// Metric settings
|
|
config.counterMetrics = make(map[string]proxywasm.MetricCounter)
|
|
return nil
|
|
}
|
|
|
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
|
route, _ := getRouteName()
|
|
cluster, _ := getClusterName()
|
|
api, api_error := getAPIName()
|
|
if api_error == nil {
|
|
route = api
|
|
}
|
|
ctx.SetContext(RouteName, route)
|
|
ctx.SetContext(ClusterName, cluster)
|
|
ctx.SetUserAttribute(APIName, api)
|
|
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
|
|
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
|
|
ctx.SetContext(ConsumerKey, consumer)
|
|
}
|
|
hasRequestBody := wrapper.HasRequestBody()
|
|
if hasRequestBody {
|
|
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
|
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
|
}
|
|
|
|
// Set user defined log & span attributes which type is fixed_value
|
|
setAttributeBySource(ctx, config, FixedValue, nil, log)
|
|
// Set user defined log & span attributes which type is request_header
|
|
setAttributeBySource(ctx, config, RequestHeader, nil, log)
|
|
// Set span attributes for ARMS.
|
|
setSpanAttribute(ArmsSpanKind, "LLM", log)
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
|
// Set user defined log & span attributes.
|
|
setAttributeBySource(ctx, config, RequestBody, body, log)
|
|
// Set span attributes for ARMS.
|
|
requestModel := gjson.GetBytes(body, "model").String()
|
|
if requestModel == "" {
|
|
requestModel = "UNKNOWN"
|
|
}
|
|
setSpanAttribute(ArmsRequestModel, requestModel, log)
|
|
// Set the number of conversation rounds
|
|
if gjson.GetBytes(body, "messages").Exists() {
|
|
userPromptCount := 0
|
|
for _, msg := range gjson.GetBytes(body, "messages").Array() {
|
|
if msg.Get("role").String() == "user" {
|
|
userPromptCount += 1
|
|
}
|
|
}
|
|
ctx.SetUserAttribute(ChatRound, userPromptCount)
|
|
}
|
|
|
|
// Write log
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
|
if !strings.Contains(contentType, "text/event-stream") {
|
|
ctx.BufferResponseBody()
|
|
}
|
|
|
|
// Set user defined log & span attributes.
|
|
setAttributeBySource(ctx, config, ResponseHeader, nil, log)
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
|
// Buffer stream body for record log & span attributes
|
|
if config.shouldBufferStreamingBody {
|
|
var streamingBodyBuffer []byte
|
|
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
|
if !ok {
|
|
streamingBodyBuffer = data
|
|
} else {
|
|
streamingBodyBuffer = append(streamingBodyBuffer, data...)
|
|
}
|
|
ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer)
|
|
}
|
|
|
|
ctx.SetUserAttribute(ResponseType, "stream")
|
|
chatID := gjson.GetBytes(data, "id").String()
|
|
if chatID != "" {
|
|
ctx.SetUserAttribute(ChatID, chatID)
|
|
}
|
|
|
|
// Get requestStartTime from http context
|
|
requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64)
|
|
if !ok {
|
|
log.Error("failed to get requestStartTime from http context")
|
|
return data
|
|
}
|
|
|
|
// If this is the first chunk, record first token duration metric and span attribute
|
|
if ctx.GetContext(StatisticsFirstTokenTime) == nil {
|
|
firstTokenTime := time.Now().UnixMilli()
|
|
ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime)
|
|
ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime)
|
|
}
|
|
|
|
// Set information about this request
|
|
if model, inputToken, outputToken, ok := getUsage(data); ok {
|
|
ctx.SetUserAttribute(Model, model)
|
|
ctx.SetUserAttribute(InputToken, inputToken)
|
|
ctx.SetUserAttribute(OutputToken, outputToken)
|
|
// Set span attributes for ARMS.
|
|
setSpanAttribute(ArmsModelName, model, log)
|
|
setSpanAttribute(ArmsInputToken, inputToken, log)
|
|
setSpanAttribute(ArmsOutputToken, outputToken, log)
|
|
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log)
|
|
}
|
|
// If the end of the stream is reached, record metrics/logs/spans.
|
|
if endOfStream {
|
|
responseEndTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
|
|
|
// Set user defined log & span attributes.
|
|
if config.shouldBufferStreamingBody {
|
|
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
|
if !ok {
|
|
return data
|
|
}
|
|
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log)
|
|
}
|
|
|
|
// Write log
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
|
|
// Write metrics
|
|
writeMetric(ctx, config, log)
|
|
}
|
|
return data
|
|
}
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
|
// Get requestStartTime from http context
|
|
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
|
|
|
|
responseEndTime := time.Now().UnixMilli()
|
|
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
|
|
|
ctx.SetUserAttribute(ResponseType, "normal")
|
|
chatID := gjson.GetBytes(body, "id").String()
|
|
if chatID != "" {
|
|
ctx.SetUserAttribute(ChatID, chatID)
|
|
}
|
|
|
|
// Set information about this request
|
|
if model, inputToken, outputToken, ok := getUsage(body); ok {
|
|
ctx.SetUserAttribute(Model, model)
|
|
ctx.SetUserAttribute(InputToken, inputToken)
|
|
ctx.SetUserAttribute(OutputToken, outputToken)
|
|
// Set span attributes for ARMS.
|
|
setSpanAttribute(ArmsModelName, model, log)
|
|
setSpanAttribute(ArmsInputToken, inputToken, log)
|
|
setSpanAttribute(ArmsOutputToken, outputToken, log)
|
|
setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log)
|
|
}
|
|
|
|
// Set user defined log & span attributes.
|
|
setAttributeBySource(ctx, config, ResponseBody, body, log)
|
|
|
|
// Write log
|
|
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
|
|
|
// Write metrics
|
|
writeMetric(ctx, config, log)
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func unifySSEChunk(data []byte) []byte {
|
|
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
|
|
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
|
|
return data
|
|
}
|
|
|
|
func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) {
|
|
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n"))
|
|
for _, chunk := range chunks {
|
|
// the feature strings are used to identify the usage data, like:
|
|
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
|
|
if !bytes.Contains(chunk, []byte("prompt_tokens")) {
|
|
continue
|
|
}
|
|
if !bytes.Contains(chunk, []byte("completion_tokens")) {
|
|
continue
|
|
}
|
|
modelObj := gjson.GetBytes(chunk, "model")
|
|
if modelObj.Exists() {
|
|
model = modelObj.String()
|
|
} else {
|
|
model = "unknown"
|
|
}
|
|
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
|
|
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
|
|
if inputTokenObj.Exists() && outputTokenObj.Exists() {
|
|
inputTokenUsage = inputTokenObj.Int()
|
|
outputTokenUsage = outputTokenObj.Int()
|
|
ok = true
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// fetches the tracing span value from the specified source.
|
|
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
|
|
for _, attribute := range config.attributes {
|
|
var key string
|
|
var value interface{}
|
|
if source == attribute.ValueSource {
|
|
key = attribute.Key
|
|
switch source {
|
|
case FixedValue:
|
|
value = attribute.Value
|
|
case RequestHeader:
|
|
value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
|
|
case RequestBody:
|
|
value = gjson.GetBytes(body, attribute.Value).Value()
|
|
case ResponseHeader:
|
|
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
|
|
case ResponseStreamingBody:
|
|
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
|
|
case ResponseBody:
|
|
value = gjson.GetBytes(body, attribute.Value).Value()
|
|
default:
|
|
}
|
|
if (value == nil || value == "") && attribute.DefaultValue != "" {
|
|
value = attribute.DefaultValue
|
|
}
|
|
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
|
|
if attribute.ApplyToLog {
|
|
ctx.SetUserAttribute(key, value)
|
|
}
|
|
// for metrics
|
|
if key == Model || key == InputToken || key == OutputToken {
|
|
ctx.SetContext(key, value)
|
|
}
|
|
if attribute.ApplyToSpan {
|
|
setSpanAttribute(key, value, log)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
|
|
chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n"))
|
|
var value interface{}
|
|
if rule == RuleFirst {
|
|
for _, chunk := range chunks {
|
|
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
|
if jsonObj.Exists() {
|
|
value = jsonObj.Value()
|
|
break
|
|
}
|
|
}
|
|
} else if rule == RuleReplace {
|
|
for _, chunk := range chunks {
|
|
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
|
if jsonObj.Exists() {
|
|
value = jsonObj.Value()
|
|
}
|
|
}
|
|
} else if rule == RuleAppend {
|
|
// extract llm response
|
|
var strValue string
|
|
for _, chunk := range chunks {
|
|
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
|
if jsonObj.Exists() {
|
|
strValue += jsonObj.String()
|
|
}
|
|
}
|
|
value = strValue
|
|
} else {
|
|
log.Errorf("unsupported rule type: %s", rule)
|
|
}
|
|
return value
|
|
}
|
|
|
|
// Set the tracing span with value.
|
|
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
|
|
if value != "" {
|
|
traceSpanTag := wrapper.TraceSpanTagPrefix + key
|
|
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
|
|
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
|
|
}
|
|
} else {
|
|
log.Debugf("failed to write span attribute [%s], because it's value is empty")
|
|
}
|
|
}
|
|
|
|
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
|
|
// Generate usage metrics
|
|
var ok bool
|
|
var route, cluster, model string
|
|
var inputToken, outputToken uint64
|
|
consumer := ctx.GetStringContext(ConsumerKey, "none")
|
|
route, ok = ctx.GetContext(RouteName).(string)
|
|
if !ok {
|
|
log.Warnf("RouteName typd assert failed, skip metric record")
|
|
return
|
|
}
|
|
cluster, ok = ctx.GetContext(ClusterName).(string)
|
|
if !ok {
|
|
log.Warnf("ClusterName typd assert failed, skip metric record")
|
|
return
|
|
}
|
|
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
|
|
log.Warnf("get usage information failed, skip metric record")
|
|
return
|
|
}
|
|
model, ok = ctx.GetUserAttribute(Model).(string)
|
|
if !ok {
|
|
log.Warnf("Model typd assert failed, skip metric record")
|
|
return
|
|
}
|
|
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
|
|
if !ok {
|
|
log.Warnf("InputToken typd assert failed, skip metric record")
|
|
return
|
|
}
|
|
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
|
|
if !ok {
|
|
log.Warnf("OutputToken typd assert failed, skip metric record")
|
|
return
|
|
}
|
|
if inputToken == 0 || outputToken == 0 {
|
|
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
|
return
|
|
}
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
|
|
|
|
// Generate duration metrics
|
|
var llmFirstTokenDuration, llmServiceDuration uint64
|
|
// Is stream response
|
|
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
|
|
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
|
|
if !ok {
|
|
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
|
return
|
|
}
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
|
|
}
|
|
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
|
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
|
if !ok {
|
|
log.Warnf("LLMServiceDuration typd assert failed")
|
|
return
|
|
}
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
|
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
|
|
}
|
|
}
|
|
|
|
func convertToUInt(val interface{}) (uint64, bool) {
|
|
switch v := val.(type) {
|
|
case float32:
|
|
return uint64(v), true
|
|
case float64:
|
|
return uint64(v), true
|
|
case int32:
|
|
return uint64(v), true
|
|
case int64:
|
|
return uint64(v), true
|
|
case uint32:
|
|
return uint64(v), true
|
|
case uint64:
|
|
return v, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|