From caae3ee068dbebd503f9a19743ae5336f58db0ea Mon Sep 17 00:00:00 2001 From: rinfx Date: Thu, 18 Sep 2025 14:22:37 +0800 Subject: [PATCH] [feature] bedrock provider support multimodal and thinking (#2897) --- .../extensions/ai-proxy/provider/bedrock.go | 120 ++++++++++++++++-- 1 file changed, 108 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index c2ebbf873..c78b2d9ed 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -14,6 +14,7 @@ import ( "io" "net/http" "net/url" + "regexp" "strconv" "strings" "time" @@ -34,9 +35,11 @@ const ( // converseStream路径 /model/{modelId}/converse-stream bedrockStreamChatCompletionPath = "/model/%s/converse-stream" // invoke_model 路径 /model/{modelId}/invoke - bedrockInvokeModelPath = "/model/%s/invoke" - bedrockSignedHeaders = "host;x-amz-date" - requestIdHeader = "X-Amzn-Requestid" + bedrockInvokeModelPath = "/model/%s/invoke" + bedrockSignedHeaders = "host;x-amz-date" + requestIdHeader = "X-Amzn-Requestid" + reasoningContextMarkerStart = "" + reasoningContextMarkerEnd = "" ) type bedrockProviderInitializer struct{} @@ -74,6 +77,9 @@ type bedrockProvider struct { func (b *bedrockProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { events := extractAmazonEventStreamEvents(ctx, chunk) if len(events) == 0 { + if isLastChunk { + return []byte(ssePrefix + "[DONE]\n\n"), nil + } return chunk, fmt.Errorf("No events are extracted ") } var responseBuilder strings.Builder @@ -85,6 +91,9 @@ func (b *bedrockProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name } responseBuilder.WriteString(string(outputEvent)) } + if isLastChunk { + responseBuilder.WriteString(ssePrefix + "[DONE]\n\n") + } return []byte(responseBuilder.String()), nil } @@ -110,7 +119,23 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex } } if bedrockEvent.Delta != nil { - chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text} + if bedrockEvent.Delta.ReasoningContent != nil { + var content string + if ctx.GetContext("thinking_start") == nil { + content += reasoningContextMarkerStart + ctx.SetContext("thinking_start", true) + } + content += bedrockEvent.Delta.ReasoningContent.Text + chatChoice.Delta = &chatMessage{Content: &content} + } else if bedrockEvent.Delta.Text != nil { + var content string + if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil { + content += reasoningContextMarkerEnd + ctx.SetContext("thinking_end", true) + } + content += *bedrockEvent.Delta.Text + chatChoice.Delta = &chatMessage{Content: &content} + } if bedrockEvent.Delta.ToolUse != nil { chatChoice.Delta.ToolCalls = []toolCall{ { @@ -162,8 +187,9 @@ type ConverseStreamEvent struct { } type converseStreamEventContentBlockDelta struct { - Text *string `json:"text,omitempty"` - ToolUse *toolUseBlockDelta `json:"toolUse,omitempty"` + Text *string `json:"text,omitempty"` + ToolUse *toolUseBlockDelta `json:"toolUse,omitempty"` + ReasoningContent *reasoningContentDelta `json:"reasoningContent,omitempty"` } type toolUseBlockStart struct { @@ -179,6 +205,11 @@ type toolUseBlockDelta struct { Input string `json:"input"` } +type reasoningContentDelta struct { + Text string `json:"text,omitempty"` + Signature string `json:"signature,omitempty"` +} + type bedrockImageGenerationResponse struct { Images []string `json:"images"` Error string `json:"error"` @@ -747,6 +778,22 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom }, } + if origRequest.ReasoningEffort != "" { + thinkingBudget := 1024 // default + switch origRequest.ReasoningEffort { + case "low": + thinkingBudget = 1024 + case "medium": + thinkingBudget = 4096 + case "high": + thinkingBudget = 16384 + } + request.AdditionalModelRequestFields["thinking"] = map[string]interface{}{ + "type": "enabled", + "budget_tokens": thinkingBudget, + } + } + if origRequest.Tools != nil { request.ToolConfig = &bedrockToolConfig{} if origRequest.ToolChoice == nil { @@ -787,9 +834,19 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom } func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse { - var outputContent string - if len(bedrockResponse.Output.Message.Content) > 0 { - outputContent = bedrockResponse.Output.Message.Content[0].Text + var outputContent, reasoningContent, normalContent string + for _, content := range bedrockResponse.Output.Message.Content { + if content.ReasoningContent != nil { + reasoningContent = content.ReasoningContent.ReasoningText.Text + } + if content.Text != "" { + normalContent = content.Text + } + } + if reasoningContent != "" { + outputContent = reasoningContextMarkerStart + reasoningContent + reasoningContextMarkerEnd + normalContent + } else { + outputContent = normalContent } choice := chatCompletionChoice{ Index: 0, @@ -964,8 +1021,18 @@ type message struct { } type contentBlock struct { - Text string `json:"text,omitempty"` - ToolUse *bedrockToolUse `json:"toolUse,omitempty"` + Text string `json:"text,omitempty"` + ToolUse *bedrockToolUse `json:"toolUse,omitempty"` + ReasoningContent *reasoningContent `json:"reasoningContent,omitempty"` +} + +type reasoningContent struct { + ReasoningText reasoningText `json:"reasoningText"` +} + +type reasoningText struct { + Text string `json:"text,omitempty"` + Signature string `json:"signature,omitempty"` } type bedrockToolUse struct { @@ -1039,8 +1106,22 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage { var content bedrockMessageContent if part.Type == contentTypeText { content.Text = part.Text + } else if part.Type == contentTypeImageUrl { + base64Str := part.ImageUrl.Url + prefix, imageType, err := extractImageType(base64Str) + if err != nil { + log.Warn("image url is not supported") + continue + } + base64WoPrefix, _ := strings.CutPrefix(base64Str, prefix) + content.Image = &imageBlock{ + Format: imageType, + Source: imageSource{ + Bytes: base64WoPrefix, + }, + } } else { - log.Warnf("imageUrl is not supported: %s", part.Type) + log.Warnf("type is not supported: %s", part.Type) continue } contents = append(contents, content) @@ -1118,3 +1199,18 @@ func hmacHex(key []byte, data string) string { h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } + +func extractImageType(base64Str string) (string, string, error) { + re := regexp.MustCompile(`^data:([^;]+);base64,`) + matches := re.FindStringSubmatch(base64Str) + if len(matches) < 2 { + return "", "", fmt.Errorf("invalid base64 format") + } + + mimeType := matches[1] // e.g. image/png + parts := strings.Split(mimeType, "/") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid mimeType") + } + return matches[0], parts[1], nil +}