feat(ai-proxy): bedrock support tool use (#2730)

This commit is contained in:
rinfx
2025-08-19 16:54:50 +08:00
committed by GitHub
parent bb69a1d50b
commit 890a802481
3 changed files with 194 additions and 25 deletions

View File

@@ -96,8 +96,31 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
if bedrockEvent.Role != nil {
chatChoice.Delta.Role = *bedrockEvent.Role
}
if bedrockEvent.Start != nil {
chatChoice.Delta.Content = nil
chatChoice.Delta.ToolCalls = []toolCall{
{
Id: bedrockEvent.Start.ToolUse.ToolUseID,
Type: "function",
Function: functionCall{
Name: bedrockEvent.Start.ToolUse.Name,
Arguments: "",
},
},
}
}
if bedrockEvent.Delta != nil {
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
if bedrockEvent.Delta.ToolUse != nil {
chatChoice.Delta.ToolCalls = []toolCall{
{
Type: "function",
Function: functionCall{
Arguments: bedrockEvent.Delta.ToolUse.Input,
},
},
}
}
}
if bedrockEvent.StopReason != nil {
chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason))
@@ -700,9 +723,12 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
systemMessages := make([]systemContentBlock, 0)
for _, msg := range origRequest.Messages {
if msg.Role == roleSystem {
switch msg.Role {
case roleSystem:
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
} else {
case roleTool:
messages = append(messages, chatToolMessage2BedrockMessage(msg))
default:
messages = append(messages, chatMessage2BedrockMessage(msg))
}
}
@@ -721,6 +747,36 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
},
}
if origRequest.Tools != nil {
request.ToolConfig = &bedrockToolConfig{}
if origRequest.ToolChoice == nil {
request.ToolConfig.ToolChoice.Auto = &struct{}{}
} else if choice_type, ok := origRequest.ToolChoice.(string); ok {
switch choice_type {
case "required":
request.ToolConfig.ToolChoice.Any = &struct{}{}
case "auto":
request.ToolConfig.ToolChoice.Auto = &struct{}{}
case "none":
request.ToolConfig.ToolChoice.Auto = &struct{}{}
}
} else if choice, ok := origRequest.ToolChoice.(toolChoice); ok {
request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{
Name: choice.Function.Name,
}
}
request.ToolConfig.Tools = []bedrockTool{}
for _, tool := range origRequest.Tools {
request.ToolConfig.Tools = append(request.ToolConfig.Tools, bedrockTool{
ToolSpec: bedrockToolSpecification{
InputSchema: bedrockToolInputSchemaJson{Json: tool.Function.Parameters},
Name: tool.Function.Name,
Description: tool.Function.Description,
},
})
}
}
for key, value := range b.config.bedrockAdditionalFields {
request.AdditionalModelRequestFields[key] = value
}
@@ -735,16 +791,29 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
if len(bedrockResponse.Output.Message.Content) > 0 {
outputContent = bedrockResponse.Output.Message.Content[0].Text
}
choices := []chatCompletionChoice{
{
Index: 0,
Message: &chatMessage{
Role: bedrockResponse.Output.Message.Role,
Content: outputContent,
},
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
choice := chatCompletionChoice{
Index: 0,
Message: &chatMessage{
Role: bedrockResponse.Output.Message.Role,
Content: outputContent,
},
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
}
choice.Message.ToolCalls = []toolCall{}
for _, content := range bedrockResponse.Output.Message.Content {
if content.ToolUse != nil {
args, _ := json.Marshal(content.ToolUse.Input)
choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall{
Id: content.ToolUse.ToolUseId,
Type: "function",
Function: functionCall{
Name: content.ToolUse.Name,
Arguments: string(args),
},
})
}
}
choices := []chatCompletionChoice{choice}
requestId := ctx.GetStringContext(requestIdHeader, "")
modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
return &chatCompletionResponse{
@@ -781,6 +850,8 @@ func stopReasonBedrock2OpenAI(reason string) string {
return finishReasonStop
case "max_tokens":
return finishReasonLength
case "tool_use":
return finishReasonToolCall
default:
return reason
}
@@ -792,20 +863,48 @@ type bedrockTextGenRequest struct {
InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"`
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"`
ToolConfig *bedrockToolConfig `json:"toolConfig,omitempty"`
}
type bedrockToolConfig struct {
Tools []bedrockTool `json:"tools,omitempty"`
ToolChoice bedrockToolChoice `json:"toolChoice,omitempty"`
}
type PerformanceConfiguration struct {
Latency string `json:"latency,omitempty"`
}
type bedrockTool struct {
ToolSpec bedrockToolSpecification `json:"toolSpec,omitempty"`
}
type bedrockToolChoice struct {
Any *struct{} `json:"any,omitempty"`
Auto *struct{} `json:"auto,omitempty"`
Tool *bedrockToolSpecification `json:"tool,omitempty"`
}
type bedrockToolSpecification struct {
InputSchema bedrockToolInputSchemaJson `json:"inputSchema,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
}
type bedrockToolInputSchemaJson struct {
Json map[string]interface{} `json:"json,omitempty"`
}
type bedrockMessage struct {
Role string `json:"role"`
Content []bedrockMessageContent `json:"content"`
}
type bedrockMessageContent struct {
Text string `json:"text,omitempty"`
Image *imageBlock `json:"image,omitempty"`
Text string `json:"text,omitempty"`
Image *imageBlock `json:"image,omitempty"`
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
}
type systemContentBlock struct {
@@ -821,6 +920,22 @@ type imageSource struct {
Bytes string `json:"bytes,omitempty"`
}
type toolResultBlock struct {
ToolUseId string `json:"toolUseId"`
Content []toolResultContentBlock `json:"content"`
Status string `json:"status,omitempty"`
}
type toolResultContentBlock struct {
Text string `json:"text"`
}
type toolUseBlock struct {
Input map[string]interface{} `json:"input"`
Name string `json:"name"`
ToolUseId string `json:"toolUseId"`
}
type bedrockInferenceConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
@@ -844,13 +959,19 @@ type converseOutputMemberMessage struct {
}
type message struct {
Content []contentBlockMemberText `json:"content"`
Role string `json:"role"`
Content []contentBlock `json:"content"`
Role string `json:"role"`
}
type contentBlockMemberText struct {
Text string `json:"text"`
type contentBlock struct {
Text string `json:"text,omitempty"`
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
}
type bedrockToolUse struct {
Name string `json:"name"`
ToolUseId string `json:"toolUseId"`
Input map[string]interface{} `json:"input"`
}
type tokenUsage struct {
@@ -861,9 +982,53 @@ type tokenUsage struct {
TotalTokens int `json:"totalTokens"`
}
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
toolResultContent := &toolResultBlock{}
toolResultContent.ToolUseId = chatMessage.ToolCallId
if text, ok := chatMessage.Content.(string); ok {
toolResultContent.Content = []toolResultContentBlock{
{
Text: text,
},
}
openaiContent := chatMessage.ParseContent()
for _, part := range openaiContent {
var content bedrockMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else {
continue
}
}
} else {
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
}
return bedrockMessage{
Role: roleUser,
Content: []bedrockMessageContent{
{
ToolResult: toolResultContent,
},
},
}
}
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
if chatMessage.IsStringContent() {
return bedrockMessage{
var result bedrockMessage
if len(chatMessage.ToolCalls) > 0 {
result = bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{}},
}
params := map[string]interface{}{}
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), &params)
result.Content[0].ToolUse = &toolUseBlock{
Input: params,
Name: chatMessage.ToolCalls[0].Function.Name,
ToolUseId: chatMessage.ToolCalls[0].Id,
}
} else if chatMessage.IsStringContent() {
result = bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}},
}
@@ -880,11 +1045,12 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
}
contents = append(contents, content)
}
return bedrockMessage{
result = bedrockMessage{
Role: chatMessage.Role,
Content: contents,
}
}
return result
}
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {

View File

@@ -171,6 +171,7 @@ type chatMessage struct {
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
@@ -377,14 +378,14 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
}
type toolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Index int `json:"index,omitempty"`
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function functionCall `json:"function"`
}
type functionCall struct {
Id string `json:"id"`
Id string `json:"id,omitempty"`
Name string `json:"name"`
Arguments string `json:"arguments"`
}

View File

@@ -137,9 +137,11 @@ const (
roleSystem = "system"
roleAssistant = "assistant"
roleUser = "user"
roleTool = "tool"
finishReasonStop = "stop"
finishReasonLength = "length"
finishReasonStop = "stop"
finishReasonLength = "length"
finishReasonToolCall = "tool_calls"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiKey = "apiKey"