From 890a802481e55463d16f91cac3c6ee90de830c08 Mon Sep 17 00:00:00 2001 From: rinfx Date: Tue, 19 Aug 2025 16:54:50 +0800 Subject: [PATCH] feat(ai-proxy): bedrock support tool use (#2730) --- .../extensions/ai-proxy/provider/bedrock.go | 206 ++++++++++++++++-- .../extensions/ai-proxy/provider/model.go | 7 +- .../extensions/ai-proxy/provider/provider.go | 6 +- 3 files changed, 194 insertions(+), 25 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index fee47046e..c2ebbf873 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -96,8 +96,31 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex if bedrockEvent.Role != nil { chatChoice.Delta.Role = *bedrockEvent.Role } + if bedrockEvent.Start != nil { + chatChoice.Delta.Content = nil + chatChoice.Delta.ToolCalls = []toolCall{ + { + Id: bedrockEvent.Start.ToolUse.ToolUseID, + Type: "function", + Function: functionCall{ + Name: bedrockEvent.Start.ToolUse.Name, + Arguments: "", + }, + }, + } + } if bedrockEvent.Delta != nil { chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text} + if bedrockEvent.Delta.ToolUse != nil { + chatChoice.Delta.ToolCalls = []toolCall{ + { + Type: "function", + Function: functionCall{ + Arguments: bedrockEvent.Delta.ToolUse.Input, + }, + }, + } + } } if bedrockEvent.StopReason != nil { chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)) @@ -700,9 +723,12 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom systemMessages := make([]systemContentBlock, 0) for _, msg := range origRequest.Messages { - if msg.Role == roleSystem { + switch msg.Role { + case roleSystem: systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()}) - } else { + case roleTool: + messages = append(messages, chatToolMessage2BedrockMessage(msg)) + default: messages = append(messages, chatMessage2BedrockMessage(msg)) } } @@ -721,6 +747,36 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom }, } + if origRequest.Tools != nil { + request.ToolConfig = &bedrockToolConfig{} + if origRequest.ToolChoice == nil { + request.ToolConfig.ToolChoice.Auto = &struct{}{} + } else if choice_type, ok := origRequest.ToolChoice.(string); ok { + switch choice_type { + case "required": + request.ToolConfig.ToolChoice.Any = &struct{}{} + case "auto": + request.ToolConfig.ToolChoice.Auto = &struct{}{} + case "none": + request.ToolConfig.ToolChoice.Auto = &struct{}{} + } + } else if choice, ok := origRequest.ToolChoice.(toolChoice); ok { + request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{ + Name: choice.Function.Name, + } + } + request.ToolConfig.Tools = []bedrockTool{} + for _, tool := range origRequest.Tools { + request.ToolConfig.Tools = append(request.ToolConfig.Tools, bedrockTool{ + ToolSpec: bedrockToolSpecification{ + InputSchema: bedrockToolInputSchemaJson{Json: tool.Function.Parameters}, + Name: tool.Function.Name, + Description: tool.Function.Description, + }, + }) + } + } + for key, value := range b.config.bedrockAdditionalFields { request.AdditionalModelRequestFields[key] = value } @@ -735,16 +791,29 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b if len(bedrockResponse.Output.Message.Content) > 0 { outputContent = bedrockResponse.Output.Message.Content[0].Text } - choices := []chatCompletionChoice{ - { - Index: 0, - Message: &chatMessage{ - Role: bedrockResponse.Output.Message.Role, - Content: outputContent, - }, - FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)), + choice := chatCompletionChoice{ + Index: 0, + Message: &chatMessage{ + Role: bedrockResponse.Output.Message.Role, + Content: outputContent, }, + FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)), } + choice.Message.ToolCalls = []toolCall{} + for _, content := range bedrockResponse.Output.Message.Content { + if content.ToolUse != nil { + args, _ := json.Marshal(content.ToolUse.Input) + choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall{ + Id: content.ToolUse.ToolUseId, + Type: "function", + Function: functionCall{ + Name: content.ToolUse.Name, + Arguments: string(args), + }, + }) + } + } + choices := []chatCompletionChoice{choice} requestId := ctx.GetStringContext(requestIdHeader, "") modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, "")) return &chatCompletionResponse{ @@ -781,6 +850,8 @@ func stopReasonBedrock2OpenAI(reason string) string { return finishReasonStop case "max_tokens": return finishReasonLength + case "tool_use": + return finishReasonToolCall default: return reason } @@ -792,20 +863,48 @@ type bedrockTextGenRequest struct { InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"` AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"` PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"` + ToolConfig *bedrockToolConfig `json:"toolConfig,omitempty"` +} + +type bedrockToolConfig struct { + Tools []bedrockTool `json:"tools,omitempty"` + ToolChoice bedrockToolChoice `json:"toolChoice,omitempty"` } type PerformanceConfiguration struct { Latency string `json:"latency,omitempty"` } +type bedrockTool struct { + ToolSpec bedrockToolSpecification `json:"toolSpec,omitempty"` +} + +type bedrockToolChoice struct { + Any *struct{} `json:"any,omitempty"` + Auto *struct{} `json:"auto,omitempty"` + Tool *bedrockToolSpecification `json:"tool,omitempty"` +} + +type bedrockToolSpecification struct { + InputSchema bedrockToolInputSchemaJson `json:"inputSchema,omitempty"` + Name string `json:"name"` + Description string `json:"description,omitempty"` +} + +type bedrockToolInputSchemaJson struct { + Json map[string]interface{} `json:"json,omitempty"` +} + type bedrockMessage struct { Role string `json:"role"` Content []bedrockMessageContent `json:"content"` } type bedrockMessageContent struct { - Text string `json:"text,omitempty"` - Image *imageBlock `json:"image,omitempty"` + Text string `json:"text,omitempty"` + Image *imageBlock `json:"image,omitempty"` + ToolResult *toolResultBlock `json:"toolResult,omitempty"` + ToolUse *toolUseBlock `json:"toolUse,omitempty"` } type systemContentBlock struct { @@ -821,6 +920,22 @@ type imageSource struct { Bytes string `json:"bytes,omitempty"` } +type toolResultBlock struct { + ToolUseId string `json:"toolUseId"` + Content []toolResultContentBlock `json:"content"` + Status string `json:"status,omitempty"` +} + +type toolResultContentBlock struct { + Text string `json:"text"` +} + +type toolUseBlock struct { + Input map[string]interface{} `json:"input"` + Name string `json:"name"` + ToolUseId string `json:"toolUseId"` +} + type bedrockInferenceConfig struct { StopSequences []string `json:"stopSequences,omitempty"` MaxTokens int `json:"maxTokens,omitempty"` @@ -844,13 +959,19 @@ type converseOutputMemberMessage struct { } type message struct { - Content []contentBlockMemberText `json:"content"` - - Role string `json:"role"` + Content []contentBlock `json:"content"` + Role string `json:"role"` } -type contentBlockMemberText struct { - Text string `json:"text"` +type contentBlock struct { + Text string `json:"text,omitempty"` + ToolUse *bedrockToolUse `json:"toolUse,omitempty"` +} + +type bedrockToolUse struct { + Name string `json:"name"` + ToolUseId string `json:"toolUseId"` + Input map[string]interface{} `json:"input"` } type tokenUsage struct { @@ -861,9 +982,53 @@ type tokenUsage struct { TotalTokens int `json:"totalTokens"` } +func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { + toolResultContent := &toolResultBlock{} + toolResultContent.ToolUseId = chatMessage.ToolCallId + if text, ok := chatMessage.Content.(string); ok { + toolResultContent.Content = []toolResultContentBlock{ + { + Text: text, + }, + } + openaiContent := chatMessage.ParseContent() + for _, part := range openaiContent { + var content bedrockMessageContent + if part.Type == contentTypeText { + content.Text = part.Text + } else { + continue + } + } + } else { + log.Warnf("only text content is supported, current content is %v", chatMessage.Content) + } + return bedrockMessage{ + Role: roleUser, + Content: []bedrockMessageContent{ + { + ToolResult: toolResultContent, + }, + }, + } +} + func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { - if chatMessage.IsStringContent() { - return bedrockMessage{ + var result bedrockMessage + if len(chatMessage.ToolCalls) > 0 { + result = bedrockMessage{ + Role: chatMessage.Role, + Content: []bedrockMessageContent{{}}, + } + params := map[string]interface{}{} + json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms) + result.Content[0].ToolUse = &toolUseBlock{ + Input: params, + Name: chatMessage.ToolCalls[0].Function.Name, + ToolUseId: chatMessage.ToolCalls[0].Id, + } + } else if chatMessage.IsStringContent() { + result = bedrockMessage{ Role: chatMessage.Role, Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}}, } @@ -880,11 +1045,12 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { } contents = append(contents, content) } - return bedrockMessage{ + result = bedrockMessage{ Role: chatMessage.Role, Content: contents, } } + return result } func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 7c1e2811d..d11d223a8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -171,6 +171,7 @@ type chatMessage struct { ReasoningContent string `json:"reasoning_content,omitempty"` ToolCalls []toolCall `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` } func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) { @@ -377,14 +378,14 @@ func (m *chatMessage) ParseContent() []chatMessageContent { } type toolCall struct { - Index int `json:"index"` - Id string `json:"id"` + Index int `json:"index,omitempty"` + Id string `json:"id,omitempty"` Type string `json:"type"` Function functionCall `json:"function"` } type functionCall struct { - Id string `json:"id"` + Id string `json:"id,omitempty"` Name string `json:"name"` Arguments string `json:"arguments"` } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 2fd24f507..faff0647d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -137,9 +137,11 @@ const ( roleSystem = "system" roleAssistant = "assistant" roleUser = "user" + roleTool = "tool" - finishReasonStop = "stop" - finishReasonLength = "length" + finishReasonStop = "stop" + finishReasonLength = "length" + finishReasonToolCall = "tool_calls" ctxKeyIncrementalStreaming = "incrementalStreaming" ctxKeyApiKey = "apiKey"