bugfix: map bedrock tool-call indexes and tool_choice (#3786)

Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
This commit is contained in:
Betula-L
2026-05-06 04:48:42 -07:00
committed by GitHub
parent 4aba4a9860
commit 6199fe414d
7 changed files with 557 additions and 34 deletions

View File

@@ -1365,6 +1365,177 @@ func RunBedrockToolCallTests(t *testing.T) {
require.Equal(t, "call_002", secondResult["toolUseId"])
})
t.Run("bedrock maps any tool choice to converse any tool choice", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Run a search."}],
"tool_choice": "any",
"tools": [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web.",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"]
}
}
}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
require.Contains(t, toolChoice, "any")
require.Equal(t, map[string]interface{}{}, toolChoice["any"])
})
t.Run("bedrock maps object tool choice to converse tool choice", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Run a search."}],
"tool_choice": {"type":"function","function":{"name":"web_search"}},
"tools": [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web.",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"]
}
}
}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
tool := toolChoice["tool"].(map[string]interface{})
require.Equal(t, "web_search", tool["name"])
require.Len(t, tool, 1, "Bedrock specific tool choice should only include name")
})
t.Run("bedrock maps object any tool choice to converse any tool choice", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Run a search."}],
"tool_choice": {"type":"any"},
"tools": [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web.",
"parameters": {"type": "object"}
}
}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
var bodyMap map[string]interface{}
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
require.NoError(t, err)
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
require.Contains(t, toolChoice, "any")
require.NotContains(t, toolChoice, "tool")
})
t.Run("bedrock maps none tool choice by omitting tools", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Answer without tools."}],
"tool_choice": {"type":"none"},
"tools": [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web.",
"parameters": {"type": "object"}
}
}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
var bodyMap map[string]interface{}
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
require.NoError(t, err)
require.NotContains(t, bodyMap, "toolConfig")
})
// Test tool call with text content mixed
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
@@ -1669,6 +1840,138 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
require.Equal(t, "", payload)
})
t.Run("bedrock streaming parallel tool calls should use dense OpenAI indexes", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Run three independent searches."}],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web.",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"]
}
}
}]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/vnd.amazon.eventstream"},
})
require.Equal(t, types.ActionContinue, action)
extractFirstToolCallIndex := func(body []byte) int {
dataPayload := extractFirstDataPayload(body)
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
var responseMap map[string]interface{}
err := json.Unmarshal([]byte(dataPayload), &responseMap)
require.NoError(t, err)
choices := responseMap["choices"].([]interface{})
require.Len(t, choices, 1, "streaming chunk should contain one choice")
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
toolCalls := delta["tool_calls"].([]interface{})
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
index, ok := toolCalls[0].(map[string]interface{})["index"].(float64)
require.True(t, ok, "tool call delta should include an index")
return int(index)
}
extractFirstToolCall := func(body []byte) map[string]interface{} {
dataPayload := extractFirstDataPayload(body)
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
var responseMap map[string]interface{}
err := json.Unmarshal([]byte(dataPayload), &responseMap)
require.NoError(t, err)
choices := responseMap["choices"].([]interface{})
require.Len(t, choices, 1, "streaming chunk should contain one choice")
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
toolCalls := delta["tool_calls"].([]interface{})
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
toolCall, ok := toolCalls[0].(map[string]interface{})
require.True(t, ok, "tool call delta should be an object")
return toolCall
}
for expectedIndex, item := range []struct {
contentBlockIndex int
toolUseId string
}{
{contentBlockIndex: 1, toolUseId: "tooluse_first"},
{contentBlockIndex: 3, toolUseId: "tooluse_second"},
{contentBlockIndex: 4, toolUseId: "tooluse_third"},
} {
toolCallStart := buildBedrockEventStreamMessage(t, map[string]interface{}{
"contentBlockIndex": item.contentBlockIndex,
"start": map[string]interface{}{
"toolUse": map[string]interface{}{
"toolUseId": item.toolUseId,
"name": "web_search",
},
},
})
action = host.CallOnHttpStreamingResponseBody(toolCallStart, false)
require.Equal(t, types.ActionContinue, action)
toolCall := extractFirstToolCall(host.GetResponseBody())
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
require.Equal(t, item.toolUseId, toolCall["id"])
require.Equal(t, "function", toolCall["type"])
function := toolCall["function"].(map[string]interface{})
require.Equal(t, "web_search", function["name"])
require.Equal(t, "", function["arguments"])
}
for expectedIndex, item := range []struct {
contentBlockIndex int
query string
}{
{contentBlockIndex: 1, query: "first synthetic query"},
{contentBlockIndex: 3, query: "second synthetic query"},
{contentBlockIndex: 4, query: "third synthetic query"},
} {
toolCallDelta := buildBedrockEventStreamMessage(t, map[string]interface{}{
"contentBlockIndex": item.contentBlockIndex,
"delta": map[string]interface{}{
"toolUse": map[string]interface{}{
"input": "{\"query\":\"" + item.query + "\"}",
},
},
})
action = host.CallOnHttpStreamingResponseBody(toolCallDelta, false)
require.Equal(t, types.ActionContinue, action)
toolCall := extractFirstToolCall(host.GetResponseBody())
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
require.NotContains(t, toolCall, "id")
function := toolCall["function"].(map[string]interface{})
require.NotContains(t, function, "name")
require.Equal(t, "{\"query\":\""+item.query+"\"}", function["arguments"])
}
})
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()