mirror of
https://github.com/alibaba/higress.git
synced 2026-06-02 09:07:26 +08:00
fix(provider/bedrock.go): 优化工具调用消息处理逻辑 || fix(provider/bedrock.go): Optimization tool calls message processing logic (#3470)
This commit is contained in:
@@ -149,6 +149,7 @@ func TestBedrock(t *testing.T) {
|
|||||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||||
|
test.RunBedrockToolCallTests(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaude(t *testing.T) {
|
func TestClaude(t *testing.T) {
|
||||||
|
|||||||
@@ -769,7 +769,15 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
|||||||
case roleSystem:
|
case roleSystem:
|
||||||
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
||||||
case roleTool:
|
case roleTool:
|
||||||
messages = append(messages, chatToolMessage2BedrockMessage(msg))
|
toolResultContent := chatToolMessage2BedrockToolResultContent(msg)
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == roleUser && messages[len(messages)-1].Content[0].ToolResult != nil {
|
||||||
|
messages[len(messages)-1].Content = append(messages[len(messages)-1].Content, toolResultContent)
|
||||||
|
} else {
|
||||||
|
messages = append(messages, bedrockMessage{
|
||||||
|
Role: roleUser,
|
||||||
|
Content: []bedrockMessageContent{toolResultContent},
|
||||||
|
})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
messages = append(messages, chatMessage2BedrockMessage(msg))
|
messages = append(messages, chatMessage2BedrockMessage(msg))
|
||||||
}
|
}
|
||||||
@@ -1060,7 +1068,7 @@ type tokenUsage struct {
|
|||||||
TotalTokens int `json:"totalTokens"`
|
TotalTokens int `json:"totalTokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
|
||||||
toolResultContent := &toolResultBlock{}
|
toolResultContent := &toolResultBlock{}
|
||||||
toolResultContent.ToolUseId = chatMessage.ToolCallId
|
toolResultContent.ToolUseId = chatMessage.ToolCallId
|
||||||
if text, ok := chatMessage.Content.(string); ok {
|
if text, ok := chatMessage.Content.(string); ok {
|
||||||
@@ -1083,29 +1091,29 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
|||||||
} else {
|
} else {
|
||||||
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
|
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
|
||||||
}
|
}
|
||||||
return bedrockMessage{
|
return bedrockMessageContent{
|
||||||
Role: roleUser,
|
ToolResult: toolResultContent,
|
||||||
Content: []bedrockMessageContent{
|
|
||||||
{
|
|
||||||
ToolResult: toolResultContent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||||
var result bedrockMessage
|
var result bedrockMessage
|
||||||
if len(chatMessage.ToolCalls) > 0 {
|
if len(chatMessage.ToolCalls) > 0 {
|
||||||
|
contents := make([]bedrockMessageContent, 0, len(chatMessage.ToolCalls))
|
||||||
|
for _, toolCall := range chatMessage.ToolCalls {
|
||||||
|
params := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms)
|
||||||
|
contents = append(contents, bedrockMessageContent{
|
||||||
|
ToolUse: &toolUseBlock{
|
||||||
|
Input: params,
|
||||||
|
Name: toolCall.Function.Name,
|
||||||
|
ToolUseId: toolCall.Id,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
result = bedrockMessage{
|
result = bedrockMessage{
|
||||||
Role: chatMessage.Role,
|
Role: chatMessage.Role,
|
||||||
Content: []bedrockMessageContent{{}},
|
Content: contents,
|
||||||
}
|
|
||||||
params := map[string]interface{}{}
|
|
||||||
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms)
|
|
||||||
result.Content[0].ToolUse = &toolUseBlock{
|
|
||||||
Input: params,
|
|
||||||
Name: chatMessage.ToolCalls[0].Function.Name,
|
|
||||||
ToolUseId: chatMessage.ToolCalls[0].Id,
|
|
||||||
}
|
}
|
||||||
} else if chatMessage.IsStringContent() {
|
} else if chatMessage.IsStringContent() {
|
||||||
result = bedrockMessage{
|
result = bedrockMessage{
|
||||||
|
|||||||
@@ -442,6 +442,186 @@ func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunBedrockToolCallTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
// Test single tool call conversion (regression test)
|
||||||
|
t.Run("bedrock single tool call conversion", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Beijing?"},
|
||||||
|
{"role": "assistant", "content": "Let me check the weather for you.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
|
||||||
|
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"}
|
||||||
|
],
|
||||||
|
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||||
|
}`
|
||||||
|
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
messages := bodyMap["messages"].([]interface{})
|
||||||
|
// messages[0] = user, messages[1] = assistant with toolUse, messages[2] = user with toolResult
|
||||||
|
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResult)")
|
||||||
|
|
||||||
|
// Verify assistant message has exactly 1 toolUse
|
||||||
|
assistantMsg := messages[1].(map[string]interface{})
|
||||||
|
require.Equal(t, "assistant", assistantMsg["role"])
|
||||||
|
assistantContent := assistantMsg["content"].([]interface{})
|
||||||
|
require.Len(t, assistantContent, 1, "Assistant should have 1 content block")
|
||||||
|
toolUseBlock := assistantContent[0].(map[string]interface{})
|
||||||
|
require.Contains(t, toolUseBlock, "toolUse", "Content block should contain toolUse")
|
||||||
|
|
||||||
|
// Verify tool result message
|
||||||
|
toolResultMsg := messages[2].(map[string]interface{})
|
||||||
|
require.Equal(t, "user", toolResultMsg["role"])
|
||||||
|
toolResultContent := toolResultMsg["content"].([]interface{})
|
||||||
|
require.Len(t, toolResultContent, 1, "Tool result message should have 1 content block")
|
||||||
|
require.Contains(t, toolResultContent[0].(map[string]interface{}), "toolResult", "Content block should contain toolResult")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test multiple parallel tool calls conversion
|
||||||
|
t.Run("bedrock multiple parallel tool calls conversion", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Beijing and Shanghai?"},
|
||||||
|
{"role": "assistant", "content": "Let me check both cities.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}, {"id": "call_002", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Shanghai\"}"}}]},
|
||||||
|
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
|
||||||
|
{"role": "tool", "content": "Cloudy, 22°C", "tool_call_id": "call_002"}
|
||||||
|
],
|
||||||
|
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||||
|
}`
|
||||||
|
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
messages := bodyMap["messages"].([]interface{})
|
||||||
|
// messages[0] = user, messages[1] = assistant with 2 toolUse, messages[2] = user with 2 toolResult
|
||||||
|
require.Len(t, messages, 3, "Should have 3 messages: user, assistant, user(toolResults merged)")
|
||||||
|
|
||||||
|
// Verify assistant message has 2 toolUse blocks
|
||||||
|
assistantMsg := messages[1].(map[string]interface{})
|
||||||
|
require.Equal(t, "assistant", assistantMsg["role"])
|
||||||
|
assistantContent := assistantMsg["content"].([]interface{})
|
||||||
|
require.Len(t, assistantContent, 2, "Assistant should have 2 content blocks for parallel tool calls")
|
||||||
|
|
||||||
|
firstToolUse := assistantContent[0].(map[string]interface{})["toolUse"].(map[string]interface{})
|
||||||
|
require.Equal(t, "get_weather", firstToolUse["name"])
|
||||||
|
require.Equal(t, "call_001", firstToolUse["toolUseId"])
|
||||||
|
|
||||||
|
secondToolUse := assistantContent[1].(map[string]interface{})["toolUse"].(map[string]interface{})
|
||||||
|
require.Equal(t, "get_weather", secondToolUse["name"])
|
||||||
|
require.Equal(t, "call_002", secondToolUse["toolUseId"])
|
||||||
|
|
||||||
|
// Verify tool results are merged into a single user message
|
||||||
|
toolResultMsg := messages[2].(map[string]interface{})
|
||||||
|
require.Equal(t, "user", toolResultMsg["role"])
|
||||||
|
toolResultContent := toolResultMsg["content"].([]interface{})
|
||||||
|
require.Len(t, toolResultContent, 2, "Tool results should be merged into 2 content blocks in one user message")
|
||||||
|
|
||||||
|
firstResult := toolResultContent[0].(map[string]interface{})["toolResult"].(map[string]interface{})
|
||||||
|
require.Equal(t, "call_001", firstResult["toolUseId"])
|
||||||
|
|
||||||
|
secondResult := toolResultContent[1].(map[string]interface{})["toolResult"].(map[string]interface{})
|
||||||
|
require.Equal(t, "call_002", secondResult["toolUseId"])
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test tool call with text content mixed
|
||||||
|
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Beijing?"},
|
||||||
|
{"role": "assistant", "content": "Let me check.", "tool_calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Beijing\"}"}}]},
|
||||||
|
{"role": "tool", "content": "Sunny, 25°C", "tool_call_id": "call_001"},
|
||||||
|
{"role": "assistant", "content": "The weather in Beijing is sunny with 25°C."},
|
||||||
|
{"role": "user", "content": "Thanks!"}
|
||||||
|
],
|
||||||
|
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather info", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}]
|
||||||
|
}`
|
||||||
|
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
messages := bodyMap["messages"].([]interface{})
|
||||||
|
// messages[0] = user, messages[1] = assistant(toolUse), messages[2] = user(toolResult),
|
||||||
|
// messages[3] = assistant(text), messages[4] = user(text)
|
||||||
|
require.Len(t, messages, 5, "Should have 5 messages in mixed tool call and text scenario")
|
||||||
|
|
||||||
|
// Verify message roles alternate correctly
|
||||||
|
require.Equal(t, "user", messages[0].(map[string]interface{})["role"])
|
||||||
|
require.Equal(t, "assistant", messages[1].(map[string]interface{})["role"])
|
||||||
|
require.Equal(t, "user", messages[2].(map[string]interface{})["role"])
|
||||||
|
require.Equal(t, "assistant", messages[3].(map[string]interface{})["role"])
|
||||||
|
require.Equal(t, "user", messages[4].(map[string]interface{})["role"])
|
||||||
|
|
||||||
|
// Verify assistant text message (messages[3]) has text content
|
||||||
|
assistantTextMsg := messages[3].(map[string]interface{})
|
||||||
|
assistantTextContent := assistantTextMsg["content"].([]interface{})
|
||||||
|
require.Len(t, assistantTextContent, 1)
|
||||||
|
require.Contains(t, assistantTextContent[0].(map[string]interface{}), "text", "Text assistant message should have text content")
|
||||||
|
require.Contains(t, assistantTextContent[0].(map[string]interface{})["text"], "sunny", "Text content should contain weather info")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||||
test.RunTest(t, func(t *testing.T) {
|
test.RunTest(t, func(t *testing.T) {
|
||||||
// Test Bedrock response body processing
|
// Test Bedrock response body processing
|
||||||
|
|||||||
Reference in New Issue
Block a user