diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index f27060c9e..e98cb7d54 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "net/http" + "regexp" "strings" "time" @@ -30,6 +31,8 @@ const ( vertexChatCompletionAction = "generateContent" vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" vertexEmbeddingAction = "predict" + reasoningContextMarkerStart = "" + reasoningContextMarkerEnd = "" ) type vertexProviderInitializer struct{} @@ -188,7 +191,10 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { log.Infof("[vertexProvider] receive chunk body: %s", string(chunk)) - if isLastChunk || len(chunk) == 0 { + if isLastChunk { + return []byte(ssePrefix + "[DONE]\n\n"), nil + } + if len(chunk) == 0 { return nil, nil } if name != ApiNameChatCompletion { @@ -259,7 +265,23 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re FinishReason: util.Ptr(candidate.FinishReason), } if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text + part := candidate.Content.Parts[0] + if part.FunctionCall != nil { + args, _ := json.Marshal(part.FunctionCall.Args) + choice.Message.ToolCalls = []toolCall{ + { + Type: "function", + Function: functionCall{ + Name: part.FunctionCall.Name, + Arguments: string(args), + }, + }, + } + } else if part.Thounght != nil && len(candidate.Content.Parts) > 1 { + choice.Message.Content = reasoningContextMarkerStart + part.Text + reasoningContextMarkerEnd + candidate.Content.Parts[1].Text + } else if part.Text != "" { + choice.Message.Content = part.Text + } } else { choice.Message.Content = "" } @@ -301,7 +323,35 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse { var choice chatCompletionChoice if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 { - choice.Delta = &chatMessage{Content: vertexResp.Candidates[0].Content.Parts[0].Text} + part := vertexResp.Candidates[0].Content.Parts[0] + if part.FunctionCall != nil { + args, _ := json.Marshal(part.FunctionCall.Args) + choice.Delta = &chatMessage{ + ToolCalls: []toolCall{ + { + Type: "function", + Function: functionCall{ + Name: part.FunctionCall.Name, + Arguments: string(args), + }, + }, + }, + } + } else if part.Thounght != nil { + if ctx.GetContext("thinking_start") == nil { + choice.Delta = &chatMessage{Content: reasoningContextMarkerStart + part.Text} + ctx.SetContext("thinking_start", true) + } else { + choice.Delta = &chatMessage{Content: part.Text} + } + } else if part.Text != "" { + if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil { + choice.Delta = &chatMessage{Content: reasoningContextMarkerEnd + part.Text} + ctx.SetContext("thinking_end", true) + } else { + choice.Delta = &chatMessage{Content: part.Text} + } + } } streamResponse := chatCompletionResponse{ Id: vertexResp.ResponseId, @@ -351,6 +401,21 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) MaxOutputTokens: request.MaxTokens, }, } + if request.ReasoningEffort != "" { + thinkingBudget := 1024 // default + switch request.ReasoningEffort { + case "low": + thinkingBudget = 1024 + case "medium": + thinkingBudget = 4096 + case "high": + thinkingBudget = 16384 + } + vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{ + IncludeThoughts: true, + ThinkingBudget: thinkingBudget, + } + } if request.Tools != nil { functions := make([]function, 0, len(request.Tools)) for _, tool := range request.Tools { @@ -363,20 +428,60 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) } } shouldAddDummyModelMessage := false + var lastFunctionName string for _, message := range request.Messages { content := vertexChatContent{ - Role: message.Role, - Parts: []vertexPart{ - { - Text: message.StringContent(), + Role: message.Role, + Parts: []vertexPart{}, + } + if len(message.ToolCalls) > 0 { + lastFunctionName = message.ToolCalls[0].Function.Name + args := make(map[string]interface{}) + if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil { + log.Errorf("unable to unmarshal function arguments: %v", err) + } + content.Parts = append(content.Parts, vertexPart{ + FunctionCall: &vertexFunctionCall{ + Name: lastFunctionName, + Args: args, }, - }, + }) + } else { + for _, part := range message.ParseContent() { + switch part.Type { + case contentTypeText: + if message.Role == roleTool { + content.Parts = append(content.Parts, vertexPart{ + FunctionResponse: &vertexFunctionResponse{ + Name: lastFunctionName, + Response: vertexFunctionResponseDetail{ + Output: part.Text, + }, + }, + }) + } else { + content.Parts = append(content.Parts, vertexPart{ + Text: part.Text, + }) + } + case contentTypeImageUrl: + vpart, err := convertImageContent(part.ImageUrl.Url) + if err != nil { + log.Errorf("unable to convert image content: %v", err) + } else { + content.Parts = append(content.Parts, vpart) + } + } + } } // there's no assistant role in vertex and API shall vomit if role is not user or model - if content.Role == roleAssistant { + switch content.Role { + case roleAssistant: content.Role = "model" - } else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason + case roleTool: + content.Role = roleUser + case roleSystem: // converting system prompt to prompt from user for the same reason content.Role = roleUser shouldAddDummyModelMessage = true } @@ -427,9 +532,12 @@ type vertexChatContent struct { } type vertexPart struct { - Text string `json:"text,omitempty"` - InlineData *blob `json:"inlineData,omitempty"` - FileData *fileData `json:"fileData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *blob `json:"inlineData,omitempty"` + FileData *fileData `json:"fileData,omitempty"` + FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"` + Thounght *bool `json:"thought,omitempty"` } type blob struct { @@ -442,6 +550,21 @@ type fileData struct { FileUri string `json:"fileUri"` } +type vertexFunctionCall struct { + Name string `json:"name"` + Args map[string]interface{} `json:"args,omitempty"` +} + +type vertexFunctionResponse struct { + Name string `json:"name"` + Response vertexFunctionResponseDetail `json:"response"` +} + +type vertexFunctionResponseDetail struct { + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` +} + type vertexSystemInstruction struct { Role string `json:"role"` Parts []vertexPart `json:"parts"` @@ -457,11 +580,17 @@ type vertexChatSafetySetting struct { } type vertexChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"` +} + +type vertexThinkingConfig struct { + IncludeThoughts bool `json:"includeThoughts,omitempty"` + ThinkingBudget int `json:"thinkingBudget,omitempty"` } type vertexEmbeddingRequest struct { @@ -665,3 +794,33 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro return proxywasm.SetSharedData(key, data, cas) } + +func convertImageContent(imageUrl string) (vertexPart, error) { + part := vertexPart{} + if strings.HasPrefix(imageUrl, "http") { + arr := strings.Split(imageUrl, ".") + mimeType := "image/" + arr[len(arr)-1] + part.FileData = &fileData{ + MimeType: mimeType, + FileUri: imageUrl, + } + return part, nil + } else { + re := regexp.MustCompile(`^data:([^;]+);base64,`) + matches := re.FindStringSubmatch(imageUrl) + if len(matches) < 2 { + return part, fmt.Errorf("invalid base64 format") + } + + mimeType := matches[1] // e.g. image/png + parts := strings.Split(mimeType, "/") + if len(parts) < 2 { + return part, fmt.Errorf("invalid mimeType") + } + part.InlineData = &blob{ + MimeType: mimeType, + Data: strings.TrimPrefix(imageUrl, matches[0]), + } + return part, nil + } +}