[feature] bedrock provider support multimodal and thinking (#2897)

This commit is contained in:
rinfx
2025-09-18 14:22:37 +08:00
committed by GitHub
parent d7bebf79e1
commit caae3ee068

View File

@@ -14,6 +14,7 @@ import (
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
@@ -34,9 +35,11 @@ const (
// converseStream路径 /model/{modelId}/converse-stream
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
// invoke_model 路径 /model/{modelId}/invoke
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
reasoningContextMarkerStart = "<think>"
reasoningContextMarkerEnd = "</think>"
)
type bedrockProviderInitializer struct{}
@@ -74,6 +77,9 @@ type bedrockProvider struct {
func (b *bedrockProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
events := extractAmazonEventStreamEvents(ctx, chunk)
if len(events) == 0 {
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
}
return chunk, fmt.Errorf("No events are extracted ")
}
var responseBuilder strings.Builder
@@ -85,6 +91,9 @@ func (b *bedrockProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
}
responseBuilder.WriteString(string(outputEvent))
}
if isLastChunk {
responseBuilder.WriteString(ssePrefix + "[DONE]\n\n")
}
return []byte(responseBuilder.String()), nil
}
@@ -110,7 +119,23 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
}
}
if bedrockEvent.Delta != nil {
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
if bedrockEvent.Delta.ReasoningContent != nil {
var content string
if ctx.GetContext("thinking_start") == nil {
content += reasoningContextMarkerStart
ctx.SetContext("thinking_start", true)
}
content += bedrockEvent.Delta.ReasoningContent.Text
chatChoice.Delta = &chatMessage{Content: &content}
} else if bedrockEvent.Delta.Text != nil {
var content string
if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil {
content += reasoningContextMarkerEnd
ctx.SetContext("thinking_end", true)
}
content += *bedrockEvent.Delta.Text
chatChoice.Delta = &chatMessage{Content: &content}
}
if bedrockEvent.Delta.ToolUse != nil {
chatChoice.Delta.ToolCalls = []toolCall{
{
@@ -162,8 +187,9 @@ type ConverseStreamEvent struct {
}
type converseStreamEventContentBlockDelta struct {
Text *string `json:"text,omitempty"`
ToolUse *toolUseBlockDelta `json:"toolUse,omitempty"`
Text *string `json:"text,omitempty"`
ToolUse *toolUseBlockDelta `json:"toolUse,omitempty"`
ReasoningContent *reasoningContentDelta `json:"reasoningContent,omitempty"`
}
type toolUseBlockStart struct {
@@ -179,6 +205,11 @@ type toolUseBlockDelta struct {
Input string `json:"input"`
}
type reasoningContentDelta struct {
Text string `json:"text,omitempty"`
Signature string `json:"signature,omitempty"`
}
type bedrockImageGenerationResponse struct {
Images []string `json:"images"`
Error string `json:"error"`
@@ -747,6 +778,22 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
},
}
if origRequest.ReasoningEffort != "" {
thinkingBudget := 1024 // default
switch origRequest.ReasoningEffort {
case "low":
thinkingBudget = 1024
case "medium":
thinkingBudget = 4096
case "high":
thinkingBudget = 16384
}
request.AdditionalModelRequestFields["thinking"] = map[string]interface{}{
"type": "enabled",
"budget_tokens": thinkingBudget,
}
}
if origRequest.Tools != nil {
request.ToolConfig = &bedrockToolConfig{}
if origRequest.ToolChoice == nil {
@@ -787,9 +834,19 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
}
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
var outputContent string
if len(bedrockResponse.Output.Message.Content) > 0 {
outputContent = bedrockResponse.Output.Message.Content[0].Text
var outputContent, reasoningContent, normalContent string
for _, content := range bedrockResponse.Output.Message.Content {
if content.ReasoningContent != nil {
reasoningContent = content.ReasoningContent.ReasoningText.Text
}
if content.Text != "" {
normalContent = content.Text
}
}
if reasoningContent != "" {
outputContent = reasoningContextMarkerStart + reasoningContent + reasoningContextMarkerEnd + normalContent
} else {
outputContent = normalContent
}
choice := chatCompletionChoice{
Index: 0,
@@ -964,8 +1021,18 @@ type message struct {
}
type contentBlock struct {
Text string `json:"text,omitempty"`
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
Text string `json:"text,omitempty"`
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
ReasoningContent *reasoningContent `json:"reasoningContent,omitempty"`
}
type reasoningContent struct {
ReasoningText reasoningText `json:"reasoningText"`
}
type reasoningText struct {
Text string `json:"text,omitempty"`
Signature string `json:"signature,omitempty"`
}
type bedrockToolUse struct {
@@ -1039,8 +1106,22 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
var content bedrockMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else if part.Type == contentTypeImageUrl {
base64Str := part.ImageUrl.Url
prefix, imageType, err := extractImageType(base64Str)
if err != nil {
log.Warn("image url is not supported")
continue
}
base64WoPrefix, _ := strings.CutPrefix(base64Str, prefix)
content.Image = &imageBlock{
Format: imageType,
Source: imageSource{
Bytes: base64WoPrefix,
},
}
} else {
log.Warnf("imageUrl is not supported: %s", part.Type)
log.Warnf("type is not supported: %s", part.Type)
continue
}
contents = append(contents, content)
@@ -1118,3 +1199,18 @@ func hmacHex(key []byte, data string) string {
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func extractImageType(base64Str string) (string, string, error) {
re := regexp.MustCompile(`^data:([^;]+);base64,`)
matches := re.FindStringSubmatch(base64Str)
if len(matches) < 2 {
return "", "", fmt.Errorf("invalid base64 format")
}
mimeType := matches[1] // e.g. image/png
parts := strings.Split(mimeType, "/")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid mimeType")
}
return matches[0], parts[1], nil
}