From 67734823005d3c4bd53340b4aa53a564fa52882f Mon Sep 17 00:00:00 2001 From: rinfx Date: Fri, 18 Apr 2025 16:19:59 +0800 Subject: [PATCH] Enhance the compatibility of AI observability plugins with different LLM suppliers (#2088) --- .../wasm-go/extensions/ai-statistics/go.mod | 1 + .../wasm-go/extensions/ai-statistics/go.sum | 3 + .../wasm-go/extensions/ai-statistics/main.go | 59 +++++++++++++++++-- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-statistics/go.mod b/plugins/wasm-go/extensions/ai-statistics/go.mod index a90a03c44..46d6780f9 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.mod +++ b/plugins/wasm-go/extensions/ai-statistics/go.mod @@ -17,4 +17,5 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/resp v0.1.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-statistics/go.sum b/plugins/wasm-go/extensions/ai-statistics/go.sum index b4ab172fe..f76459d11 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.sum +++ b/plugins/wasm-go/extensions/ai-statistics/go.sum @@ -9,6 +9,7 @@ github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -17,4 +18,6 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index f8d315ec8..ae0c4d88b 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -46,7 +46,7 @@ const ( ResponseStreamingBody = "response_streaming_body" ResponseBody = "response_body" - // Inner metric & log attributes name + // Inner metric & log attributes Model = "model" InputToken = "input_token" OutputToken = "output_token" @@ -55,6 +55,16 @@ const ( LLMDurationCount = "llm_duration_count" LLMStreamDurationCount = "llm_stream_duration_count" ResponseType = "response_type" + ChatID = "chat_id" + ChatRound = "chat_round" + + // Inner span attributes + ArmsSpanKind = "gen_ai.span.kind" + ArmsModelName = "gen_ai.model_name" + ArmsRequestModel = "gen_ai.request.model" + ArmsInputToken = "gen_ai.usage.input_tokens" + ArmsOutputToken = "gen_ai.usage.output_tokens" + ArmsTotalToken = "gen_ai.usage.total_tokens" // Extract Rule RuleFirst = "first" @@ -171,7 +181,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo setAttributeBySource(ctx, config, FixedValue, nil, log) // Set user defined log & span attributes which type is request_header setAttributeBySource(ctx, config, RequestHeader, nil, log) - // Set request start time. + // Set span attributes for ARMS. + setSpanAttribute(ArmsSpanKind, "LLM", log) return types.ActionContinue } @@ -179,6 +190,22 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { // Set user defined log & span attributes. setAttributeBySource(ctx, config, RequestBody, body, log) + // Set span attributes for ARMS. + requestModel := gjson.GetBytes(body, "model").String() + if requestModel == "" { + requestModel = "UNKNOWN" + } + setSpanAttribute(ArmsRequestModel, requestModel, log) + // Set the number of conversation rounds + if gjson.GetBytes(body, "messages").Exists() { + userPromptCount := 0 + for _, msg := range gjson.GetBytes(body, "messages").Array() { + if msg.Get("role").String() == "user" { + userPromptCount += 1 + } + } + ctx.SetUserAttribute(ChatRound, userPromptCount) + } // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) @@ -211,6 +238,10 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat } ctx.SetUserAttribute(ResponseType, "stream") + chatID := gjson.GetBytes(data, "id").String() + if chatID != "" { + ctx.SetUserAttribute(ChatID, chatID) + } // Get requestStartTime from http context requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64) @@ -231,6 +262,11 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat ctx.SetUserAttribute(Model, model) ctx.SetUserAttribute(InputToken, inputToken) ctx.SetUserAttribute(OutputToken, outputToken) + // Set span attributes for ARMS. + setSpanAttribute(ArmsModelName, model, log) + setSpanAttribute(ArmsInputToken, inputToken, log) + setSpanAttribute(ArmsOutputToken, outputToken, log) + setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) } // If the end of the stream is reached, record metrics/logs/spans. if endOfStream { @@ -263,12 +299,21 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) ctx.SetUserAttribute(ResponseType, "normal") + chatID := gjson.GetBytes(body, "id").String() + if chatID != "" { + ctx.SetUserAttribute(ChatID, chatID) + } // Set information about this request if model, inputToken, outputToken, ok := getUsage(body); ok { ctx.SetUserAttribute(Model, model) ctx.SetUserAttribute(InputToken, inputToken) ctx.SetUserAttribute(OutputToken, outputToken) + // Set span attributes for ARMS. + setSpanAttribute(ArmsModelName, model, log) + setSpanAttribute(ArmsInputToken, inputToken, log) + setSpanAttribute(ArmsOutputToken, outputToken, log) + setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) } // Set user defined log & span attributes. @@ -283,8 +328,14 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body return types.ActionContinue } +func unifySSEChunk(data []byte) []byte { + data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) + data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n")) + return data +} + func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) { - chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) for _, chunk := range chunks { // the feature strings are used to identify the usage data, like: // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} @@ -353,7 +404,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so } func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} { - chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) var value interface{} if rule == RuleFirst { for _, chunk := range chunks {