mirror of
https://github.com/alibaba/higress.git
synced 2026-06-05 02:27:28 +08:00
ai-statistics: increase default value_length_limit and emit AILog during streaming usage (#3624)
This commit is contained in:
@@ -132,13 +132,13 @@ const (
|
|||||||
ToolCallsPathStreaming = "choices.0.delta.tool_calls"
|
ToolCallsPathStreaming = "choices.0.delta.tool_calls"
|
||||||
|
|
||||||
// Claude/Anthropic tool calls paths (streaming)
|
// Claude/Anthropic tool calls paths (streaming)
|
||||||
ClaudeEventType = "type"
|
ClaudeEventType = "type"
|
||||||
ClaudeContentBlockType = "content_block.type"
|
ClaudeContentBlockType = "content_block.type"
|
||||||
ClaudeContentBlockID = "content_block.id"
|
ClaudeContentBlockID = "content_block.id"
|
||||||
ClaudeContentBlockName = "content_block.name"
|
ClaudeContentBlockName = "content_block.name"
|
||||||
ClaudeContentBlockInput = "content_block.input"
|
ClaudeContentBlockInput = "content_block.input"
|
||||||
ClaudeDeltaPartialJSON = "delta.partial_json"
|
ClaudeDeltaPartialJSON = "delta.partial_json"
|
||||||
ClaudeIndex = "index"
|
ClaudeIndex = "index"
|
||||||
|
|
||||||
// Reasoning paths
|
// Reasoning paths
|
||||||
ReasoningPathNonStreaming = "choices.0.message.reasoning_content"
|
ReasoningPathNonStreaming = "choices.0.message.reasoning_content"
|
||||||
@@ -154,10 +154,10 @@ func getDefaultAttributes() []Attribute {
|
|||||||
return []Attribute{
|
return []Attribute{
|
||||||
// Extract complete conversation history from request body
|
// Extract complete conversation history from request body
|
||||||
{
|
{
|
||||||
Key: "messages",
|
Key: "messages",
|
||||||
ValueSource: RequestBody,
|
ValueSource: RequestBody,
|
||||||
Value: "messages",
|
Value: "messages",
|
||||||
ApplyToLog: true,
|
ApplyToLog: true,
|
||||||
},
|
},
|
||||||
// Built-in attributes (no value_source needed, will be auto-extracted)
|
// Built-in attributes (no value_source needed, will be auto-extracted)
|
||||||
{
|
{
|
||||||
@@ -259,10 +259,10 @@ func extractSessionId(customHeader string) string {
|
|||||||
|
|
||||||
// ToolCall represents a single tool call in the response
|
// ToolCall represents a single tool call in the response
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
Function ToolCallFunction `json:"function,omitempty"`
|
Function ToolCallFunction `json:"function,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolCallFunction represents the function details in a tool call
|
// ToolCallFunction represents the function details in a tool call
|
||||||
@@ -297,7 +297,7 @@ func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *S
|
|||||||
|
|
||||||
for _, tcResult := range toolCallsResult.Array() {
|
for _, tcResult := range toolCallsResult.Array() {
|
||||||
index := int(tcResult.Get("index").Int())
|
index := int(tcResult.Get("index").Int())
|
||||||
|
|
||||||
// Get or create tool call entry
|
// Get or create tool call entry
|
||||||
tc, exists := buffer.ToolCalls[index]
|
tc, exists := buffer.ToolCalls[index]
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -350,10 +350,10 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff
|
|||||||
contentBlockType := gjson.GetBytes(chunk, ClaudeContentBlockType)
|
contentBlockType := gjson.GetBytes(chunk, ClaudeContentBlockType)
|
||||||
if contentBlockType.Exists() && contentBlockType.String() == "tool_use" {
|
if contentBlockType.Exists() && contentBlockType.String() == "tool_use" {
|
||||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||||
|
|
||||||
// Create tool call entry
|
// Create tool call entry
|
||||||
tc := &ToolCall{Index: index}
|
tc := &ToolCall{Index: index}
|
||||||
|
|
||||||
// Extract id and name
|
// Extract id and name
|
||||||
if id := gjson.GetBytes(chunk, ClaudeContentBlockID).String(); id != "" {
|
if id := gjson.GetBytes(chunk, ClaudeContentBlockID).String(); id != "" {
|
||||||
tc.ID = id
|
tc.ID = id
|
||||||
@@ -362,11 +362,11 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff
|
|||||||
tc.Function.Name = name
|
tc.Function.Name = name
|
||||||
}
|
}
|
||||||
tc.Type = "tool_use"
|
tc.Type = "tool_use"
|
||||||
|
|
||||||
buffer.ToolCalls[index] = tc
|
buffer.ToolCalls[index] = tc
|
||||||
buffer.InToolBlock[index] = true
|
buffer.InToolBlock[index] = true
|
||||||
buffer.ArgumentsBuffer[index] = ""
|
buffer.ArgumentsBuffer[index] = ""
|
||||||
|
|
||||||
// Try to extract initial input if present
|
// Try to extract initial input if present
|
||||||
if input := gjson.GetBytes(chunk, ClaudeContentBlockInput); input.Exists() {
|
if input := gjson.GetBytes(chunk, ClaudeContentBlockInput); input.Exists() {
|
||||||
if inputMap, ok := input.Value().(map[string]interface{}); ok {
|
if inputMap, ok := input.Value().(map[string]interface{}); ok {
|
||||||
@@ -393,7 +393,7 @@ func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuff
|
|||||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||||
if buffer.InToolBlock[index] {
|
if buffer.InToolBlock[index] {
|
||||||
buffer.InToolBlock[index] = false
|
buffer.InToolBlock[index] = false
|
||||||
|
|
||||||
// Parse accumulated arguments and set them
|
// Parse accumulated arguments and set them
|
||||||
if tc, exists := buffer.ToolCalls[index]; exists {
|
if tc, exists := buffer.ToolCalls[index]; exists {
|
||||||
tc.Function.Arguments = buffer.ArgumentsBuffer[index]
|
tc.Function.Arguments = buffer.ArgumentsBuffer[index]
|
||||||
@@ -555,7 +555,7 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
|||||||
if configJson.Get("value_length_limit").Exists() {
|
if configJson.Get("value_length_limit").Exists() {
|
||||||
config.valueLengthLimit = int(configJson.Get("value_length_limit").Int())
|
config.valueLengthLimit = int(configJson.Get("value_length_limit").Int())
|
||||||
} else {
|
} else {
|
||||||
config.valueLengthLimit = 4000
|
config.valueLengthLimit = 32000
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse attributes or use defaults
|
// Parse attributes or use defaults
|
||||||
@@ -843,7 +843,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
|||||||
setSpanAttribute(ArmsModelName, usage.Model)
|
setSpanAttribute(ArmsModelName, usage.Model)
|
||||||
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
||||||
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
||||||
|
|
||||||
// Set token details to context for later use in attributes
|
// Set token details to context for later use in attributes
|
||||||
if len(usage.InputTokenDetails) > 0 {
|
if len(usage.InputTokenDetails) > 0 {
|
||||||
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
||||||
@@ -851,6 +851,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
|||||||
if len(usage.OutputTokenDetails) > 0 {
|
if len(usage.OutputTokenDetails) > 0 {
|
||||||
ctx.SetContext(tokenusage.CtxKeyOutputTokenDetails, usage.OutputTokenDetails)
|
ctx.SetContext(tokenusage.CtxKeyOutputTokenDetails, usage.OutputTokenDetails)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write once
|
||||||
|
_ = ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If the end of the stream is reached, record metrics/logs/spans.
|
// If the end of the stream is reached, record metrics/logs/spans.
|
||||||
@@ -907,7 +910,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
|||||||
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
||||||
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
||||||
setSpanAttribute(ArmsTotalToken, usage.TotalToken)
|
setSpanAttribute(ArmsTotalToken, usage.TotalToken)
|
||||||
|
|
||||||
// Set token details to context for later use in attributes
|
// Set token details to context for later use in attributes
|
||||||
if len(usage.InputTokenDetails) > 0 {
|
if len(usage.InputTokenDetails) > 0 {
|
||||||
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
||||||
@@ -975,7 +978,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
|||||||
if (value == nil || value == "") && attribute.DefaultValue != "" {
|
if (value == nil || value == "") && attribute.DefaultValue != "" {
|
||||||
value = attribute.DefaultValue
|
value = attribute.DefaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format value for logging/span
|
// Format value for logging/span
|
||||||
var formattedValue interface{}
|
var formattedValue interface{}
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
@@ -994,7 +997,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
|||||||
formattedValue = fmt.Sprint(value)[:config.valueLengthLimit/2] + " [truncated] " + fmt.Sprint(value)[len(fmt.Sprint(value))-config.valueLengthLimit/2:]
|
formattedValue = fmt.Sprint(value)[:config.valueLengthLimit/2] + " [truncated] " + fmt.Sprint(value)[len(fmt.Sprint(value))-config.valueLengthLimit/2:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, formattedValue)
|
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, formattedValue)
|
||||||
if attribute.ApplyToLog {
|
if attribute.ApplyToLog {
|
||||||
if attribute.AsSeparateLogField {
|
if attribute.AsSeparateLogField {
|
||||||
@@ -1124,7 +1127,7 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon
|
|||||||
// Also try Claude format (both formats can be checked)
|
// Also try Claude format (both formats can be checked)
|
||||||
buffer = extractClaudeStreamingToolCalls(body, buffer)
|
buffer = extractClaudeStreamingToolCalls(body, buffer)
|
||||||
ctx.SetContext(CtxStreamingToolCallsBuffer, buffer)
|
ctx.SetContext(CtxStreamingToolCallsBuffer, buffer)
|
||||||
|
|
||||||
// Also set tool_calls to user attributes so they appear in ai_log
|
// Also set tool_calls to user attributes so they appear in ai_log
|
||||||
toolCalls := getToolCallsFromBuffer(buffer)
|
toolCalls := getToolCallsFromBuffer(buffer)
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
|
|||||||
@@ -1712,3 +1712,305 @@ func TestTokenDetails(t *testing.T) {
|
|||||||
host.CompleteHttp()
|
host.CompleteHttp()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmatchedPathsAndContentTypes(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
restrictiveConfig := func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"enable_path_suffixes": []string{"/allowed_path"},
|
||||||
|
"enable_content_types": []string{"application/json"},
|
||||||
|
"attributes": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"key": "test_attr",
|
||||||
|
"value_source": "response_body",
|
||||||
|
"value": "data",
|
||||||
|
"apply_to_log": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"disable_openai_usage": true,
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("skip request for unenabled path", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(restrictiveConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/disallowed_path"},
|
||||||
|
{":method", "POST"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip response for unenabled content type", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(restrictiveConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/allowed_path"},
|
||||||
|
{":method", "POST"},
|
||||||
|
})
|
||||||
|
|
||||||
|
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"content-type", "text/plain"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSpanAttributeAndLoggingEdgeCases(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
configBytes := []byte(`{
|
||||||
|
"attributes": [
|
||||||
|
{
|
||||||
|
"key": "test_attr1",
|
||||||
|
"value_source": "fixed_value",
|
||||||
|
"value": "",
|
||||||
|
"apply_to_span": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "test_attr2",
|
||||||
|
"value_source": "fixed_value",
|
||||||
|
"value": "long_value_that_exceeds_limit_long_value_that_exceeds_limit_long_value_that_exceeds_limit",
|
||||||
|
"apply_to_log": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"value_length_limit": 20
|
||||||
|
}`)
|
||||||
|
|
||||||
|
t.Run("span attribute edge cases", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(configBytes)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// Setting fixed value attribute to empty should just print a debug log and skip setting span
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/api/chat"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRouteAndClusterNameEdgeCases(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
t.Run("properties absence", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost([]byte(`{}`))
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// Host doesn't have route_name implicitly by default without SetRouteName, but getRouteName handles err
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/api/chat"},
|
||||||
|
})
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("api name with @", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost([]byte(`{}`))
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.SetRouteName("api@v1@service@extra") // @ has special handling in getAPIName
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/api/chat"},
|
||||||
|
})
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeStreamingToolCallsMissingInput(t *testing.T) {
|
||||||
|
t.Run("claude missing partial_json", func(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather","input":{}}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta"}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_stop","index":0}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer *StreamingToolCallsBuffer
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := getToolCallsFromBuffer(buffer)
|
||||||
|
require.Len(t, toolCalls, 1)
|
||||||
|
require.Equal(t, "tool_123", toolCalls[0].ID)
|
||||||
|
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||||
|
require.Equal(t, "get_weather", toolCalls[0].Function.Name)
|
||||||
|
// partial_json absence means arguments might be empty
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMetricEdgeCases(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
t.Run("disable_openai_usage true", func(t *testing.T) {
|
||||||
|
configBytes := []byte(`{
|
||||||
|
"disable_openai_usage": true
|
||||||
|
}`)
|
||||||
|
host, status := test.NewTestHost(configBytes)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.SetRouteName("api-v1")
|
||||||
|
host.SetClusterName("cluster-1")
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/api/chat"},
|
||||||
|
})
|
||||||
|
|
||||||
|
host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"content-type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBody := []byte(`{
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 8, "total_tokens": 13},
|
||||||
|
"model": "gpt-3.5-turbo"
|
||||||
|
}`)
|
||||||
|
host.CallOnHttpResponseBody(responseBody)
|
||||||
|
host.CompleteHttp()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPathEnabled(t *testing.T) {
|
||||||
|
require.True(t, isPathEnabled("/v1/chat/completions", nil))
|
||||||
|
require.True(t, isPathEnabled("/v1/chat/completions", []string{}))
|
||||||
|
require.True(t, isPathEnabled("/v1/chat/completions", []string{"/completions", "/messages"}))
|
||||||
|
require.True(t, isPathEnabled("/v1/messages", []string{"/completions", "/messages"}))
|
||||||
|
require.False(t, isPathEnabled("/v1/embeddings", []string{"/completions", "/messages"}))
|
||||||
|
|
||||||
|
// test query params
|
||||||
|
require.True(t, isPathEnabled("/v1/chat/completions?stream=true", []string{"/completions"}))
|
||||||
|
require.False(t, isPathEnabled("/v1/embeddings?stream=true", []string{"/completions"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsContentTypeEnabled(t *testing.T) {
|
||||||
|
require.True(t, isContentTypeEnabled("application/json", nil))
|
||||||
|
require.True(t, isContentTypeEnabled("application/json", []string{}))
|
||||||
|
require.True(t, isContentTypeEnabled("application/json", []string{"application/json", "text/event-stream"}))
|
||||||
|
require.True(t, isContentTypeEnabled("text/event-stream; charset=utf-8", []string{"application/json", "text/event-stream"}))
|
||||||
|
require.False(t, isContentTypeEnabled("text/html", []string{"application/json", "text/event-stream"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertToUInt(t *testing.T) {
|
||||||
|
val, ok := convertToUInt(int32(10))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
val, ok = convertToUInt(int64(10))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
val, ok = convertToUInt(uint32(10))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
val, ok = convertToUInt(uint64(10))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
val, ok = convertToUInt(float32(10.0))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
val, ok = convertToUInt(float64(10.0))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, uint64(10), val)
|
||||||
|
|
||||||
|
_, ok = convertToUInt("10")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeStreamingToolCalls(t *testing.T) {
|
||||||
|
t.Run("claude tool use assembly", func(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather"}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"loc"}}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ation\":\"Bei"}}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"jing\"}"}}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_stop","index":0}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer *StreamingToolCallsBuffer
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := getToolCallsFromBuffer(buffer)
|
||||||
|
require.Len(t, toolCalls, 1)
|
||||||
|
require.Equal(t, "tool_123", toolCalls[0].ID)
|
||||||
|
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||||
|
require.Equal(t, "get_weather", toolCalls[0].Function.Name)
|
||||||
|
require.Equal(t, `{"location":"Beijing"}`, toolCalls[0].Function.Arguments)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("claude empty chunks", func(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte(`data: {"type":"ping"}`),
|
||||||
|
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`),
|
||||||
|
}
|
||||||
|
var buffer *StreamingToolCallsBuffer
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||||
|
}
|
||||||
|
toolCalls := getToolCallsFromBuffer(buffer)
|
||||||
|
require.Len(t, toolCalls, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("claude tool use with initial input", func(t *testing.T) {
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte(`data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"tool_456","name":"get_time","input":{"timezone":"UTC+8"}}}`),
|
||||||
|
[]byte(`data: {"type":"content_block_stop","index":1}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer *StreamingToolCallsBuffer
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := getToolCallsFromBuffer(buffer)
|
||||||
|
require.Len(t, toolCalls, 1)
|
||||||
|
require.Equal(t, "tool_456", toolCalls[0].ID)
|
||||||
|
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||||
|
require.Equal(t, "get_time", toolCalls[0].Function.Name)
|
||||||
|
require.Equal(t, `{"timezone":"UTC+8"}`, toolCalls[0].Function.Arguments)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigWithDefaultAttributes(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
t.Run("use default attributes config", func(t *testing.T) {
|
||||||
|
defaultConfig := []byte(`{
|
||||||
|
"use_default_attributes": true
|
||||||
|
}`)
|
||||||
|
host, status := test.NewTestHost(defaultConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("use default response attributes config", func(t *testing.T) {
|
||||||
|
defaultRespConfig := []byte(`{
|
||||||
|
"use_default_response_attributes": true
|
||||||
|
}`)
|
||||||
|
host, status := test.NewTestHost(defaultRespConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user