mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(ai-proxy): bedrock support tool use (#2730)
This commit is contained in:
@@ -96,8 +96,31 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Role != nil {
|
||||
chatChoice.Delta.Role = *bedrockEvent.Role
|
||||
}
|
||||
if bedrockEvent.Start != nil {
|
||||
chatChoice.Delta.Content = nil
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: bedrockEvent.Start.ToolUse.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
if bedrockEvent.Delta != nil {
|
||||
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
|
||||
if bedrockEvent.Delta.ToolUse != nil {
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Arguments: bedrockEvent.Delta.ToolUse.Input,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if bedrockEvent.StopReason != nil {
|
||||
chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason))
|
||||
@@ -700,9 +723,12 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
systemMessages := make([]systemContentBlock, 0)
|
||||
|
||||
for _, msg := range origRequest.Messages {
|
||||
if msg.Role == roleSystem {
|
||||
switch msg.Role {
|
||||
case roleSystem:
|
||||
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
||||
} else {
|
||||
case roleTool:
|
||||
messages = append(messages, chatToolMessage2BedrockMessage(msg))
|
||||
default:
|
||||
messages = append(messages, chatMessage2BedrockMessage(msg))
|
||||
}
|
||||
}
|
||||
@@ -721,6 +747,36 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
if origRequest.Tools != nil {
|
||||
request.ToolConfig = &bedrockToolConfig{}
|
||||
if origRequest.ToolChoice == nil {
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
} else if choice_type, ok := origRequest.ToolChoice.(string); ok {
|
||||
switch choice_type {
|
||||
case "required":
|
||||
request.ToolConfig.ToolChoice.Any = &struct{}{}
|
||||
case "auto":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
case "none":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
}
|
||||
} else if choice, ok := origRequest.ToolChoice.(toolChoice); ok {
|
||||
request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{
|
||||
Name: choice.Function.Name,
|
||||
}
|
||||
}
|
||||
request.ToolConfig.Tools = []bedrockTool{}
|
||||
for _, tool := range origRequest.Tools {
|
||||
request.ToolConfig.Tools = append(request.ToolConfig.Tools, bedrockTool{
|
||||
ToolSpec: bedrockToolSpecification{
|
||||
InputSchema: bedrockToolInputSchemaJson{Json: tool.Function.Parameters},
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range b.config.bedrockAdditionalFields {
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
@@ -735,16 +791,29 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
if len(bedrockResponse.Output.Message.Content) > 0 {
|
||||
outputContent = bedrockResponse.Output.Message.Content[0].Text
|
||||
}
|
||||
choices := []chatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &chatMessage{
|
||||
Role: bedrockResponse.Output.Message.Role,
|
||||
Content: outputContent,
|
||||
},
|
||||
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{
|
||||
Role: bedrockResponse.Output.Message.Role,
|
||||
Content: outputContent,
|
||||
},
|
||||
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
|
||||
}
|
||||
choice.Message.ToolCalls = []toolCall{}
|
||||
for _, content := range bedrockResponse.Output.Message.Content {
|
||||
if content.ToolUse != nil {
|
||||
args, _ := json.Marshal(content.ToolUse.Input)
|
||||
choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall{
|
||||
Id: content.ToolUse.ToolUseId,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: content.ToolUse.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
choices := []chatCompletionChoice{choice}
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
return &chatCompletionResponse{
|
||||
@@ -781,6 +850,8 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
return finishReasonStop
|
||||
case "max_tokens":
|
||||
return finishReasonLength
|
||||
case "tool_use":
|
||||
return finishReasonToolCall
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
@@ -792,20 +863,48 @@ type bedrockTextGenRequest struct {
|
||||
InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
|
||||
PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"`
|
||||
ToolConfig *bedrockToolConfig `json:"toolConfig,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolConfig struct {
|
||||
Tools []bedrockTool `json:"tools,omitempty"`
|
||||
ToolChoice bedrockToolChoice `json:"toolChoice,omitempty"`
|
||||
}
|
||||
|
||||
type PerformanceConfiguration struct {
|
||||
Latency string `json:"latency,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockTool struct {
|
||||
ToolSpec bedrockToolSpecification `json:"toolSpec,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolChoice struct {
|
||||
Any *struct{} `json:"any,omitempty"`
|
||||
Auto *struct{} `json:"auto,omitempty"`
|
||||
Tool *bedrockToolSpecification `json:"tool,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolSpecification struct {
|
||||
InputSchema bedrockToolInputSchemaJson `json:"inputSchema,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolInputSchemaJson struct {
|
||||
Json map[string]interface{} `json:"json,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []bedrockMessageContent `json:"content"`
|
||||
}
|
||||
|
||||
type bedrockMessageContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
}
|
||||
|
||||
type systemContentBlock struct {
|
||||
@@ -821,6 +920,22 @@ type imageSource struct {
|
||||
Bytes string `json:"bytes,omitempty"`
|
||||
}
|
||||
|
||||
type toolResultBlock struct {
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
Content []toolResultContentBlock `json:"content"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
type toolResultContentBlock struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type toolUseBlock struct {
|
||||
Input map[string]interface{} `json:"input"`
|
||||
Name string `json:"name"`
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
type bedrockInferenceConfig struct {
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
@@ -844,13 +959,19 @@ type converseOutputMemberMessage struct {
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Content []contentBlockMemberText `json:"content"`
|
||||
|
||||
Role string `json:"role"`
|
||||
Content []contentBlock `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type contentBlockMemberText struct {
|
||||
Text string `json:"text"`
|
||||
type contentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolUse struct {
|
||||
Name string `json:"name"`
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
type tokenUsage struct {
|
||||
@@ -861,9 +982,53 @@ type tokenUsage struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
toolResultContent := &toolResultBlock{}
|
||||
toolResultContent.ToolUseId = chatMessage.ToolCallId
|
||||
if text, ok := chatMessage.Content.(string); ok {
|
||||
toolResultContent.Content = []toolResultContentBlock{
|
||||
{
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
openaiContent := chatMessage.ParseContent()
|
||||
for _, part := range openaiContent {
|
||||
var content bedrockMessageContent
|
||||
if part.Type == contentTypeText {
|
||||
content.Text = part.Text
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{
|
||||
ToolResult: toolResultContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
if chatMessage.IsStringContent() {
|
||||
return bedrockMessage{
|
||||
var result bedrockMessage
|
||||
if len(chatMessage.ToolCalls) > 0 {
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: []bedrockMessageContent{{}},
|
||||
}
|
||||
params := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms)
|
||||
result.Content[0].ToolUse = &toolUseBlock{
|
||||
Input: params,
|
||||
Name: chatMessage.ToolCalls[0].Function.Name,
|
||||
ToolUseId: chatMessage.ToolCalls[0].Id,
|
||||
}
|
||||
} else if chatMessage.IsStringContent() {
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}},
|
||||
}
|
||||
@@ -880,11 +1045,12 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
}
|
||||
contents = append(contents, content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: contents,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
|
||||
@@ -171,6 +171,7 @@ type chatMessage struct {
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
|
||||
@@ -377,14 +378,14 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Function functionCall `json:"function"`
|
||||
}
|
||||
|
||||
type functionCall struct {
|
||||
Id string `json:"id"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
@@ -137,9 +137,11 @@ const (
|
||||
roleSystem = "system"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
roleTool = "tool"
|
||||
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
finishReasonToolCall = "tool_calls"
|
||||
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
|
||||
Reference in New Issue
Block a user