vertex support multi-modal, function call and thinking (#2926)

This commit is contained in:
rinfx
2025-09-18 14:22:22 +08:00
committed by GitHub
parent 78860ce399
commit d7bebf79e1

View File

@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"time"
@@ -30,6 +31,8 @@ const (
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexEmbeddingAction = "predict"
reasoningContextMarkerStart = "<think>"
reasoningContextMarkerEnd = "</think>"
)
type vertexProviderInitializer struct{}
@@ -188,7 +191,10 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
if isLastChunk || len(chunk) == 0 {
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
}
if len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
@@ -259,7 +265,23 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
FinishReason: util.Ptr(candidate.FinishReason),
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
part := candidate.Content.Parts[0]
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
choice.Message.ToolCalls = []toolCall{
{
Type: "function",
Function: functionCall{
Name: part.FunctionCall.Name,
Arguments: string(args),
},
},
}
} else if part.Thounght != nil && len(candidate.Content.Parts) > 1 {
choice.Message.Content = reasoningContextMarkerStart + part.Text + reasoningContextMarkerEnd + candidate.Content.Parts[1].Text
} else if part.Text != "" {
choice.Message.Content = part.Text
}
} else {
choice.Message.Content = ""
}
@@ -301,7 +323,35 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
choice.Delta = &chatMessage{Content: vertexResp.Candidates[0].Content.Parts[0].Text}
part := vertexResp.Candidates[0].Content.Parts[0]
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
choice.Delta = &chatMessage{
ToolCalls: []toolCall{
{
Type: "function",
Function: functionCall{
Name: part.FunctionCall.Name,
Arguments: string(args),
},
},
},
}
} else if part.Thounght != nil {
if ctx.GetContext("thinking_start") == nil {
choice.Delta = &chatMessage{Content: reasoningContextMarkerStart + part.Text}
ctx.SetContext("thinking_start", true)
} else {
choice.Delta = &chatMessage{Content: part.Text}
}
} else if part.Text != "" {
if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil {
choice.Delta = &chatMessage{Content: reasoningContextMarkerEnd + part.Text}
ctx.SetContext("thinking_end", true)
} else {
choice.Delta = &chatMessage{Content: part.Text}
}
}
}
streamResponse := chatCompletionResponse{
Id: vertexResp.ResponseId,
@@ -351,6 +401,21 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
MaxOutputTokens: request.MaxTokens,
},
}
if request.ReasoningEffort != "" {
thinkingBudget := 1024 // default
switch request.ReasoningEffort {
case "low":
thinkingBudget = 1024
case "medium":
thinkingBudget = 4096
case "high":
thinkingBudget = 16384
}
vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: thinkingBudget,
}
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
for _, tool := range request.Tools {
@@ -363,20 +428,60 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
}
}
shouldAddDummyModelMessage := false
var lastFunctionName string
for _, message := range request.Messages {
content := vertexChatContent{
Role: message.Role,
Parts: []vertexPart{
{
Text: message.StringContent(),
Role: message.Role,
Parts: []vertexPart{},
}
if len(message.ToolCalls) > 0 {
lastFunctionName = message.ToolCalls[0].Function.Name
args := make(map[string]interface{})
if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil {
log.Errorf("unable to unmarshal function arguments: %v", err)
}
content.Parts = append(content.Parts, vertexPart{
FunctionCall: &vertexFunctionCall{
Name: lastFunctionName,
Args: args,
},
},
})
} else {
for _, part := range message.ParseContent() {
switch part.Type {
case contentTypeText:
if message.Role == roleTool {
content.Parts = append(content.Parts, vertexPart{
FunctionResponse: &vertexFunctionResponse{
Name: lastFunctionName,
Response: vertexFunctionResponseDetail{
Output: part.Text,
},
},
})
} else {
content.Parts = append(content.Parts, vertexPart{
Text: part.Text,
})
}
case contentTypeImageUrl:
vpart, err := convertImageContent(part.ImageUrl.Url)
if err != nil {
log.Errorf("unable to convert image content: %v", err)
} else {
content.Parts = append(content.Parts, vpart)
}
}
}
}
// there's no assistant role in vertex and API shall vomit if role is not user or model
if content.Role == roleAssistant {
switch content.Role {
case roleAssistant:
content.Role = "model"
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
case roleTool:
content.Role = roleUser
case roleSystem: // converting system prompt to prompt from user for the same reason
content.Role = roleUser
shouldAddDummyModelMessage = true
}
@@ -427,9 +532,12 @@ type vertexChatContent struct {
}
type vertexPart struct {
Text string `json:"text,omitempty"`
InlineData *blob `json:"inlineData,omitempty"`
FileData *fileData `json:"fileData,omitempty"`
Text string `json:"text,omitempty"`
InlineData *blob `json:"inlineData,omitempty"`
FileData *fileData `json:"fileData,omitempty"`
FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"`
Thounght *bool `json:"thought,omitempty"`
}
type blob struct {
@@ -442,6 +550,21 @@ type fileData struct {
FileUri string `json:"fileUri"`
}
type vertexFunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args,omitempty"`
}
type vertexFunctionResponse struct {
Name string `json:"name"`
Response vertexFunctionResponseDetail `json:"response"`
}
type vertexFunctionResponseDetail struct {
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
}
type vertexSystemInstruction struct {
Role string `json:"role"`
Parts []vertexPart `json:"parts"`
@@ -457,11 +580,17 @@ type vertexChatSafetySetting struct {
}
type vertexChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
}
type vertexThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget int `json:"thinkingBudget,omitempty"`
}
type vertexEmbeddingRequest struct {
@@ -665,3 +794,33 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
return proxywasm.SetSharedData(key, data, cas)
}
func convertImageContent(imageUrl string) (vertexPart, error) {
part := vertexPart{}
if strings.HasPrefix(imageUrl, "http") {
arr := strings.Split(imageUrl, ".")
mimeType := "image/" + arr[len(arr)-1]
part.FileData = &fileData{
MimeType: mimeType,
FileUri: imageUrl,
}
return part, nil
} else {
re := regexp.MustCompile(`^data:([^;]+);base64,`)
matches := re.FindStringSubmatch(imageUrl)
if len(matches) < 2 {
return part, fmt.Errorf("invalid base64 format")
}
mimeType := matches[1] // e.g. image/png
parts := strings.Split(mimeType, "/")
if len(parts) < 2 {
return part, fmt.Errorf("invalid mimeType")
}
part.InlineData = &blob{
MimeType: mimeType,
Data: strings.TrimPrefix(imageUrl, matches[0]),
}
return part, nil
}
}