mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
vertex support multi-modal, function call and thinking (#2926)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,6 +31,8 @@ const (
|
|||||||
vertexChatCompletionAction = "generateContent"
|
vertexChatCompletionAction = "generateContent"
|
||||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||||
vertexEmbeddingAction = "predict"
|
vertexEmbeddingAction = "predict"
|
||||||
|
reasoningContextMarkerStart = "<think>"
|
||||||
|
reasoningContextMarkerEnd = "</think>"
|
||||||
)
|
)
|
||||||
|
|
||||||
type vertexProviderInitializer struct{}
|
type vertexProviderInitializer struct{}
|
||||||
@@ -188,7 +191,10 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
|||||||
|
|
||||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||||
if isLastChunk || len(chunk) == 0 {
|
if isLastChunk {
|
||||||
|
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||||
|
}
|
||||||
|
if len(chunk) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if name != ApiNameChatCompletion {
|
if name != ApiNameChatCompletion {
|
||||||
@@ -259,7 +265,23 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
|
|||||||
FinishReason: util.Ptr(candidate.FinishReason),
|
FinishReason: util.Ptr(candidate.FinishReason),
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
part := candidate.Content.Parts[0]
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||||
|
choice.Message.ToolCalls = []toolCall{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: functionCall{
|
||||||
|
Name: part.FunctionCall.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if part.Thounght != nil && len(candidate.Content.Parts) > 1 {
|
||||||
|
choice.Message.Content = reasoningContextMarkerStart + part.Text + reasoningContextMarkerEnd + candidate.Content.Parts[1].Text
|
||||||
|
} else if part.Text != "" {
|
||||||
|
choice.Message.Content = part.Text
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = ""
|
choice.Message.Content = ""
|
||||||
}
|
}
|
||||||
@@ -301,7 +323,35 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
|||||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||||
var choice chatCompletionChoice
|
var choice chatCompletionChoice
|
||||||
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
||||||
choice.Delta = &chatMessage{Content: vertexResp.Candidates[0].Content.Parts[0].Text}
|
part := vertexResp.Candidates[0].Content.Parts[0]
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||||
|
choice.Delta = &chatMessage{
|
||||||
|
ToolCalls: []toolCall{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: functionCall{
|
||||||
|
Name: part.FunctionCall.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if part.Thounght != nil {
|
||||||
|
if ctx.GetContext("thinking_start") == nil {
|
||||||
|
choice.Delta = &chatMessage{Content: reasoningContextMarkerStart + part.Text}
|
||||||
|
ctx.SetContext("thinking_start", true)
|
||||||
|
} else {
|
||||||
|
choice.Delta = &chatMessage{Content: part.Text}
|
||||||
|
}
|
||||||
|
} else if part.Text != "" {
|
||||||
|
if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil {
|
||||||
|
choice.Delta = &chatMessage{Content: reasoningContextMarkerEnd + part.Text}
|
||||||
|
ctx.SetContext("thinking_end", true)
|
||||||
|
} else {
|
||||||
|
choice.Delta = &chatMessage{Content: part.Text}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
streamResponse := chatCompletionResponse{
|
streamResponse := chatCompletionResponse{
|
||||||
Id: vertexResp.ResponseId,
|
Id: vertexResp.ResponseId,
|
||||||
@@ -351,6 +401,21 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
|||||||
MaxOutputTokens: request.MaxTokens,
|
MaxOutputTokens: request.MaxTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if request.ReasoningEffort != "" {
|
||||||
|
thinkingBudget := 1024 // default
|
||||||
|
switch request.ReasoningEffort {
|
||||||
|
case "low":
|
||||||
|
thinkingBudget = 1024
|
||||||
|
case "medium":
|
||||||
|
thinkingBudget = 4096
|
||||||
|
case "high":
|
||||||
|
thinkingBudget = 16384
|
||||||
|
}
|
||||||
|
vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
ThinkingBudget: thinkingBudget,
|
||||||
|
}
|
||||||
|
}
|
||||||
if request.Tools != nil {
|
if request.Tools != nil {
|
||||||
functions := make([]function, 0, len(request.Tools))
|
functions := make([]function, 0, len(request.Tools))
|
||||||
for _, tool := range request.Tools {
|
for _, tool := range request.Tools {
|
||||||
@@ -363,20 +428,60 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
shouldAddDummyModelMessage := false
|
shouldAddDummyModelMessage := false
|
||||||
|
var lastFunctionName string
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
content := vertexChatContent{
|
content := vertexChatContent{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Parts: []vertexPart{
|
Parts: []vertexPart{},
|
||||||
{
|
}
|
||||||
Text: message.StringContent(),
|
if len(message.ToolCalls) > 0 {
|
||||||
|
lastFunctionName = message.ToolCalls[0].Function.Name
|
||||||
|
args := make(map[string]interface{})
|
||||||
|
if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil {
|
||||||
|
log.Errorf("unable to unmarshal function arguments: %v", err)
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, vertexPart{
|
||||||
|
FunctionCall: &vertexFunctionCall{
|
||||||
|
Name: lastFunctionName,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for _, part := range message.ParseContent() {
|
||||||
|
switch part.Type {
|
||||||
|
case contentTypeText:
|
||||||
|
if message.Role == roleTool {
|
||||||
|
content.Parts = append(content.Parts, vertexPart{
|
||||||
|
FunctionResponse: &vertexFunctionResponse{
|
||||||
|
Name: lastFunctionName,
|
||||||
|
Response: vertexFunctionResponseDetail{
|
||||||
|
Output: part.Text,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
content.Parts = append(content.Parts, vertexPart{
|
||||||
|
Text: part.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case contentTypeImageUrl:
|
||||||
|
vpart, err := convertImageContent(part.ImageUrl.Url)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to convert image content: %v", err)
|
||||||
|
} else {
|
||||||
|
content.Parts = append(content.Parts, vpart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// there's no assistant role in vertex and API shall vomit if role is not user or model
|
// there's no assistant role in vertex and API shall vomit if role is not user or model
|
||||||
if content.Role == roleAssistant {
|
switch content.Role {
|
||||||
|
case roleAssistant:
|
||||||
content.Role = "model"
|
content.Role = "model"
|
||||||
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
|
case roleTool:
|
||||||
|
content.Role = roleUser
|
||||||
|
case roleSystem: // converting system prompt to prompt from user for the same reason
|
||||||
content.Role = roleUser
|
content.Role = roleUser
|
||||||
shouldAddDummyModelMessage = true
|
shouldAddDummyModelMessage = true
|
||||||
}
|
}
|
||||||
@@ -430,6 +535,9 @@ type vertexPart struct {
|
|||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
InlineData *blob `json:"inlineData,omitempty"`
|
InlineData *blob `json:"inlineData,omitempty"`
|
||||||
FileData *fileData `json:"fileData,omitempty"`
|
FileData *fileData `json:"fileData,omitempty"`
|
||||||
|
FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
Thounght *bool `json:"thought,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type blob struct {
|
type blob struct {
|
||||||
@@ -442,6 +550,21 @@ type fileData struct {
|
|||||||
FileUri string `json:"fileUri"`
|
FileUri string `json:"fileUri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type vertexFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]interface{} `json:"args,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type vertexFunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response vertexFunctionResponseDetail `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type vertexFunctionResponseDetail struct {
|
||||||
|
Output string `json:"output,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type vertexSystemInstruction struct {
|
type vertexSystemInstruction struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Parts []vertexPart `json:"parts"`
|
Parts []vertexPart `json:"parts"`
|
||||||
@@ -462,6 +585,12 @@ type vertexChatGenerationConfig struct {
|
|||||||
TopK int `json:"topK,omitempty"`
|
TopK int `json:"topK,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||||
|
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type vertexThinkingConfig struct {
|
||||||
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
|
ThinkingBudget int `json:"thinkingBudget,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type vertexEmbeddingRequest struct {
|
type vertexEmbeddingRequest struct {
|
||||||
@@ -665,3 +794,33 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
|
|||||||
|
|
||||||
return proxywasm.SetSharedData(key, data, cas)
|
return proxywasm.SetSharedData(key, data, cas)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertImageContent(imageUrl string) (vertexPart, error) {
|
||||||
|
part := vertexPart{}
|
||||||
|
if strings.HasPrefix(imageUrl, "http") {
|
||||||
|
arr := strings.Split(imageUrl, ".")
|
||||||
|
mimeType := "image/" + arr[len(arr)-1]
|
||||||
|
part.FileData = &fileData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
FileUri: imageUrl,
|
||||||
|
}
|
||||||
|
return part, nil
|
||||||
|
} else {
|
||||||
|
re := regexp.MustCompile(`^data:([^;]+);base64,`)
|
||||||
|
matches := re.FindStringSubmatch(imageUrl)
|
||||||
|
if len(matches) < 2 {
|
||||||
|
return part, fmt.Errorf("invalid base64 format")
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := matches[1] // e.g. image/png
|
||||||
|
parts := strings.Split(mimeType, "/")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return part, fmt.Errorf("invalid mimeType")
|
||||||
|
}
|
||||||
|
part.InlineData = &blob{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: strings.TrimPrefix(imageUrl, matches[0]),
|
||||||
|
}
|
||||||
|
return part, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user