mirror of
https://github.com/alibaba/higress.git
synced 2026-05-22 19:57:29 +08:00
bugfix: map bedrock tool-call indexes and tool_choice (#3786)
Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
This commit is contained in:
@@ -1365,6 +1365,177 @@ func RunBedrockToolCallTests(t *testing.T) {
|
||||
require.Equal(t, "call_002", secondResult["toolUseId"])
|
||||
})
|
||||
|
||||
t.Run("bedrock maps any tool choice to converse any tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": "any",
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
require.Contains(t, toolChoice, "any")
|
||||
require.Equal(t, map[string]interface{}{}, toolChoice["any"])
|
||||
})
|
||||
|
||||
t.Run("bedrock maps object tool choice to converse tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": {"type":"function","function":{"name":"web_search"}},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
tool := toolChoice["tool"].(map[string]interface{})
|
||||
require.Equal(t, "web_search", tool["name"])
|
||||
require.Len(t, tool, 1, "Bedrock specific tool choice should only include name")
|
||||
})
|
||||
|
||||
t.Run("bedrock maps object any tool choice to converse any tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": {"type":"any"},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
require.Contains(t, toolChoice, "any")
|
||||
require.NotContains(t, toolChoice, "tool")
|
||||
})
|
||||
|
||||
t.Run("bedrock maps none tool choice by omitting tools", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Answer without tools."}],
|
||||
"tool_choice": {"type":"none"},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, bodyMap, "toolConfig")
|
||||
})
|
||||
|
||||
// Test tool call with text content mixed
|
||||
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
@@ -1669,6 +1840,138 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
||||
require.Equal(t, "", payload)
|
||||
})
|
||||
|
||||
t.Run("bedrock streaming parallel tool calls should use dense OpenAI indexes", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run three independent searches."}],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/vnd.amazon.eventstream"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
extractFirstToolCallIndex := func(body []byte) int {
|
||||
dataPayload := extractFirstDataPayload(body)
|
||||
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(dataPayload), &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
choices := responseMap["choices"].([]interface{})
|
||||
require.Len(t, choices, 1, "streaming chunk should contain one choice")
|
||||
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
|
||||
toolCalls := delta["tool_calls"].([]interface{})
|
||||
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
|
||||
|
||||
index, ok := toolCalls[0].(map[string]interface{})["index"].(float64)
|
||||
require.True(t, ok, "tool call delta should include an index")
|
||||
return int(index)
|
||||
}
|
||||
extractFirstToolCall := func(body []byte) map[string]interface{} {
|
||||
dataPayload := extractFirstDataPayload(body)
|
||||
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(dataPayload), &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
choices := responseMap["choices"].([]interface{})
|
||||
require.Len(t, choices, 1, "streaming chunk should contain one choice")
|
||||
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
|
||||
toolCalls := delta["tool_calls"].([]interface{})
|
||||
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
|
||||
|
||||
toolCall, ok := toolCalls[0].(map[string]interface{})
|
||||
require.True(t, ok, "tool call delta should be an object")
|
||||
return toolCall
|
||||
}
|
||||
|
||||
for expectedIndex, item := range []struct {
|
||||
contentBlockIndex int
|
||||
toolUseId string
|
||||
}{
|
||||
{contentBlockIndex: 1, toolUseId: "tooluse_first"},
|
||||
{contentBlockIndex: 3, toolUseId: "tooluse_second"},
|
||||
{contentBlockIndex: 4, toolUseId: "tooluse_third"},
|
||||
} {
|
||||
toolCallStart := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||
"contentBlockIndex": item.contentBlockIndex,
|
||||
"start": map[string]interface{}{
|
||||
"toolUse": map[string]interface{}{
|
||||
"toolUseId": item.toolUseId,
|
||||
"name": "web_search",
|
||||
},
|
||||
},
|
||||
})
|
||||
action = host.CallOnHttpStreamingResponseBody(toolCallStart, false)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
toolCall := extractFirstToolCall(host.GetResponseBody())
|
||||
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
|
||||
require.Equal(t, item.toolUseId, toolCall["id"])
|
||||
require.Equal(t, "function", toolCall["type"])
|
||||
function := toolCall["function"].(map[string]interface{})
|
||||
require.Equal(t, "web_search", function["name"])
|
||||
require.Equal(t, "", function["arguments"])
|
||||
}
|
||||
|
||||
for expectedIndex, item := range []struct {
|
||||
contentBlockIndex int
|
||||
query string
|
||||
}{
|
||||
{contentBlockIndex: 1, query: "first synthetic query"},
|
||||
{contentBlockIndex: 3, query: "second synthetic query"},
|
||||
{contentBlockIndex: 4, query: "third synthetic query"},
|
||||
} {
|
||||
toolCallDelta := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||
"contentBlockIndex": item.contentBlockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"toolUse": map[string]interface{}{
|
||||
"input": "{\"query\":\"" + item.query + "\"}",
|
||||
},
|
||||
},
|
||||
})
|
||||
action = host.CallOnHttpStreamingResponseBody(toolCallDelta, false)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
toolCall := extractFirstToolCall(host.GetResponseBody())
|
||||
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
|
||||
require.NotContains(t, toolCall, "id")
|
||||
function := toolCall["function"].(map[string]interface{})
|
||||
require.NotContains(t, function, "name")
|
||||
require.Equal(t, "{\"query\":\""+item.query+"\"}", function["arguments"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
|
||||
Reference in New Issue
Block a user