From 5e2892f18cb67c3894a016d5ad5ca3d9d6e4113a Mon Sep 17 00:00:00 2001 From: woody Date: Wed, 11 Feb 2026 12:33:12 +0800 Subject: [PATCH] =?UTF-8?q?fix(provider/bedrock.go):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E6=B6=88=E6=81=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=20||=20fix(provider/bedrock.go):=20?= =?UTF-8?q?Optimization=20tool=20calls=20message=20processing=20logic=20(#?= =?UTF-8?q?3470)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../wasm-go/extensions/ai-proxy/main_test.go | 1 + .../extensions/ai-proxy/provider/bedrock.go | 42 ++-- .../extensions/ai-proxy/test/bedrock.go | 180 ++++++++++++++++++ 3 files changed, 206 insertions(+), 17 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 3155ce219..ffae63f0b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -149,6 +149,7 @@ func TestBedrock(t *testing.T) { test.RunBedrockOnHttpRequestBodyTests(t) test.RunBedrockOnHttpResponseHeadersTests(t) test.RunBedrockOnHttpResponseBodyTests(t) + test.RunBedrockToolCallTests(t) } func TestClaude(t *testing.T) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index 85b341a2a..84a03dee0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -769,7 +769,15 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom case roleSystem: systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()}) case roleTool: - messages = append(messages, chatToolMessage2BedrockMessage(msg)) + toolResultContent := chatToolMessage2BedrockToolResultContent(msg) + if len(messages) > 0 && messages[len(messages)-1].Role == roleUser && messages[len(messages)-1].Content[0].ToolResult != nil { + messages[len(messages)-1].Content = append(messages[len(messages)-1].Content, toolResultContent) + } else { + messages = append(messages, bedrockMessage{ + Role: roleUser, + Content: []bedrockMessageContent{toolResultContent}, + }) + } default: messages = append(messages, chatMessage2BedrockMessage(msg)) } @@ -1060,7 +1068,7 @@ type tokenUsage struct { TotalTokens int `json:"totalTokens"` } -func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { +func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent { toolResultContent := &toolResultBlock{} toolResultContent.ToolUseId = chatMessage.ToolCallId if text, ok := chatMessage.Content.(string); ok { @@ -1083,29 +1091,29 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { } else { log.Warnf("the content type is not supported, current content is %v", chatMessage.Content) } - return bedrockMessage{ - Role: roleUser, - Content: []bedrockMessageContent{ - { - ToolResult: toolResultContent, - }, - }, + return bedrockMessageContent{ + ToolResult: toolResultContent, } } func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { var result bedrockMessage if len(chatMessage.ToolCalls) > 0 { + contents := make([]bedrockMessageContent, 0, len(chatMessage.ToolCalls)) + for _, toolCall := range chatMessage.ToolCalls { + params := map[string]interface{}{} + json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms) + contents = append(contents, bedrockMessageContent{ + ToolUse: &toolUseBlock{ + Input: params, + Name: toolCall.Function.Name, + ToolUseId: toolCall.Id, + }, + }) + } result = bedrockMessage{ Role: chatMessage.Role, - Content: []bedrockMessageContent{{}}, - } - params := map[string]interface{}{} - json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms) - result.Content[0].ToolUse = &toolUseBlock{ - Input: params, - Name: chatMessage.ToolCalls[0].Function.Name, - ToolUseId: chatMessage.ToolCalls[0].Id, + Content: contents, } } else if chatMessage.IsStringContent() { result = bedrockMessage{ diff --git a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go index 766e11e06..c4d9d4c23 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go @@ -442,6 +442,186 @@ func RunBedrockOnHttpResponseHeadersTests(t *testing.T) { }) } +func RunBedrockToolCallTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // Test single tool call conversion (regression test) + t.Run("bedrock single tool call conversion", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "What is the weather in Beijing?"}, + {"role": "assistant", "content": "Let me check the weather for you.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]}, + {"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"} + ], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + messages := bodyMap["messages"].([]interface{}) + // messages[0] = user, messages[1] = assistant with toolUse, messages[2] = user with toolResult + require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResult)") + + // Verify assistant message has exactly 1 toolUse + assistantMsg := messages[1].(map[string]interface{}) + require.Equal(t, "assistant", assistantMsg["role"]) + assistantContent := assistantMsg["content"].([]interface{}) + require.Len(t, assistantContent, 1, "Assistant should have 1 content block") + toolUseBlock := assistantContent[0].(map[string]interface{}) + require.Contains(t, toolUseBlock, "toolUse", "Content block should contain toolUse") + + // Verify tool result message + toolResultMsg := messages[2].(map[string]interface{}) + require.Equal(t, "user", toolResultMsg["role"]) + toolResultContent := toolResultMsg["content"].([]interface{}) + require.Len(t, toolResultContent, 1, "Tool result message should have 1 content block") + require.Contains(t, toolResultContent[0].(map[string]interface{}), "toolResult", "Content block should contain toolResult") + }) + + // Test multiple parallel tool calls conversion + t.Run("bedrock multiple parallel tool calls conversion", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "What is the weather in Beijing and Shanghai?"}, + {"role": "assistant", "content": "Let me check both cities.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}, {"id": "call_002", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Shanghai\"}"}}]}, + {"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"}, + {"role": "tool", "content": "Cloudy, 22°C", "tool_call_id": "call_002"} + ], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + messages := bodyMap["messages"].([]interface{}) + // messages[0] = user, messages[1] = assistant with 2 toolUse, messages[2] = user with 2 toolResult + require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResults merged)") + + // Verify assistant message has 2 toolUse blocks + assistantMsg := messages[1].(map[string]interface{}) + require.Equal(t, "assistant", assistantMsg["role"]) + assistantContent := assistantMsg["content"].([]interface{}) + require.Len(t, assistantContent, 2, "Assistant should have 2 content blocks for parallel tool calls") + + firstToolUse := assistantContent[0].(map[string]interface{})["toolUse"].(map[string]interface{}) + require.Equal(t, "get_weather", firstToolUse["name"]) + require.Equal(t, "call_001", firstToolUse["toolUseId"]) + + secondToolUse := assistantContent[1].(map[string]interface{})["toolUse"].(map[string]interface{}) + require.Equal(t, "get_weather", secondToolUse["name"]) + require.Equal(t, "call_002", secondToolUse["toolUseId"]) + + // Verify tool results are merged into a single user message + toolResultMsg := messages[2].(map[string]interface{}) + require.Equal(t, "user", toolResultMsg["role"]) + toolResultContent := toolResultMsg["content"].([]interface{}) + require.Len(t, toolResultContent, 2, "Tool results should be merged into 2 content blocks in one user message") + + firstResult := toolResultContent[0].(map[string]interface{})["toolResult"].(map[string]interface{}) + require.Equal(t, "call_001", firstResult["toolUseId"]) + + secondResult := toolResultContent[1].(map[string]interface{})["toolResult"].(map[string]interface{}) + require.Equal(t, "call_002", secondResult["toolUseId"]) + }) + + // Test tool call with text content mixed + t.Run("bedrock tool call with text content mixed", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "What is the weather in Beijing?"}, + {"role": "assistant", "content": "Let me check.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]}, + {"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"}, + {"role": "assistant", "content": "The weather in Beijing is sunny with 25°C."}, + {"role": "user", "content": "Thanks!"} + ], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}] + }` + + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + messages := bodyMap["messages"].([]interface{}) + // messages[0] = user, messages[1] = assistant(toolUse), messages[2] = user(toolResult), + // messages[3] = assistant(text), messages[4] = user(text) + require.Len(t, messages, 5, "Should have 5 messages in mixed tool call and text scenario") + + // Verify message roles alternate correctly + require.Equal(t, "user", messages[0].(map[string]interface{})["role"]) + require.Equal(t, "assistant", messages[1].(map[string]interface{})["role"]) + require.Equal(t, "user", messages[2].(map[string]interface{})["role"]) + require.Equal(t, "assistant", messages[3].(map[string]interface{})["role"]) + require.Equal(t, "user", messages[4].(map[string]interface{})["role"]) + + // Verify assistant text message (messages[3]) has text content + assistantTextMsg := messages[3].(map[string]interface{}) + assistantTextContent := assistantTextMsg["content"].([]interface{}) + require.Len(t, assistantTextContent, 1) + require.Contains(t, assistantTextContent[0].(map[string]interface{}), "text", "Text assistant message should have text content") + require.Contains(t, assistantTextContent[0].(map[string]interface{})["text"], "sunny", "Text content should contain weather info") + }) + }) +} + func RunBedrockOnHttpResponseBodyTests(t *testing.T) { test.RunTest(t, func(t *testing.T) { // Test Bedrock response body processing