From 6199fe414d2ecdc1be1ec778ef89ae3e725e5695 Mon Sep 17 00:00:00 2001 From: Betula-L Date: Wed, 6 May 2026 04:48:42 -0700 Subject: [PATCH] bugfix: map bedrock tool-call indexes and tool_choice (#3786) Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com> --- .../extensions/ai-proxy/provider/bedrock.go | 69 +++- .../extensions/ai-proxy/provider/claude.go | 29 +- .../ai-proxy/provider/claude_test.go | 81 +++++ .../ai-proxy/provider/claude_to_openai.go | 13 +- .../provider/claude_to_openai_test.go | 60 ++++ .../extensions/ai-proxy/provider/model.go | 36 ++- .../extensions/ai-proxy/test/bedrock.go | 303 ++++++++++++++++++ 7 files changed, 557 insertions(+), 34 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index 80c6c346d..4fd4cda0b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -47,6 +47,8 @@ const ( bedrockCachePointPositionSystemPrompt = "systemPrompt" bedrockCachePointPositionLastUserMessage = "lastUserMessage" bedrockCachePointPositionLastMessage = "lastMessage" + + ctxKeyBedrockToolCallState = "bedrock_tool_call_state" ) var ( @@ -121,11 +123,13 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex chatChoice.Delta.Role = *bedrockEvent.Role } if bedrockEvent.Start != nil { + toolCallIndex := getBedrockOpenAIToolCallIndex(ctx, bedrockEvent.ContentBlockIndex) chatChoice.Delta.Content = nil chatChoice.Delta.ToolCalls = []toolCall{ { - Id: bedrockEvent.Start.ToolUse.ToolUseID, - Type: "function", + Index: toolCallIndex, + Id: bedrockEvent.Start.ToolUse.ToolUseID, + Type: "function", Function: functionCall{ Name: bedrockEvent.Start.ToolUse.Name, Arguments: "", @@ -152,9 +156,11 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex chatChoice.Delta = &chatMessage{Content: &content} } if bedrockEvent.Delta.ToolUse != nil { + toolCallIndex := getBedrockOpenAIToolCallIndex(ctx, bedrockEvent.ContentBlockIndex) chatChoice.Delta.ToolCalls = []toolCall{ { - Type: "function", + Index: toolCallIndex, + Type: "function", Function: functionCall{ Arguments: bedrockEvent.Delta.ToolUse.Input, }, @@ -192,6 +198,28 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex return []byte(openAIChunk.String()), nil } +type bedrockToolCallState struct { + indexes map[int]int + next int +} + +func getBedrockOpenAIToolCallIndex(ctx wrapper.HttpContext, contentBlockIndex int) int { + state, _ := ctx.GetContext(ctxKeyBedrockToolCallState).(*bedrockToolCallState) + if state == nil { + state = &bedrockToolCallState{indexes: make(map[int]int)} + ctx.SetContext(ctxKeyBedrockToolCallState, state) + } + + if toolCallIndex, ok := state.indexes[contentBlockIndex]; ok { + return toolCallIndex + } + + toolCallIndex := state.next + state.indexes[contentBlockIndex] = toolCallIndex + state.next++ + return toolCallIndex +} + type ConverseStreamEvent struct { ContentBlockIndex int `json:"contentBlockIndex,omitempty"` Delta *converseStreamEventContentBlockDelta `json:"delta,omitempty"` @@ -870,22 +898,25 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom } } - if origRequest.Tools != nil { + if origRequest.Tools != nil && origRequest.getToolChoiceType() != "none" { request.ToolConfig = &bedrockToolConfig{} - if origRequest.ToolChoice == nil { - request.ToolConfig.ToolChoice.Auto = &struct{}{} - } else if choice_type, ok := origRequest.ToolChoice.(string); ok { + request.ToolConfig.ToolChoice.Auto = &struct{}{} + if choice_type := origRequest.getToolChoiceType(); choice_type != "" { switch choice_type { - case "required": + // "any" is accepted for direct Anthropic-compatible callers; OpenAI + // uses "required" for the same "must call at least one tool" behavior. + case "required", "any": + request.ToolConfig.ToolChoice.Auto = nil request.ToolConfig.ToolChoice.Any = &struct{}{} case "auto": request.ToolConfig.ToolChoice.Auto = &struct{}{} - case "none": - request.ToolConfig.ToolChoice.Auto = &struct{}{} - } - } else if choice, ok := origRequest.ToolChoice.(toolChoice); ok { - request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{ - Name: choice.Function.Name, + case "function": + if choice := origRequest.getToolChoiceObject(); choice != nil && choice.Function.Name != "" { + request.ToolConfig.ToolChoice.Auto = nil + request.ToolConfig.ToolChoice.Tool = &bedrockSpecificToolChoice{ + Name: choice.Function.Name, + } + } } } request.ToolConfig.Tools = []bedrockTool{} @@ -1151,9 +1182,13 @@ type bedrockTool struct { } type bedrockToolChoice struct { - Any *struct{} `json:"any,omitempty"` - Auto *struct{} `json:"auto,omitempty"` - Tool *bedrockToolSpecification `json:"tool,omitempty"` + Any *struct{} `json:"any,omitempty"` + Auto *struct{} `json:"auto,omitempty"` + Tool *bedrockSpecificToolChoice `json:"tool,omitempty"` +} + +type bedrockSpecificToolChoice struct { + Name string `json:"name"` } type bedrockToolSpecification struct { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 4b763ce75..05fd05349 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -672,11 +672,30 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe claudeRequest.Tools = append(claudeRequest.Tools, claudeTool) } - if tc := origRequest.getToolChoiceObject(); tc != nil { - claudeRequest.ToolChoice = &claudeToolChoice{ - Name: tc.Function.Name, - Type: tc.Type, - DisableParallelToolUse: !origRequest.ParallelToolCalls, + if origRequest.ToolChoice != nil { + parallelToolCalls := true + if origRequest.ParallelToolCalls != nil { + parallelToolCalls = *origRequest.ParallelToolCalls + } + + choiceType := origRequest.getToolChoiceType() + if tc := origRequest.getToolChoiceObject(); tc != nil && tc.Type == "function" && tc.Function.Name != "" { + claudeRequest.ToolChoice = &claudeToolChoice{ + Name: tc.Function.Name, + Type: "tool", + DisableParallelToolUse: !parallelToolCalls, + } + } else if choiceType != "" { + switch choiceType { + case "required": + choiceType = "any" + } + claudeRequest.ToolChoice = &claudeToolChoice{ + Type: choiceType, + } + if choiceType != "none" { + claudeRequest.ToolChoice.DisableParallelToolUse = !parallelToolCalls + } } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude_test.go index 71494c8e5..4ad27162b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude_test.go @@ -183,6 +183,87 @@ func TestClaudeProvider_BuildClaudeTextGenRequest_StandardMode(t *testing.T) { assert.False(t, claudeReq.System.IsArray) assert.Equal(t, "You are a helpful assistant.", claudeReq.System.StringValue) }) + + t.Run("maps_openai_function_tool_choice_to_claude_tool_choice", func(t *testing.T) { + request := &chatCompletionRequest{ + Model: "claude-sonnet-4-5-20250929", + MaxTokens: 8192, + Messages: []chatMessage{ + {Role: roleUser, Content: "Search."}, + }, + Tools: []tool{{ + Type: "function", + Function: function{ + Name: "web_search", + Description: "Search the web.", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, + ToolChoice: map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": "web_search", + }, + }, + } + + claudeReq := provider.buildClaudeTextGenRequest(request) + + require.NotNil(t, claudeReq.ToolChoice) + assert.Equal(t, "tool", claudeReq.ToolChoice.Type) + assert.Equal(t, "web_search", claudeReq.ToolChoice.Name) + }) + + t.Run("maps_openai_string_required_tool_choice_to_claude_any", func(t *testing.T) { + parallelToolCalls := false + request := &chatCompletionRequest{ + Model: "claude-sonnet-4-5-20250929", + MaxTokens: 8192, + Messages: []chatMessage{ + {Role: roleUser, Content: "Search."}, + }, + Tools: []tool{{ + Type: "function", + Function: function{ + Name: "web_search", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, + ToolChoice: "required", + ParallelToolCalls: ¶llelToolCalls, + } + + claudeReq := provider.buildClaudeTextGenRequest(request) + + require.NotNil(t, claudeReq.ToolChoice) + assert.Equal(t, "any", claudeReq.ToolChoice.Type) + assert.Empty(t, claudeReq.ToolChoice.Name) + assert.True(t, claudeReq.ToolChoice.DisableParallelToolUse) + }) + + t.Run("maps_openai_string_none_tool_choice_to_claude_none", func(t *testing.T) { + request := &chatCompletionRequest{ + Model: "claude-sonnet-4-5-20250929", + MaxTokens: 8192, + Messages: []chatMessage{ + {Role: roleUser, Content: "Answer without tools."}, + }, + Tools: []tool{{ + Type: "function", + Function: function{ + Name: "web_search", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, + ToolChoice: "none", + } + + claudeReq := provider.buildClaudeTextGenRequest(request) + + require.NotNil(t, claudeReq.ToolChoice) + assert.Equal(t, "none", claudeReq.ToolChoice.Type) + assert.Empty(t, claudeReq.ToolChoice.Name) + }) } func TestClaudeProvider_BuildClaudeTextGenRequest_ClaudeCodeMode(t *testing.T) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go index d9416270e..3fc1e2e66 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go @@ -178,12 +178,19 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b }, } } else { - // For other types like "auto", "none", etc. - openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type + // Anthropic's "any" means the model must call at least one tool. + // OpenAI-compatible requests express the same behavior as "required". + if claudeRequest.ToolChoice.Type == "any" { + openaiRequest.ToolChoice = "required" + } else { + // For other types like "auto", "none", etc. + openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type + } } // Handle parallel tool calls - openaiRequest.ParallelToolCalls = !claudeRequest.ToolChoice.DisableParallelToolUse + parallelToolCalls := !claudeRequest.ToolChoice.DisableParallelToolUse + openaiRequest.ParallelToolCalls = ¶llelToolCalls } // Convert thinking configuration if present diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go index b490a3591..0e0a1b782 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go @@ -35,6 +35,66 @@ func init() { func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) { converter := &ClaudeToOpenAIConverter{} + t.Run("convert_tool_choice_any_to_required", func(t *testing.T) { + claudeRequest := `{ + "model": "claude-sonnet-4", + "max_tokens": 1000, + "messages": [{"role": "user", "content": "Run a search."}], + "tools": [{ + "name": "web_search", + "description": "Search the web.", + "input_schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + } + }], + "tool_choice": {"type": "any"} + }` + + result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest)) + require.NoError(t, err) + + var openaiRequest chatCompletionRequest + err = json.Unmarshal(result, &openaiRequest) + require.NoError(t, err) + + require.Equal(t, "required", openaiRequest.ToolChoice) + require.NotNil(t, openaiRequest.ParallelToolCalls) + require.True(t, *openaiRequest.ParallelToolCalls) + require.Contains(t, string(result), `"parallel_tool_calls":true`) + }) + + t.Run("convert_tool_choice_any_preserves_disable_parallel_tool_use", func(t *testing.T) { + claudeRequest := `{ + "model": "claude-sonnet-4", + "max_tokens": 1000, + "messages": [{"role": "user", "content": "Run a search."}], + "tools": [{ + "name": "web_search", + "description": "Search the web.", + "input_schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + } + }], + "tool_choice": {"type": "any", "disable_parallel_tool_use": true} + }` + + result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest)) + require.NoError(t, err) + + var openaiRequest chatCompletionRequest + err = json.Unmarshal(result, &openaiRequest) + require.NoError(t, err) + + require.Equal(t, "required", openaiRequest.ToolChoice) + require.NotNil(t, openaiRequest.ParallelToolCalls) + require.False(t, *openaiRequest.ParallelToolCalls) + require.Contains(t, string(result), `"parallel_tool_calls":false`) + }) + t.Run("convert_multiple_text_content_blocks", func(t *testing.T) { // Test case: multiple text content blocks should remain as separate array elements with cache control support // Both system and user messages should handle array content format diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 8f951f543..c83314c74 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -70,7 +70,7 @@ type chatCompletionRequest struct { TopP float64 `json:"top_p,omitempty"` Tools []tool `json:"tools,omitempty"` ToolChoice interface{} `json:"tool_choice,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` User string `json:"user,omitempty"` } @@ -92,15 +92,33 @@ func (c *chatCompletionRequest) getToolChoiceString() string { return "" } -func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice { - if c.ToolChoice == nil { - return nil - } - - if tc, ok := c.ToolChoice.(*toolChoice); ok { +func (c *chatCompletionRequest) getToolChoiceType() string { + if tc := c.getToolChoiceString(); tc != "" { return tc } - return nil + if tc := c.getToolChoiceObject(); tc != nil { + return tc.Type + } + return "" +} + +func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice { + switch tc := c.ToolChoice.(type) { + case nil, string: + return nil + case *toolChoice: + return tc + } + + body, err := json.Marshal(c.ToolChoice) + if err != nil { + return nil + } + var parsed toolChoice + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + return &parsed } type CompletionRequest struct { @@ -474,7 +492,7 @@ type toolCall struct { type functionCall struct { Id string `json:"id,omitempty"` - Name string `json:"name"` + Name string `json:"name,omitempty"` Arguments string `json:"arguments"` } diff --git a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go index a690f39ec..4e2c214a3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go @@ -1365,6 +1365,177 @@ func RunBedrockToolCallTests(t *testing.T) { require.Equal(t, "call_002", secondResult["toolUseId"]) }) + t.Run("bedrock maps any tool choice to converse any tool choice", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Run a search."}], + "tool_choice": "any", + "tools": [{ + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + } + } + }] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + toolConfig := bodyMap["toolConfig"].(map[string]interface{}) + toolChoice := toolConfig["toolChoice"].(map[string]interface{}) + require.Contains(t, toolChoice, "any") + require.Equal(t, map[string]interface{}{}, toolChoice["any"]) + }) + + t.Run("bedrock maps object tool choice to converse tool choice", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Run a search."}], + "tool_choice": {"type":"function","function":{"name":"web_search"}}, + "tools": [{ + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + } + } + }] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + toolConfig := bodyMap["toolConfig"].(map[string]interface{}) + toolChoice := toolConfig["toolChoice"].(map[string]interface{}) + tool := toolChoice["tool"].(map[string]interface{}) + require.Equal(t, "web_search", tool["name"]) + require.Len(t, tool, 1, "Bedrock specific tool choice should only include name") + }) + + t.Run("bedrock maps object any tool choice to converse any tool choice", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Run a search."}], + "tool_choice": {"type":"any"}, + "tools": [{ + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": {"type": "object"} + } + }] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + var bodyMap map[string]interface{} + err := json.Unmarshal(host.GetRequestBody(), &bodyMap) + require.NoError(t, err) + + toolConfig := bodyMap["toolConfig"].(map[string]interface{}) + toolChoice := toolConfig["toolChoice"].(map[string]interface{}) + require.Contains(t, toolChoice, "any") + require.NotContains(t, toolChoice, "tool") + }) + + t.Run("bedrock maps none tool choice by omitting tools", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Answer without tools."}], + "tool_choice": {"type":"none"}, + "tools": [{ + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": {"type": "object"} + } + }] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + var bodyMap map[string]interface{} + err := json.Unmarshal(host.GetRequestBody(), &bodyMap) + require.NoError(t, err) + require.NotContains(t, bodyMap, "toolConfig") + }) + // Test tool call with text content mixed t.Run("bedrock tool call with text content mixed", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) @@ -1669,6 +1840,138 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) { require.Equal(t, "", payload) }) + t.Run("bedrock streaming parallel tool calls should use dense OpenAI indexes", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Run three independent searches."}], + "stream": true, + "tools": [{ + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"] + } + } + }] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/vnd.amazon.eventstream"}, + }) + require.Equal(t, types.ActionContinue, action) + + extractFirstToolCallIndex := func(body []byte) int { + dataPayload := extractFirstDataPayload(body) + require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload") + + var responseMap map[string]interface{} + err := json.Unmarshal([]byte(dataPayload), &responseMap) + require.NoError(t, err) + + choices := responseMap["choices"].([]interface{}) + require.Len(t, choices, 1, "streaming chunk should contain one choice") + delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{}) + toolCalls := delta["tool_calls"].([]interface{}) + require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta") + + index, ok := toolCalls[0].(map[string]interface{})["index"].(float64) + require.True(t, ok, "tool call delta should include an index") + return int(index) + } + extractFirstToolCall := func(body []byte) map[string]interface{} { + dataPayload := extractFirstDataPayload(body) + require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload") + + var responseMap map[string]interface{} + err := json.Unmarshal([]byte(dataPayload), &responseMap) + require.NoError(t, err) + + choices := responseMap["choices"].([]interface{}) + require.Len(t, choices, 1, "streaming chunk should contain one choice") + delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{}) + toolCalls := delta["tool_calls"].([]interface{}) + require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta") + + toolCall, ok := toolCalls[0].(map[string]interface{}) + require.True(t, ok, "tool call delta should be an object") + return toolCall + } + + for expectedIndex, item := range []struct { + contentBlockIndex int + toolUseId string + }{ + {contentBlockIndex: 1, toolUseId: "tooluse_first"}, + {contentBlockIndex: 3, toolUseId: "tooluse_second"}, + {contentBlockIndex: 4, toolUseId: "tooluse_third"}, + } { + toolCallStart := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "contentBlockIndex": item.contentBlockIndex, + "start": map[string]interface{}{ + "toolUse": map[string]interface{}{ + "toolUseId": item.toolUseId, + "name": "web_search", + }, + }, + }) + action = host.CallOnHttpStreamingResponseBody(toolCallStart, false) + require.Equal(t, types.ActionContinue, action) + toolCall := extractFirstToolCall(host.GetResponseBody()) + require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody())) + require.Equal(t, item.toolUseId, toolCall["id"]) + require.Equal(t, "function", toolCall["type"]) + function := toolCall["function"].(map[string]interface{}) + require.Equal(t, "web_search", function["name"]) + require.Equal(t, "", function["arguments"]) + } + + for expectedIndex, item := range []struct { + contentBlockIndex int + query string + }{ + {contentBlockIndex: 1, query: "first synthetic query"}, + {contentBlockIndex: 3, query: "second synthetic query"}, + {contentBlockIndex: 4, query: "third synthetic query"}, + } { + toolCallDelta := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "contentBlockIndex": item.contentBlockIndex, + "delta": map[string]interface{}{ + "toolUse": map[string]interface{}{ + "input": "{\"query\":\"" + item.query + "\"}", + }, + }, + }) + action = host.CallOnHttpStreamingResponseBody(toolCallDelta, false) + require.Equal(t, types.ActionContinue, action) + toolCall := extractFirstToolCall(host.GetResponseBody()) + require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody())) + require.NotContains(t, toolCall, "id") + function := toolCall["function"].(map[string]interface{}) + require.NotContains(t, function, "name") + require.Equal(t, "{\"query\":\""+item.query+"\"}", function["arguments"]) + } + }) + t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) defer host.Reset()