diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 68511fe4e..121bad10e 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -132,13 +132,13 @@ const ( ToolCallsPathStreaming = "choices.0.delta.tool_calls" // Claude/Anthropic tool calls paths (streaming) - ClaudeEventType = "type" - ClaudeContentBlockType = "content_block.type" - ClaudeContentBlockID = "content_block.id" - ClaudeContentBlockName = "content_block.name" - ClaudeContentBlockInput = "content_block.input" - ClaudeDeltaPartialJSON = "delta.partial_json" - ClaudeIndex = "index" + ClaudeEventType = "type" + ClaudeContentBlockType = "content_block.type" + ClaudeContentBlockID = "content_block.id" + ClaudeContentBlockName = "content_block.name" + ClaudeContentBlockInput = "content_block.input" + ClaudeDeltaPartialJSON = "delta.partial_json" + ClaudeIndex = "index" // Reasoning paths ReasoningPathNonStreaming = "choices.0.message.reasoning_content" @@ -154,10 +154,10 @@ func getDefaultAttributes() []Attribute { return []Attribute{ // Extract complete conversation history from request body { - Key: "messages", + Key: "messages", ValueSource: RequestBody, - Value: "messages", - ApplyToLog: true, + Value: "messages", + ApplyToLog: true, }, // Built-in attributes (no value_source needed, will be auto-extracted) { @@ -259,10 +259,10 @@ func extractSessionId(customHeader string) string { // ToolCall represents a single tool call in the response type ToolCall struct { - Index int `json:"index,omitempty"` - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Function ToolCallFunction `json:"function,omitempty"` + Index int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function ToolCallFunction `json:"function,omitempty"` } // ToolCallFunction represents the function details in a tool call @@ -297,7 +297,7 @@ func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *S for _, tcResult := range toolCallsResult.Array() { index := int(tcResult.Get("index").Int()) - + // Get or create tool call entry tc, exists := buffer.ToolCalls[index] if !exists { @@ -350,10 +350,10 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff contentBlockType := gjson.GetBytes(chunk, ClaudeContentBlockType) if contentBlockType.Exists() && contentBlockType.String() == "tool_use" { index := int(gjson.GetBytes(chunk, ClaudeIndex).Int()) - + // Create tool call entry tc := &ToolCall{Index: index} - + // Extract id and name if id := gjson.GetBytes(chunk, ClaudeContentBlockID).String(); id != "" { tc.ID = id @@ -362,11 +362,11 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff tc.Function.Name = name } tc.Type = "tool_use" - + buffer.ToolCalls[index] = tc buffer.InToolBlock[index] = true buffer.ArgumentsBuffer[index] = "" - + // Try to extract initial input if present if input := gjson.GetBytes(chunk, ClaudeContentBlockInput); input.Exists() { if inputMap, ok := input.Value().(map[string]interface{}); ok { @@ -393,7 +393,7 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff index := int(gjson.GetBytes(chunk, ClaudeIndex).Int()) if buffer.InToolBlock[index] { buffer.InToolBlock[index] = false - + // Parse accumulated arguments and set them if tc, exists := buffer.ToolCalls[index]; exists { tc.Function.Arguments = buffer.ArgumentsBuffer[index] @@ -555,7 +555,7 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error { if configJson.Get("value_length_limit").Exists() { config.valueLengthLimit = int(configJson.Get("value_length_limit").Int()) } else { - config.valueLengthLimit = 4000 + config.valueLengthLimit = 32000 } // Parse attributes or use defaults @@ -843,7 +843,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat setSpanAttribute(ArmsModelName, usage.Model) setSpanAttribute(ArmsInputToken, usage.InputToken) setSpanAttribute(ArmsOutputToken, usage.OutputToken) - + // Set token details to context for later use in attributes if len(usage.InputTokenDetails) > 0 { ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails) @@ -851,6 +851,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat if len(usage.OutputTokenDetails) > 0 { ctx.SetContext(tokenusage.CtxKeyOutputTokenDetails, usage.OutputTokenDetails) } + + // Write once + _ = ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) } } // If the end of the stream is reached, record metrics/logs/spans. @@ -907,7 +910,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body setSpanAttribute(ArmsInputToken, usage.InputToken) setSpanAttribute(ArmsOutputToken, usage.OutputToken) setSpanAttribute(ArmsTotalToken, usage.TotalToken) - + // Set token details to context for later use in attributes if len(usage.InputTokenDetails) > 0 { ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails) @@ -975,7 +978,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so if (value == nil || value == "") && attribute.DefaultValue != "" { value = attribute.DefaultValue } - + // Format value for logging/span var formattedValue interface{} switch v := value.(type) { @@ -994,7 +997,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so formattedValue = fmt.Sprint(value)[:config.valueLengthLimit/2] + " [truncated] " + fmt.Sprint(value)[len(fmt.Sprint(value))-config.valueLengthLimit/2:] } } - + log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, formattedValue) if attribute.ApplyToLog { if attribute.AsSeparateLogField { @@ -1124,7 +1127,7 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon // Also try Claude format (both formats can be checked) buffer = extractClaudeStreamingToolCalls(body, buffer) ctx.SetContext(CtxStreamingToolCallsBuffer, buffer) - + // Also set tool_calls to user attributes so they appear in ai_log toolCalls := getToolCallsFromBuffer(buffer) if len(toolCalls) > 0 { diff --git a/plugins/wasm-go/extensions/ai-statistics/main_test.go b/plugins/wasm-go/extensions/ai-statistics/main_test.go index 2e6329ae6..4af76e686 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main_test.go +++ b/plugins/wasm-go/extensions/ai-statistics/main_test.go @@ -1712,3 +1712,305 @@ func TestTokenDetails(t *testing.T) { host.CompleteHttp() }) } + +func TestUnmatchedPathsAndContentTypes(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + restrictiveConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "enable_path_suffixes": []string{"/allowed_path"}, + "enable_content_types": []string{"application/json"}, + "attributes": []map[string]interface{}{ + { + "key": "test_attr", + "value_source": "response_body", + "value": "data", + "apply_to_log": true, + }, + }, + "disable_openai_usage": true, + }) + return data + }() + + t.Run("skip request for unenabled path", func(t *testing.T) { + host, status := test.NewTestHost(restrictiveConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/disallowed_path"}, + {":method", "POST"}, + }) + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + + t.Run("skip response for unenabled content type", func(t *testing.T) { + host, status := test.NewTestHost(restrictiveConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/allowed_path"}, + {":method", "POST"}, + }) + + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/plain"}, + }) + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + }) +} + +func TestSetSpanAttributeAndLoggingEdgeCases(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + configBytes := []byte(`{ + "attributes": [ + { + "key": "test_attr1", + "value_source": "fixed_value", + "value": "", + "apply_to_span": true + }, + { + "key": "test_attr2", + "value_source": "fixed_value", + "value": "long_value_that_exceeds_limit_long_value_that_exceeds_limit_long_value_that_exceeds_limit", + "apply_to_log": true + } + ], + "value_length_limit": 20 + }`) + + t.Run("span attribute edge cases", func(t *testing.T) { + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // Setting fixed value attribute to empty should just print a debug log and skip setting span + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/api/chat"}, + }) + require.Equal(t, types.ActionContinue, action) + host.CompleteHttp() + }) + }) +} + +func TestGetRouteAndClusterNameEdgeCases(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("properties absence", func(t *testing.T) { + host, status := test.NewTestHost([]byte(`{}`)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // Host doesn't have route_name implicitly by default without SetRouteName, but getRouteName handles err + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/api/chat"}, + }) + host.CompleteHttp() + }) + + t.Run("api name with @", func(t *testing.T) { + host, status := test.NewTestHost([]byte(`{}`)) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.SetRouteName("api@v1@service@extra") // @ has special handling in getAPIName + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/api/chat"}, + }) + host.CompleteHttp() + }) + }) +} + +func TestExtractClaudeStreamingToolCallsMissingInput(t *testing.T) { + t.Run("claude missing partial_json", func(t *testing.T) { + chunks := [][]byte{ + []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather","input":{}}}`), + []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta"}}`), + []byte(`data: {"type":"content_block_stop","index":0}`), + } + + var buffer *StreamingToolCallsBuffer + for _, chunk := range chunks { + buffer = extractClaudeStreamingToolCalls(chunk, buffer) + } + + toolCalls := getToolCallsFromBuffer(buffer) + require.Len(t, toolCalls, 1) + require.Equal(t, "tool_123", toolCalls[0].ID) + require.Equal(t, "tool_use", toolCalls[0].Type) + require.Equal(t, "get_weather", toolCalls[0].Function.Name) + // partial_json absence means arguments might be empty + }) +} + +func TestWriteMetricEdgeCases(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("disable_openai_usage true", func(t *testing.T) { + configBytes := []byte(`{ + "disable_openai_usage": true + }`) + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.SetRouteName("api-v1") + host.SetClusterName("cluster-1") + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/api/chat"}, + }) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + + responseBody := []byte(`{ + "usage": {"prompt_tokens": 5, "completion_tokens": 8, "total_tokens": 13}, + "model": "gpt-3.5-turbo" + }`) + host.CallOnHttpResponseBody(responseBody) + host.CompleteHttp() + }) + }) +} + +func TestIsPathEnabled(t *testing.T) { + require.True(t, isPathEnabled("/v1/chat/completions", nil)) + require.True(t, isPathEnabled("/v1/chat/completions", []string{})) + require.True(t, isPathEnabled("/v1/chat/completions", []string{"/completions", "/messages"})) + require.True(t, isPathEnabled("/v1/messages", []string{"/completions", "/messages"})) + require.False(t, isPathEnabled("/v1/embeddings", []string{"/completions", "/messages"})) + + // test query params + require.True(t, isPathEnabled("/v1/chat/completions?stream=true", []string{"/completions"})) + require.False(t, isPathEnabled("/v1/embeddings?stream=true", []string{"/completions"})) +} + +func TestIsContentTypeEnabled(t *testing.T) { + require.True(t, isContentTypeEnabled("application/json", nil)) + require.True(t, isContentTypeEnabled("application/json", []string{})) + require.True(t, isContentTypeEnabled("application/json", []string{"application/json", "text/event-stream"})) + require.True(t, isContentTypeEnabled("text/event-stream; charset=utf-8", []string{"application/json", "text/event-stream"})) + require.False(t, isContentTypeEnabled("text/html", []string{"application/json", "text/event-stream"})) +} + +func TestConvertToUInt(t *testing.T) { + val, ok := convertToUInt(int32(10)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + val, ok = convertToUInt(int64(10)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + val, ok = convertToUInt(uint32(10)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + val, ok = convertToUInt(uint64(10)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + val, ok = convertToUInt(float32(10.0)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + val, ok = convertToUInt(float64(10.0)) + require.True(t, ok) + require.Equal(t, uint64(10), val) + + _, ok = convertToUInt("10") + require.False(t, ok) +} + +func TestExtractClaudeStreamingToolCalls(t *testing.T) { + t.Run("claude tool use assembly", func(t *testing.T) { + chunks := [][]byte{ + []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather"}}`), + []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"loc"}}}`), + []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ation\":\"Bei"}}}`), + []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"jing\"}"}}}`), + []byte(`data: {"type":"content_block_stop","index":0}`), + } + + var buffer *StreamingToolCallsBuffer + for _, chunk := range chunks { + buffer = extractClaudeStreamingToolCalls(chunk, buffer) + } + + toolCalls := getToolCallsFromBuffer(buffer) + require.Len(t, toolCalls, 1) + require.Equal(t, "tool_123", toolCalls[0].ID) + require.Equal(t, "tool_use", toolCalls[0].Type) + require.Equal(t, "get_weather", toolCalls[0].Function.Name) + require.Equal(t, `{"location":"Beijing"}`, toolCalls[0].Function.Arguments) + }) + + t.Run("claude empty chunks", func(t *testing.T) { + chunks := [][]byte{ + []byte(`data: {"type":"ping"}`), + []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`), + } + var buffer *StreamingToolCallsBuffer + for _, chunk := range chunks { + buffer = extractClaudeStreamingToolCalls(chunk, buffer) + } + toolCalls := getToolCallsFromBuffer(buffer) + require.Len(t, toolCalls, 0) + }) + + t.Run("claude tool use with initial input", func(t *testing.T) { + chunks := [][]byte{ + []byte(`data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"tool_456","name":"get_time","input":{"timezone":"UTC+8"}}}`), + []byte(`data: {"type":"content_block_stop","index":1}`), + } + + var buffer *StreamingToolCallsBuffer + for _, chunk := range chunks { + buffer = extractClaudeStreamingToolCalls(chunk, buffer) + } + + toolCalls := getToolCallsFromBuffer(buffer) + require.Len(t, toolCalls, 1) + require.Equal(t, "tool_456", toolCalls[0].ID) + require.Equal(t, "tool_use", toolCalls[0].Type) + require.Equal(t, "get_time", toolCalls[0].Function.Name) + require.Equal(t, `{"timezone":"UTC+8"}`, toolCalls[0].Function.Arguments) + }) +} + +func TestConfigWithDefaultAttributes(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("use default attributes config", func(t *testing.T) { + defaultConfig := []byte(`{ + "use_default_attributes": true + }`) + host, status := test.NewTestHost(defaultConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + }) + + t.Run("use default response attributes config", func(t *testing.T) { + defaultRespConfig := []byte(`{ + "use_default_response_attributes": true + }`) + host, status := test.NewTestHost(defaultRespConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + }) + }) +}