fix(provider/bedrock.go): 优化工具调用消息处理逻辑 || fix(provider/bedrock.go): Optimization tool calls message processing logic (#3470)

This commit is contained in:
woody
2026-02-11 12:33:12 +08:00
committed by GitHub
parent 0cc92aa6b8
commit 5e2892f18c
3 changed files with 206 additions and 17 deletions

View File

@@ -149,6 +149,7 @@ func TestBedrock(t *testing.T) {
test.RunBedrockOnHttpRequestBodyTests(t)
test.RunBedrockOnHttpResponseHeadersTests(t)
test.RunBedrockOnHttpResponseBodyTests(t)
test.RunBedrockToolCallTests(t)
}
func TestClaude(t *testing.T) {

View File

@@ -769,7 +769,15 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
case roleSystem:
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
case roleTool:
messages = append(messages, chatToolMessage2BedrockMessage(msg))
toolResultContent := chatToolMessage2BedrockToolResultContent(msg)
if len(messages) > 0 && messages[len(messages)-1].Role == roleUser && messages[len(messages)-1].Content[0].ToolResult != nil {
messages[len(messages)-1].Content = append(messages[len(messages)-1].Content, toolResultContent)
} else {
messages = append(messages, bedrockMessage{
Role: roleUser,
Content: []bedrockMessageContent{toolResultContent},
})
}
default:
messages = append(messages, chatMessage2BedrockMessage(msg))
}
@@ -1060,7 +1068,7 @@ type tokenUsage struct {
TotalTokens int `json:"totalTokens"`
}
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
toolResultContent := &toolResultBlock{}
toolResultContent.ToolUseId = chatMessage.ToolCallId
if text, ok := chatMessage.Content.(string); ok {
@@ -1083,29 +1091,29 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
} else {
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
}
return bedrockMessage{
Role: roleUser,
Content: []bedrockMessageContent{
{
ToolResult: toolResultContent,
},
},
return bedrockMessageContent{
ToolResult: toolResultContent,
}
}
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
var result bedrockMessage
if len(chatMessage.ToolCalls) > 0 {
contents := make([]bedrockMessageContent, 0, len(chatMessage.ToolCalls))
for _, toolCall := range chatMessage.ToolCalls {
params := map[string]interface{}{}
json.Unmarshal([]byte(toolCall.Function.Arguments), &params)
contents = append(contents, bedrockMessageContent{
ToolUse: &toolUseBlock{
Input: params,
Name: toolCall.Function.Name,
ToolUseId: toolCall.Id,
},
})
}
result = bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{}},
}
params := map[string]interface{}{}
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), &params)
result.Content[0].ToolUse = &toolUseBlock{
Input: params,
Name: chatMessage.ToolCalls[0].Function.Name,
ToolUseId: chatMessage.ToolCalls[0].Id,
Content: contents,
}
} else if chatMessage.IsStringContent() {
result = bedrockMessage{

View File

@@ -442,6 +442,186 @@ func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
})
}
func RunBedrockToolCallTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test single tool call conversion (regression test)
t.Run("bedrock single tool call conversion", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "What is the weather in Beijing?"},
{"role": "assistant", "content": "Let me check the weather for you.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"}
],
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
messages := bodyMap["messages"].([]interface{})
// messages[0] = user, messages[1] = assistant with toolUse, messages[2] = user with toolResult
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResult)")
// Verify assistant message has exactly 1 toolUse
assistantMsg := messages[1].(map[string]interface{})
require.Equal(t, "assistant", assistantMsg["role"])
assistantContent := assistantMsg["content"].([]interface{})
require.Len(t, assistantContent, 1, "Assistant should have 1 content block")
toolUseBlock := assistantContent[0].(map[string]interface{})
require.Contains(t, toolUseBlock, "toolUse", "Content block should contain toolUse")
// Verify tool result message
toolResultMsg := messages[2].(map[string]interface{})
require.Equal(t, "user", toolResultMsg["role"])
toolResultContent := toolResultMsg["content"].([]interface{})
require.Len(t, toolResultContent, 1, "Tool result message should have 1 content block")
require.Contains(t, toolResultContent[0].(map[string]interface{}), "toolResult", "Content block should contain toolResult")
})
// Test multiple parallel tool calls conversion
t.Run("bedrock multiple parallel tool calls conversion", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "What is the weather in Beijing and Shanghai?"},
{"role": "assistant", "content": "Let me check both cities.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}, {"id": "call_002", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Shanghai\"}"}}]},
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
{"role": "tool", "content": "Cloudy, 22°C", "tool_call_id": "call_002"}
],
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
messages := bodyMap["messages"].([]interface{})
// messages[0] = user, messages[1] = assistant with 2 toolUse, messages[2] = user with 2 toolResult
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResults merged)")
// Verify assistant message has 2 toolUse blocks
assistantMsg := messages[1].(map[string]interface{})
require.Equal(t, "assistant", assistantMsg["role"])
assistantContent := assistantMsg["content"].([]interface{})
require.Len(t, assistantContent, 2, "Assistant should have 2 content blocks for parallel tool calls")
firstToolUse := assistantContent[0].(map[string]interface{})["toolUse"].(map[string]interface{})
require.Equal(t, "get_weather", firstToolUse["name"])
require.Equal(t, "call_001", firstToolUse["toolUseId"])
secondToolUse := assistantContent[1].(map[string]interface{})["toolUse"].(map[string]interface{})
require.Equal(t, "get_weather", secondToolUse["name"])
require.Equal(t, "call_002", secondToolUse["toolUseId"])
// Verify tool results are merged into a single user message
toolResultMsg := messages[2].(map[string]interface{})
require.Equal(t, "user", toolResultMsg["role"])
toolResultContent := toolResultMsg["content"].([]interface{})
require.Len(t, toolResultContent, 2, "Tool results should be merged into 2 content blocks in one user message")
firstResult := toolResultContent[0].(map[string]interface{})["toolResult"].(map[string]interface{})
require.Equal(t, "call_001", firstResult["toolUseId"])
secondResult := toolResultContent[1].(map[string]interface{})["toolResult"].(map[string]interface{})
require.Equal(t, "call_002", secondResult["toolUseId"])
})
// Test tool call with text content mixed
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "What is the weather in Beijing?"},
{"role": "assistant", "content": "Let me check.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
{"role": "assistant", "content": "The weather in Beijing is sunny with 25°C."},
{"role": "user", "content": "Thanks!"}
],
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
messages := bodyMap["messages"].([]interface{})
// messages[0] = user, messages[1] = assistant(toolUse), messages[2] = user(toolResult),
// messages[3] = assistant(text), messages[4] = user(text)
require.Len(t, messages, 5, "Should have 5 messages in mixed tool call and text scenario")
// Verify message roles alternate correctly
require.Equal(t, "user", messages[0].(map[string]interface{})["role"])
require.Equal(t, "assistant", messages[1].(map[string]interface{})["role"])
require.Equal(t, "user", messages[2].(map[string]interface{})["role"])
require.Equal(t, "assistant", messages[3].(map[string]interface{})["role"])
require.Equal(t, "user", messages[4].(map[string]interface{})["role"])
// Verify assistant text message (messages[3]) has text content
assistantTextMsg := messages[3].(map[string]interface{})
assistantTextContent := assistantTextMsg["content"].([]interface{})
require.Len(t, assistantTextContent, 1)
require.Contains(t, assistantTextContent[0].(map[string]interface{}), "text", "Text assistant message should have text content")
require.Contains(t, assistantTextContent[0].(map[string]interface{})["text"], "sunny", "Text content should contain weather info")
})
})
}
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock response body processing