mirror of
https://github.com/alibaba/higress.git
synced 2026-05-09 13:27:27 +08:00
bugfix: map bedrock tool-call indexes and tool_choice (#3786)
Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,8 @@ const (
|
||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||
bedrockCachePointPositionLastMessage = "lastMessage"
|
||||
|
||||
ctxKeyBedrockToolCallState = "bedrock_tool_call_state"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -121,11 +123,13 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
chatChoice.Delta.Role = *bedrockEvent.Role
|
||||
}
|
||||
if bedrockEvent.Start != nil {
|
||||
toolCallIndex := getBedrockOpenAIToolCallIndex(ctx, bedrockEvent.ContentBlockIndex)
|
||||
chatChoice.Delta.Content = nil
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Type: "function",
|
||||
Index: toolCallIndex,
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: bedrockEvent.Start.ToolUse.Name,
|
||||
Arguments: "",
|
||||
@@ -152,9 +156,11 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
chatChoice.Delta = &chatMessage{Content: &content}
|
||||
}
|
||||
if bedrockEvent.Delta.ToolUse != nil {
|
||||
toolCallIndex := getBedrockOpenAIToolCallIndex(ctx, bedrockEvent.ContentBlockIndex)
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Type: "function",
|
||||
Index: toolCallIndex,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Arguments: bedrockEvent.Delta.ToolUse.Input,
|
||||
},
|
||||
@@ -192,6 +198,28 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
return []byte(openAIChunk.String()), nil
|
||||
}
|
||||
|
||||
type bedrockToolCallState struct {
|
||||
indexes map[int]int
|
||||
next int
|
||||
}
|
||||
|
||||
func getBedrockOpenAIToolCallIndex(ctx wrapper.HttpContext, contentBlockIndex int) int {
|
||||
state, _ := ctx.GetContext(ctxKeyBedrockToolCallState).(*bedrockToolCallState)
|
||||
if state == nil {
|
||||
state = &bedrockToolCallState{indexes: make(map[int]int)}
|
||||
ctx.SetContext(ctxKeyBedrockToolCallState, state)
|
||||
}
|
||||
|
||||
if toolCallIndex, ok := state.indexes[contentBlockIndex]; ok {
|
||||
return toolCallIndex
|
||||
}
|
||||
|
||||
toolCallIndex := state.next
|
||||
state.indexes[contentBlockIndex] = toolCallIndex
|
||||
state.next++
|
||||
return toolCallIndex
|
||||
}
|
||||
|
||||
type ConverseStreamEvent struct {
|
||||
ContentBlockIndex int `json:"contentBlockIndex,omitempty"`
|
||||
Delta *converseStreamEventContentBlockDelta `json:"delta,omitempty"`
|
||||
@@ -870,22 +898,25 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
}
|
||||
}
|
||||
|
||||
if origRequest.Tools != nil {
|
||||
if origRequest.Tools != nil && origRequest.getToolChoiceType() != "none" {
|
||||
request.ToolConfig = &bedrockToolConfig{}
|
||||
if origRequest.ToolChoice == nil {
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
} else if choice_type, ok := origRequest.ToolChoice.(string); ok {
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
if choice_type := origRequest.getToolChoiceType(); choice_type != "" {
|
||||
switch choice_type {
|
||||
case "required":
|
||||
// "any" is accepted for direct Anthropic-compatible callers; OpenAI
|
||||
// uses "required" for the same "must call at least one tool" behavior.
|
||||
case "required", "any":
|
||||
request.ToolConfig.ToolChoice.Auto = nil
|
||||
request.ToolConfig.ToolChoice.Any = &struct{}{}
|
||||
case "auto":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
case "none":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
}
|
||||
} else if choice, ok := origRequest.ToolChoice.(toolChoice); ok {
|
||||
request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{
|
||||
Name: choice.Function.Name,
|
||||
case "function":
|
||||
if choice := origRequest.getToolChoiceObject(); choice != nil && choice.Function.Name != "" {
|
||||
request.ToolConfig.ToolChoice.Auto = nil
|
||||
request.ToolConfig.ToolChoice.Tool = &bedrockSpecificToolChoice{
|
||||
Name: choice.Function.Name,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
request.ToolConfig.Tools = []bedrockTool{}
|
||||
@@ -1151,9 +1182,13 @@ type bedrockTool struct {
|
||||
}
|
||||
|
||||
type bedrockToolChoice struct {
|
||||
Any *struct{} `json:"any,omitempty"`
|
||||
Auto *struct{} `json:"auto,omitempty"`
|
||||
Tool *bedrockToolSpecification `json:"tool,omitempty"`
|
||||
Any *struct{} `json:"any,omitempty"`
|
||||
Auto *struct{} `json:"auto,omitempty"`
|
||||
Tool *bedrockSpecificToolChoice `json:"tool,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockSpecificToolChoice struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type bedrockToolSpecification struct {
|
||||
|
||||
@@ -672,11 +672,30 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
claudeRequest.Tools = append(claudeRequest.Tools, claudeTool)
|
||||
}
|
||||
|
||||
if tc := origRequest.getToolChoiceObject(); tc != nil {
|
||||
claudeRequest.ToolChoice = &claudeToolChoice{
|
||||
Name: tc.Function.Name,
|
||||
Type: tc.Type,
|
||||
DisableParallelToolUse: !origRequest.ParallelToolCalls,
|
||||
if origRequest.ToolChoice != nil {
|
||||
parallelToolCalls := true
|
||||
if origRequest.ParallelToolCalls != nil {
|
||||
parallelToolCalls = *origRequest.ParallelToolCalls
|
||||
}
|
||||
|
||||
choiceType := origRequest.getToolChoiceType()
|
||||
if tc := origRequest.getToolChoiceObject(); tc != nil && tc.Type == "function" && tc.Function.Name != "" {
|
||||
claudeRequest.ToolChoice = &claudeToolChoice{
|
||||
Name: tc.Function.Name,
|
||||
Type: "tool",
|
||||
DisableParallelToolUse: !parallelToolCalls,
|
||||
}
|
||||
} else if choiceType != "" {
|
||||
switch choiceType {
|
||||
case "required":
|
||||
choiceType = "any"
|
||||
}
|
||||
claudeRequest.ToolChoice = &claudeToolChoice{
|
||||
Type: choiceType,
|
||||
}
|
||||
if choiceType != "none" {
|
||||
claudeRequest.ToolChoice.DisableParallelToolUse = !parallelToolCalls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -183,6 +183,87 @@ func TestClaudeProvider_BuildClaudeTextGenRequest_StandardMode(t *testing.T) {
|
||||
assert.False(t, claudeReq.System.IsArray)
|
||||
assert.Equal(t, "You are a helpful assistant.", claudeReq.System.StringValue)
|
||||
})
|
||||
|
||||
t.Run("maps_openai_function_tool_choice_to_claude_tool_choice", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 8192,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "Search."},
|
||||
},
|
||||
Tools: []tool{{
|
||||
Type: "function",
|
||||
Function: function{
|
||||
Name: "web_search",
|
||||
Description: "Search the web.",
|
||||
Parameters: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
}},
|
||||
ToolChoice: map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
require.NotNil(t, claudeReq.ToolChoice)
|
||||
assert.Equal(t, "tool", claudeReq.ToolChoice.Type)
|
||||
assert.Equal(t, "web_search", claudeReq.ToolChoice.Name)
|
||||
})
|
||||
|
||||
t.Run("maps_openai_string_required_tool_choice_to_claude_any", func(t *testing.T) {
|
||||
parallelToolCalls := false
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 8192,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "Search."},
|
||||
},
|
||||
Tools: []tool{{
|
||||
Type: "function",
|
||||
Function: function{
|
||||
Name: "web_search",
|
||||
Parameters: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
}},
|
||||
ToolChoice: "required",
|
||||
ParallelToolCalls: ¶llelToolCalls,
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
require.NotNil(t, claudeReq.ToolChoice)
|
||||
assert.Equal(t, "any", claudeReq.ToolChoice.Type)
|
||||
assert.Empty(t, claudeReq.ToolChoice.Name)
|
||||
assert.True(t, claudeReq.ToolChoice.DisableParallelToolUse)
|
||||
})
|
||||
|
||||
t.Run("maps_openai_string_none_tool_choice_to_claude_none", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 8192,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "Answer without tools."},
|
||||
},
|
||||
Tools: []tool{{
|
||||
Type: "function",
|
||||
Function: function{
|
||||
Name: "web_search",
|
||||
Parameters: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
}},
|
||||
ToolChoice: "none",
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
require.NotNil(t, claudeReq.ToolChoice)
|
||||
assert.Equal(t, "none", claudeReq.ToolChoice.Type)
|
||||
assert.Empty(t, claudeReq.ToolChoice.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeProvider_BuildClaudeTextGenRequest_ClaudeCodeMode(t *testing.T) {
|
||||
|
||||
@@ -178,12 +178,19 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// For other types like "auto", "none", etc.
|
||||
openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type
|
||||
// Anthropic's "any" means the model must call at least one tool.
|
||||
// OpenAI-compatible requests express the same behavior as "required".
|
||||
if claudeRequest.ToolChoice.Type == "any" {
|
||||
openaiRequest.ToolChoice = "required"
|
||||
} else {
|
||||
// For other types like "auto", "none", etc.
|
||||
openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type
|
||||
}
|
||||
}
|
||||
|
||||
// Handle parallel tool calls
|
||||
openaiRequest.ParallelToolCalls = !claudeRequest.ToolChoice.DisableParallelToolUse
|
||||
parallelToolCalls := !claudeRequest.ToolChoice.DisableParallelToolUse
|
||||
openaiRequest.ParallelToolCalls = ¶llelToolCalls
|
||||
}
|
||||
|
||||
// Convert thinking configuration if present
|
||||
|
||||
@@ -35,6 +35,66 @@ func init() {
|
||||
func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
t.Run("convert_tool_choice_any_to_required", func(t *testing.T) {
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tools": [{
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}],
|
||||
"tool_choice": {"type": "any"}
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "required", openaiRequest.ToolChoice)
|
||||
require.NotNil(t, openaiRequest.ParallelToolCalls)
|
||||
require.True(t, *openaiRequest.ParallelToolCalls)
|
||||
require.Contains(t, string(result), `"parallel_tool_calls":true`)
|
||||
})
|
||||
|
||||
t.Run("convert_tool_choice_any_preserves_disable_parallel_tool_use", func(t *testing.T) {
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tools": [{
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}],
|
||||
"tool_choice": {"type": "any", "disable_parallel_tool_use": true}
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "required", openaiRequest.ToolChoice)
|
||||
require.NotNil(t, openaiRequest.ParallelToolCalls)
|
||||
require.False(t, *openaiRequest.ParallelToolCalls)
|
||||
require.Contains(t, string(result), `"parallel_tool_calls":false`)
|
||||
})
|
||||
|
||||
t.Run("convert_multiple_text_content_blocks", func(t *testing.T) {
|
||||
// Test case: multiple text content blocks should remain as separate array elements with cache control support
|
||||
// Both system and user messages should handle array content format
|
||||
|
||||
@@ -70,7 +70,7 @@ type chatCompletionRequest struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -92,15 +92,33 @@ func (c *chatCompletionRequest) getToolChoiceString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice {
|
||||
if c.ToolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tc, ok := c.ToolChoice.(*toolChoice); ok {
|
||||
func (c *chatCompletionRequest) getToolChoiceType() string {
|
||||
if tc := c.getToolChoiceString(); tc != "" {
|
||||
return tc
|
||||
}
|
||||
return nil
|
||||
if tc := c.getToolChoiceObject(); tc != nil {
|
||||
return tc.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice {
|
||||
switch tc := c.ToolChoice.(type) {
|
||||
case nil, string:
|
||||
return nil
|
||||
case *toolChoice:
|
||||
return tc
|
||||
}
|
||||
|
||||
body, err := json.Marshal(c.ToolChoice)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var parsed toolChoice
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
return &parsed
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
@@ -474,7 +492,7 @@ type toolCall struct {
|
||||
|
||||
type functionCall struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1365,6 +1365,177 @@ func RunBedrockToolCallTests(t *testing.T) {
|
||||
require.Equal(t, "call_002", secondResult["toolUseId"])
|
||||
})
|
||||
|
||||
t.Run("bedrock maps any tool choice to converse any tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": "any",
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
require.Contains(t, toolChoice, "any")
|
||||
require.Equal(t, map[string]interface{}{}, toolChoice["any"])
|
||||
})
|
||||
|
||||
t.Run("bedrock maps object tool choice to converse tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": {"type":"function","function":{"name":"web_search"}},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
tool := toolChoice["tool"].(map[string]interface{})
|
||||
require.Equal(t, "web_search", tool["name"])
|
||||
require.Len(t, tool, 1, "Bedrock specific tool choice should only include name")
|
||||
})
|
||||
|
||||
t.Run("bedrock maps object any tool choice to converse any tool choice", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run a search."}],
|
||||
"tool_choice": {"type":"any"},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolConfig := bodyMap["toolConfig"].(map[string]interface{})
|
||||
toolChoice := toolConfig["toolChoice"].(map[string]interface{})
|
||||
require.Contains(t, toolChoice, "any")
|
||||
require.NotContains(t, toolChoice, "tool")
|
||||
})
|
||||
|
||||
t.Run("bedrock maps none tool choice by omitting tools", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Answer without tools."}],
|
||||
"tool_choice": {"type":"none"},
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(host.GetRequestBody(), &bodyMap)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, bodyMap, "toolConfig")
|
||||
})
|
||||
|
||||
// Test tool call with text content mixed
|
||||
t.Run("bedrock tool call with text content mixed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
@@ -1669,6 +1840,138 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
||||
require.Equal(t, "", payload)
|
||||
})
|
||||
|
||||
t.Run("bedrock streaming parallel tool calls should use dense OpenAI indexes", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Run three independent searches."}],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/vnd.amazon.eventstream"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
extractFirstToolCallIndex := func(body []byte) int {
|
||||
dataPayload := extractFirstDataPayload(body)
|
||||
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(dataPayload), &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
choices := responseMap["choices"].([]interface{})
|
||||
require.Len(t, choices, 1, "streaming chunk should contain one choice")
|
||||
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
|
||||
toolCalls := delta["tool_calls"].([]interface{})
|
||||
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
|
||||
|
||||
index, ok := toolCalls[0].(map[string]interface{})["index"].(float64)
|
||||
require.True(t, ok, "tool call delta should include an index")
|
||||
return int(index)
|
||||
}
|
||||
extractFirstToolCall := func(body []byte) map[string]interface{} {
|
||||
dataPayload := extractFirstDataPayload(body)
|
||||
require.NotEmpty(t, dataPayload, "streaming chunk should contain one SSE data payload")
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(dataPayload), &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
choices := responseMap["choices"].([]interface{})
|
||||
require.Len(t, choices, 1, "streaming chunk should contain one choice")
|
||||
delta := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
|
||||
toolCalls := delta["tool_calls"].([]interface{})
|
||||
require.Len(t, toolCalls, 1, "streaming chunk should contain one tool call delta")
|
||||
|
||||
toolCall, ok := toolCalls[0].(map[string]interface{})
|
||||
require.True(t, ok, "tool call delta should be an object")
|
||||
return toolCall
|
||||
}
|
||||
|
||||
for expectedIndex, item := range []struct {
|
||||
contentBlockIndex int
|
||||
toolUseId string
|
||||
}{
|
||||
{contentBlockIndex: 1, toolUseId: "tooluse_first"},
|
||||
{contentBlockIndex: 3, toolUseId: "tooluse_second"},
|
||||
{contentBlockIndex: 4, toolUseId: "tooluse_third"},
|
||||
} {
|
||||
toolCallStart := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||
"contentBlockIndex": item.contentBlockIndex,
|
||||
"start": map[string]interface{}{
|
||||
"toolUse": map[string]interface{}{
|
||||
"toolUseId": item.toolUseId,
|
||||
"name": "web_search",
|
||||
},
|
||||
},
|
||||
})
|
||||
action = host.CallOnHttpStreamingResponseBody(toolCallStart, false)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
toolCall := extractFirstToolCall(host.GetResponseBody())
|
||||
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
|
||||
require.Equal(t, item.toolUseId, toolCall["id"])
|
||||
require.Equal(t, "function", toolCall["type"])
|
||||
function := toolCall["function"].(map[string]interface{})
|
||||
require.Equal(t, "web_search", function["name"])
|
||||
require.Equal(t, "", function["arguments"])
|
||||
}
|
||||
|
||||
for expectedIndex, item := range []struct {
|
||||
contentBlockIndex int
|
||||
query string
|
||||
}{
|
||||
{contentBlockIndex: 1, query: "first synthetic query"},
|
||||
{contentBlockIndex: 3, query: "second synthetic query"},
|
||||
{contentBlockIndex: 4, query: "third synthetic query"},
|
||||
} {
|
||||
toolCallDelta := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||
"contentBlockIndex": item.contentBlockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"toolUse": map[string]interface{}{
|
||||
"input": "{\"query\":\"" + item.query + "\"}",
|
||||
},
|
||||
},
|
||||
})
|
||||
action = host.CallOnHttpStreamingResponseBody(toolCallDelta, false)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
toolCall := extractFirstToolCall(host.GetResponseBody())
|
||||
require.Equal(t, expectedIndex, extractFirstToolCallIndex(host.GetResponseBody()))
|
||||
require.NotContains(t, toolCall, "id")
|
||||
function := toolCall["function"].(map[string]interface{})
|
||||
require.NotContains(t, function, "name")
|
||||
require.Equal(t, "{\"query\":\""+item.query+"\"}", function["arguments"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
|
||||
Reference in New Issue
Block a user