mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 17:40:51 +08:00
fix(provider/bedrock.go): 优化工具调用消息处理逻辑 || fix(provider/bedrock.go): Optimization tool calls message processing logic (#3470)
This commit is contained in:
@@ -149,6 +149,7 @@ func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
test.RunBedrockToolCallTests(t)
|
||||
}
|
||||
|
||||
func TestClaude(t *testing.T) {
|
||||
|
||||
@@ -769,7 +769,15 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
case roleSystem:
|
||||
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
||||
case roleTool:
|
||||
messages = append(messages, chatToolMessage2BedrockMessage(msg))
|
||||
toolResultContent := chatToolMessage2BedrockToolResultContent(msg)
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == roleUser && messages[len(messages)-1].Content[0].ToolResult != nil {
|
||||
messages[len(messages)-1].Content = append(messages[len(messages)-1].Content, toolResultContent)
|
||||
} else {
|
||||
messages = append(messages, bedrockMessage{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{toolResultContent},
|
||||
})
|
||||
}
|
||||
default:
|
||||
messages = append(messages, chatMessage2BedrockMessage(msg))
|
||||
}
|
||||
@@ -1060,7 +1068,7 @@ type tokenUsage struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
|
||||
toolResultContent := &toolResultBlock{}
|
||||
toolResultContent.ToolUseId = chatMessage.ToolCallId
|
||||
if text, ok := chatMessage.Content.(string); ok {
|
||||
@@ -1083,29 +1091,29 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
} else {
|
||||
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{
|
||||
ToolResult: toolResultContent,
|
||||
},
|
||||
},
|
||||
return bedrockMessageContent{
|
||||
ToolResult: toolResultContent,
|
||||
}
|
||||
}
|
||||
|
||||
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
var result bedrockMessage
|
||||
if len(chatMessage.ToolCalls) > 0 {
|
||||
contents := make([]bedrockMessageContent, 0, len(chatMessage.ToolCalls))
|
||||
for _, toolCall := range chatMessage.ToolCalls {
|
||||
params := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms)
|
||||
contents = append(contents, bedrockMessageContent{
|
||||
ToolUse: &toolUseBlock{
|
||||
Input: params,
|
||||
Name: toolCall.Function.Name,
|
||||
ToolUseId: toolCall.Id,
|
||||
},
|
||||
})
|
||||
}
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: []bedrockMessageContent{{}},
|
||||
}
|
||||
params := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms)
|
||||
result.Content[0].ToolUse = &toolUseBlock{
|
||||
Input: params,
|
||||
Name: chatMessage.ToolCalls[0].Function.Name,
|
||||
ToolUseId: chatMessage.ToolCalls[0].Id,
|
||||
Content: contents,
|
||||
}
|
||||
} else if chatMessage.IsStringContent() {
|
||||
result = bedrockMessage{
|
||||
|
||||
@@ -442,6 +442,186 @@ func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockToolCallTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test single tool call conversion (regression test)
|
||||
t.Run("bedrock single tool call conversion", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Beijing?"},
|
||||
{"role": "assistant", "content": "Let me check the weather for you.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
|
||||
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"}
|
||||
],
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
// messages[0] = user, messages[1] = assistant with toolUse, messages[2] = user with toolResult
|
||||
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResult)")
|
||||
|
||||
// Verify assistant message has exactly 1 toolUse
|
||||
assistantMsg := messages[1].(map[string]interface{})
|
||||
require.Equal(t, "assistant", assistantMsg["role"])
|
||||
assistantContent := assistantMsg["content"].([]interface{})
|
||||
require.Len(t, assistantContent, 1, "Assistant should have 1 content block")
|
||||
toolUseBlock := assistantContent[0].(map[string]interface{})
|
||||
require.Contains(t, toolUseBlock, "toolUse", "Content block should contain toolUse")
|
||||
|
||||
// Verify tool result message
|
||||
toolResultMsg := messages[2].(map[string]interface{})
|
||||
require.Equal(t, "user", toolResultMsg["role"])
|
||||
toolResultContent := toolResultMsg["content"].([]interface{})
|
||||
require.Len(t, toolResultContent, 1, "Tool result message should have 1 content block")
|
||||
require.Contains(t, toolResultContent[0].(map[string]interface{}), "toolResult", "Content block should contain toolResult")
|
||||
})
|
||||
|
||||
// Test multiple parallel tool calls conversion
|
||||
t.Run("bedrock multiple parallel tool calls conversion", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Beijing and Shanghai?"},
|
||||
{"role": "assistant", "content": "Let me check both cities.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}, {"id": "call_002", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Shanghai\"}"}}]},
|
||||
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
|
||||
{"role": "tool", "content": "Cloudy, 22°C", "tool_call_id": "call_002"}
|
||||
],
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
// messages[0] = user, messages[1] = assistant with 2 toolUse, messages[2] = user with 2 toolResult
|
||||
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResults merged)")
|
||||
|
||||
// Verify assistant message has 2 toolUse blocks
|
||||
assistantMsg := messages[1].(map[string]interface{})
|
||||
require.Equal(t, "assistant", assistantMsg["role"])
|
||||
assistantContent := assistantMsg["content"].([]interface{})
|
||||
require.Len(t, assistantContent, 2, "Assistant should have 2 content blocks for parallel tool calls")
|
||||
|
||||
firstToolUse := assistantContent[0].(map[string]interface{})["toolUse"].(map[string]interface{})
|
||||
require.Equal(t, "get_weather", firstToolUse["name"])
|
||||
require.Equal(t, "call_001", firstToolUse["toolUseId"])
|
||||
|
||||
secondToolUse := assistantContent[1].(map[string]interface{})["toolUse"].(map[string]interface{})
|
||||
require.Equal(t, "get_weather", secondToolUse["name"])
|
||||
require.Equal(t, "call_002", secondToolUse["toolUseId"])
|
||||
|
||||
// Verify tool results are merged into a single user message
|
||||
toolResultMsg := messages[2].(map[string]interface{})
|
||||
require.Equal(t, "user", toolResultMsg["role"])
|
||||
toolResultContent := toolResultMsg["content"].([]interface{})
|
||||
require.Len(t, toolResultContent, 2, "Tool results should be merged into 2 content blocks in one user message")
|
||||
|
||||
firstResult := toolResultContent[0].(map[string]interface{})["toolResult"].(map[string]interface{})
|
||||
require.Equal(t, "call_001", firstResult["toolUseId"])
|
||||
|
||||
secondResult := toolResultContent[1].(map[string]interface{})["toolResult"].(map[string]interface{})
|
||||
require.Equal(t, "call_002", secondResult["toolUseId"])
|
||||
})
|
||||
|
||||
// Test tool call with text content mixed
|
||||
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Beijing?"},
|
||||
{"role": "assistant", "content": "Let me check.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
|
||||
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
|
||||
{"role": "assistant", "content": "The weather in Beijing is sunny with 25°C."},
|
||||
{"role": "user", "content": "Thanks!"}
|
||||
],
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
// messages[0] = user, messages[1] = assistant(toolUse), messages[2] = user(toolResult),
|
||||
// messages[3] = assistant(text), messages[4] = user(text)
|
||||
require.Len(t, messages, 5, "Should have 5 messages in mixed tool call and text scenario")
|
||||
|
||||
// Verify message roles alternate correctly
|
||||
require.Equal(t, "user", messages[0].(map[string]interface{})["role"])
|
||||
require.Equal(t, "assistant", messages[1].(map[string]interface{})["role"])
|
||||
require.Equal(t, "user", messages[2].(map[string]interface{})["role"])
|
||||
require.Equal(t, "assistant", messages[3].(map[string]interface{})["role"])
|
||||
require.Equal(t, "user", messages[4].(map[string]interface{})["role"])
|
||||
|
||||
// Verify assistant text message (messages[3]) has text content
|
||||
assistantTextMsg := messages[3].(map[string]interface{})
|
||||
assistantTextContent := assistantTextMsg["content"].([]interface{})
|
||||
require.Len(t, assistantTextContent, 1)
|
||||
require.Contains(t, assistantTextContent[0].(map[string]interface{}), "text", "Text assistant message should have text content")
|
||||
require.Contains(t, assistantTextContent[0].(map[string]interface{})["text"], "sunny", "Text content should contain weather info")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock response body processing
|
||||
|
||||
Reference in New Issue
Block a user