package provider import ( "crypto" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "errors" "fmt" "net/http" "regexp" "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( vertexAuthDomain = "oauth2.googleapis.com" vertexDomain = "aiplatform.googleapis.com" // /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION} vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s" // Express Mode 路径模板 (不含 project/location) vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s" vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s" // OpenAI-compatible endpoint 路径模板 // /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions" vertexChatCompletionAction = "generateContent" vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" vertexAnthropicMessageAction = "rawPredict" vertexAnthropicMessageStreamAction = "streamRawPredict" vertexEmbeddingAction = "predict" vertexGlobalRegion = "global" contextClaudeMarker = "isClaudeRequest" contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" contextVertexRawMarker = "isVertexRawRequest" vertexAnthropicVersion = "vertex-2023-10-16" ) // vertexRawPathRegex 匹配原生 Vertex AI REST API 路径 // 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action} // 允许任意 basePath 前缀,兼容 basePathHandling 配置 var vertexRawPathRegex = regexp.MustCompile(`^.*/([^/]+)/projects/([^/]+)/locations/([^/]+)/publishers/([^/]+)/models/([^/:]+):([^/?]+)`) type vertexProviderInitializer struct{} func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error { // Express Mode: 如果配置了 apiTokens,则使用 API Key 认证 if len(config.apiTokens) > 0 { // Express Mode 与 OpenAI 兼容模式互斥 if config.vertexOpenAICompatible { return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)") } // Express Mode 不需要其他配置 return nil } // OpenAI 兼容模式: 需要 OAuth 认证配置 if config.vertexOpenAICompatible { if config.vertexAuthKey == "" { return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode") } if config.vertexRegion == "" || config.vertexProjectId == "" { return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode") } if config.vertexAuthServiceName == "" { return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode") } return nil } // 标准模式: 保持原有验证逻辑 if config.vertexAuthKey == "" { return errors.New("missing vertexAuthKey in vertex provider config") } if config.vertexRegion == "" || config.vertexProjectId == "" { return errors.New("missing vertexRegion or vertexProjectId in vertex provider config") } if config.vertexAuthServiceName == "" { return errors.New("missing vertexAuthServiceName in vertex provider config") } return nil } func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ string(ApiNameChatCompletion): vertexPathTemplate, string(ApiNameEmbeddings): vertexPathTemplate, string(ApiNameImageGeneration): vertexPathTemplate, string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换 } } func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { config.setDefaultCapabilities(v.DefaultCapabilities()) provider := &vertexProvider{ config: config, contextCache: createContextCache(&config), claude: &claudeProvider{ config: config, contextCache: createContextCache(&config), }, } // 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置) if !provider.isExpressMode() { provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{ Domain: vertexAuthDomain, ServiceName: config.vertexAuthServiceName, Port: 443, }) } return provider, nil } // isExpressMode 检测是否启用 Express Mode // 如果配置了 apiTokens,则使用 Express Mode(API Key 认证) func (v *vertexProvider) isExpressMode() bool { return len(v.config.apiTokens) > 0 } // isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式 // 使用 Vertex AI 的 OpenAI-compatible Chat Completions API func (v *vertexProvider) isOpenAICompatibleMode() bool { return v.config.vertexOpenAICompatible } type vertexProvider struct { client wrapper.HttpClient config ProviderConfig contextCache *contextCache claude *claudeProvider } func (v *vertexProvider) GetProviderType() string { return providerTypeVertex } func (v *vertexProvider) GetApiName(path string) ApiName { // 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀 // 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action} // 必须在其他 action 检查之前,因为 :predict、:generateContent 等 action 会被其他规则匹配 if vertexRawPathRegex.MatchString(path) { return ApiNameVertexRaw } if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) { return ApiNameChatCompletion } if strings.HasSuffix(path, vertexEmbeddingAction) { return ApiNameEmbeddings } return "" } func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { v.config.handleRequestHeaders(v, ctx, apiName) return nil } func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { var finalVertexDomain string if v.isExpressMode() { // Express Mode: 固定域名,不带 region 前缀 finalVertexDomain = vertexDomain } else { // 标准模式: 带 region 前缀 if v.config.vertexRegion != vertexGlobalRegion { finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain) } else { finalVertexDomain = vertexDomain } } util.OverwriteRequestHostHeader(headers, finalVertexDomain) } func (v *vertexProvider) getToken() (cached bool, err error) { cacheKeyName := v.buildTokenKey() cachedAccessToken, err := v.getCachedAccessToken(cacheKeyName) if err == nil && cachedAccessToken != "" { _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+cachedAccessToken) return true, nil } var key ServiceAccountKey if err := json.Unmarshal([]byte(v.config.vertexAuthKey), &key); err != nil { return false, fmt.Errorf("[vertex]: unable to unmarshal auth key json: %v", err) } if key.ClientEmail == "" || key.PrivateKey == "" || key.TokenURI == "" { return false, fmt.Errorf("[vertex]: missing auth params") } jwtToken, err := createJWT(&key) if err != nil { log.Errorf("[vertex]: unable to create JWT token: %v", err) return false, err } err = v.getAccessToken(jwtToken) if err != nil { log.Errorf("[vertex]: unable to get access token: %v", err) return false, err } return false, err } func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !v.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } // Vertex Raw 模式: 透传请求体,只做 OAuth 认证 // 用于直接访问 Vertex AI REST API,不做协议转换 // 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用 if apiName == ApiNameVertexRaw { ctx.SetContext(contextVertexRawMarker, true) // Express Mode 不需要 OAuth 认证 if v.isExpressMode() { return types.ActionContinue, nil } // 标准模式需要获取 OAuth token cached, err := v.getToken() if cached { return types.ActionContinue, nil } if err == nil { return types.ActionPause, nil } return types.ActionContinue, err } if v.config.IsOriginal() { return types.ActionContinue, nil } headers := util.GetRequestHeaders() // OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射 if v.isOpenAICompatibleMode() { ctx.SetContext(contextOpenAICompatibleMarker, true) body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers) headers.Set("Content-Length", fmt.Sprint(len(body))) util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) if err != nil { return types.ActionContinue, err } // OpenAI 兼容模式需要 OAuth token cached, err := v.getToken() if cached { return types.ActionContinue, nil } if err == nil { return types.ActionPause, nil } return types.ActionContinue, err } body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) headers.Set("Content-Length", fmt.Sprint(len(body))) if v.isExpressMode() { // Express Mode: 不需要 Authorization header,API Key 已在 URL 中 headers.Del("Authorization") util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) return types.ActionContinue, err } // 标准模式: 需要获取 OAuth token util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) if err != nil { return types.ActionContinue, err } cached, err := v.getToken() if cached { return types.ActionContinue, nil } if err == nil { return types.ActionPause, nil } return types.ActionContinue, err } func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { switch apiName { case ApiNameChatCompletion: return v.onChatCompletionRequestBody(ctx, body, headers) case ApiNameEmbeddings: return v.onEmbeddingsRequestBody(ctx, body, headers) case ApiNameImageGeneration: return v.onImageGenerationRequestBody(ctx, body, headers) default: return body, nil } } // onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求 // 不转换请求体格式,只进行模型映射和路径设置 func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { if apiName != ApiNameChatCompletion { return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API") } // 解析请求进行模型映射 request := &chatCompletionRequest{} if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { return nil, err } // 设置 OpenAI 兼容端点路径 path := v.getOpenAICompatibleRequestPath() util.OverwriteRequestPathHeader(headers, path) // 如果模型被映射,需要更新请求体中的模型字段 if request.Model != "" { body, _ = sjson.SetBytes(body, "model", request.Model) } // 保持 OpenAI 格式,直接返回(可能更新了模型字段) return body, nil } func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { request := &chatCompletionRequest{} err := v.config.parseRequestAndMapModel(ctx, request, body) if err != nil { return nil, err } if strings.HasPrefix(request.Model, "claude") { ctx.SetContext(contextClaudeMarker, true) path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream) util.OverwriteRequestPathHeader(headers, path) claudeRequest := v.claude.buildClaudeTextGenRequest(request) claudeRequest.Model = "" claudeRequest.AnthropicVersion = vertexAnthropicVersion claudeBody, err := json.Marshal(claudeRequest) if err != nil { return nil, err } return claudeBody, nil } else { path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) util.OverwriteRequestPathHeader(headers, path) vertexRequest := v.buildVertexChatRequest(request) return json.Marshal(vertexRequest) } } func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { request := &embeddingsRequest{} if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { return nil, err } path := v.getRequestPath(ApiNameEmbeddings, request.Model, false) util.OverwriteRequestPathHeader(headers, path) vertexRequest := v.buildEmbeddingRequest(request) return json.Marshal(vertexRequest) } func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { request := &imageGenerationRequest{} if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { return nil, err } // 图片生成不使用流式端点,需要完整响应 path := v.getRequestPath(ApiNameImageGeneration, request.Model, false) util.OverwriteRequestPathHeader(headers, path) vertexRequest := v.buildVertexImageGenerationRequest(request) return json.Marshal(vertexRequest) } func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest { // 构建安全设置 safetySettings := make([]vertexChatSafetySetting, 0) for category, threshold := range v.config.geminiSafetySetting { safetySettings = append(safetySettings, vertexChatSafetySetting{ Category: category, Threshold: threshold, }) } // 解析尺寸参数 aspectRatio, imageSize := v.parseImageSize(request.Size) // 确定输出 MIME 类型 mimeType := "image/png" if request.OutputFormat != "" { switch request.OutputFormat { case "jpeg", "jpg": mimeType = "image/jpeg" case "webp": mimeType = "image/webp" default: mimeType = "image/png" } } vertexRequest := &vertexChatRequest{ Contents: []vertexChatContent{{ Role: roleUser, Parts: []vertexPart{{ Text: request.Prompt, }}, }}, SafetySettings: safetySettings, GenerationConfig: vertexChatGenerationConfig{ Temperature: 1.0, MaxOutputTokens: 32768, ResponseModalities: []string{"TEXT", "IMAGE"}, ImageConfig: &vertexImageConfig{ AspectRatio: aspectRatio, ImageSize: imageSize, ImageOutputOptions: &vertexImageOutputOptions{ MimeType: mimeType, }, PersonGeneration: "ALLOW_ALL", }, }, } return vertexRequest } // parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize // Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9 // Vertex AI 支持的 imageSize: 1k, 2k, 4k func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) { // 默认值 aspectRatio = "1:1" imageSize = "1k" if size == "" { return } // 预定义的尺寸映射(OpenAI 标准尺寸) sizeMapping := map[string]struct { aspectRatio string imageSize string }{ // OpenAI DALL-E 标准尺寸 "256x256": {"1:1", "1k"}, "512x512": {"1:1", "1k"}, "1024x1024": {"1:1", "1k"}, "1792x1024": {"16:9", "2k"}, "1024x1792": {"9:16", "2k"}, // 扩展尺寸支持 "2048x2048": {"1:1", "2k"}, "4096x4096": {"1:1", "4k"}, // 3:2 和 2:3 比例 "1536x1024": {"3:2", "2k"}, "1024x1536": {"2:3", "2k"}, // 4:3 和 3:4 比例 "1024x768": {"4:3", "1k"}, "768x1024": {"3:4", "1k"}, "1365x1024": {"4:3", "1k"}, "1024x1365": {"3:4", "1k"}, // 5:4 和 4:5 比例 "1280x1024": {"5:4", "1k"}, "1024x1280": {"4:5", "1k"}, // 21:9 超宽比例 "2560x1080": {"21:9", "2k"}, } if mapping, ok := sizeMapping[size]; ok { return mapping.aspectRatio, mapping.imageSize } return } func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { // OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列 // Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) { return util.DecodeUnicodeEscapesInSSE(chunk), nil } if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk) } log.Infof("[vertexProvider] receive chunk body: %s", string(chunk)) if isLastChunk { return []byte(ssePrefix + "[DONE]\n\n"), nil } if len(chunk) == 0 { return nil, nil } if name != ApiNameChatCompletion { return chunk, nil } responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") for _, data := range lines { if len(data) < 6 { // ignore blank line or wrong format continue } data = data[6:] var vertexResp vertexChatResponse if err := json.Unmarshal([]byte(data), &vertexResp); err != nil { log.Errorf("unable to unmarshal vertex response: %v", err) continue } response := v.buildChatCompletionStreamResponse(ctx, &vertexResp) responseBody, err := json.Marshal(response) if err != nil { log.Errorf("unable to marshal response: %v", err) return nil, err } v.appendResponse(responseBuilder, string(responseBody)) } modifiedResponseChunk := responseBuilder.String() log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) return []byte(modifiedResponseChunk), nil } func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { // OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列 // Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) { return util.DecodeUnicodeEscapes(body), nil } if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.TransformResponseBody(ctx, apiName, body) } switch apiName { case ApiNameChatCompletion: return v.onChatCompletionResponseBody(ctx, body) case ApiNameEmbeddings: return v.onEmbeddingsResponseBody(ctx, body) case ApiNameImageGeneration: return v.onImageGenerationResponseBody(ctx, body) default: return body, nil } } func (v *vertexProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { vertexResponse := &vertexChatResponse{} if err := json.Unmarshal(body, vertexResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal vertex chat response: %v", err) } response := v.buildChatCompletionResponse(ctx, vertexResponse) return json.Marshal(response) } func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *vertexChatResponse) *chatCompletionResponse { fullTextResponse := chatCompletionResponse{ Id: response.ResponseId, Object: objectChatCompletion, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), Choices: make([]chatCompletionChoice, 0, len(response.Candidates)), Usage: &usage{ PromptTokens: response.UsageMetadata.PromptTokenCount, CompletionTokens: response.UsageMetadata.CandidatesTokenCount, TotalTokens: response.UsageMetadata.TotalTokenCount, CompletionTokensDetails: &completionTokensDetails{ ReasoningTokens: response.UsageMetadata.ThoughtsTokenCount, }, }, } for _, candidate := range response.Candidates { choice := chatCompletionChoice{ Index: candidate.Index, Message: &chatMessage{ Role: roleAssistant, }, FinishReason: util.Ptr(candidate.FinishReason), } if len(candidate.Content.Parts) > 0 { part := candidate.Content.Parts[0] if part.FunctionCall != nil { args, _ := json.Marshal(part.FunctionCall.Args) choice.Message.ToolCalls = []toolCall{ { Type: "function", Function: functionCall{ Name: part.FunctionCall.Name, Arguments: string(args), }, }, } } else if part.Thounght != nil && len(candidate.Content.Parts) > 1 { choice.Message.Content = reasoningStartTag + part.Text + reasoningEndTag + candidate.Content.Parts[1].Text } else if part.Text != "" { choice.Message.Content = part.Text } } else { choice.Message.Content = "" } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse } func (v *vertexProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { vertexResponse := &vertexEmbeddingResponse{} if err := json.Unmarshal(body, vertexResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal vertex embeddings response: %v", err) } response := v.buildEmbeddingsResponse(ctx, vertexResponse) return json.Marshal(response) } func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertexResp *vertexEmbeddingResponse) *embeddingsResponse { response := embeddingsResponse{ Object: "list", Data: make([]embedding, 0, len(vertexResp.Predictions)), Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), } totalTokens := 0 for _, item := range vertexResp.Predictions { response.Data = append(response.Data, embedding{ Object: `embedding`, Index: 0, Embedding: item.Embeddings.Values, }) if item.Embeddings.Statistics != nil { totalTokens += item.Embeddings.Statistics.TokenCount } } response.Usage.TotalTokens = totalTokens return &response } func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { // 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据 // 这样可以显著减少内存分配和复制次数 response := v.buildImageGenerationResponseFromJSON(body) return json.Marshal(response) } // buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应 // 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高 func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse { result := gjson.ParseBytes(body) data := make([]imageGenerationData, 0) // 遍历所有 candidates,提取图片数据 candidates := result.Get("candidates") candidates.ForEach(func(_, candidate gjson.Result) bool { parts := candidate.Get("content.parts") parts.ForEach(func(_, part gjson.Result) bool { // 跳过思考过程 (thought: true) if part.Get("thought").Bool() { return true } // 提取图片数据 inlineData := part.Get("inlineData.data") if inlineData.Exists() && inlineData.String() != "" { data = append(data, imageGenerationData{ B64: inlineData.String(), }) } return true }) return true }) // 提取 usage 信息 usage := result.Get("usageMetadata") return &imageGenerationResponse{ Created: time.Now().UnixMilli() / 1000, Data: data, Usage: &imageGenerationUsage{ TotalTokens: int(usage.Get("totalTokenCount").Int()), InputTokens: int(usage.Get("promptTokenCount").Int()), OutputTokens: int(usage.Get("candidatesTokenCount").Int()), }, } } func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse { var choice chatCompletionChoice choice.Delta = &chatMessage{} if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 { part := vertexResp.Candidates[0].Content.Parts[0] if part.FunctionCall != nil { args, _ := json.Marshal(part.FunctionCall.Args) choice.Delta = &chatMessage{ ToolCalls: []toolCall{ { Type: "function", Function: functionCall{ Name: part.FunctionCall.Name, Arguments: string(args), }, }, }, } } else if part.Thounght != nil { if ctx.GetContext("thinking_start") == nil { choice.Delta = &chatMessage{Content: reasoningStartTag + part.Text} ctx.SetContext("thinking_start", true) } else { choice.Delta = &chatMessage{Content: part.Text} } } else if part.Text != "" { if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil { choice.Delta = &chatMessage{Content: reasoningEndTag + part.Text} ctx.SetContext("thinking_end", true) } else { choice.Delta = &chatMessage{Content: part.Text} } } } streamResponse := chatCompletionResponse{ Id: vertexResp.ResponseId, Object: objectChatCompletionChunk, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), Choices: []chatCompletionChoice{choice}, Usage: &usage{ PromptTokens: vertexResp.UsageMetadata.PromptTokenCount, CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount, TotalTokens: vertexResp.UsageMetadata.TotalTokenCount, CompletionTokensDetails: &completionTokensDetails{ ReasoningTokens: vertexResp.UsageMetadata.ThoughtsTokenCount, }, }, } return &streamResponse } func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string { action := "" if stream { action = vertexAnthropicMessageStreamAction } else { action = vertexAnthropicMessageAction } if v.isExpressMode() { // Express Mode: 简化路径 + API Key 参数 basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action) apiKey := v.config.GetRandomToken() // 如果 action 已经包含 ?,使用 & 拼接 var fullPath string if strings.Contains(action, "?") { fullPath = basePath + "&key=" + apiKey } else { fullPath = basePath + "?key=" + apiKey } return fullPath } path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) return path } func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { action := "" switch apiName { case ApiNameEmbeddings: action = vertexEmbeddingAction case ApiNameImageGeneration: // 图片生成使用非流式端点,需要完整响应 action = vertexChatCompletionAction default: if stream { action = vertexChatCompletionStreamAction } else { action = vertexChatCompletionAction } } if v.isExpressMode() { // Express Mode: 简化路径 + API Key 参数 basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action) apiKey := v.config.GetRandomToken() // 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接 var fullPath string if strings.Contains(action, "?") { fullPath = basePath + "&key=" + apiKey } else { fullPath = basePath + "?key=" + apiKey } return fullPath } path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) return path } // getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径 func (v *vertexProvider) getOpenAICompatibleRequestPath() string { return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion) } func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest { safetySettings := make([]vertexChatSafetySetting, 0) for category, threshold := range v.config.geminiSafetySetting { safetySettings = append(safetySettings, vertexChatSafetySetting{ Category: category, Threshold: threshold, }) } vertexRequest := vertexChatRequest{ Contents: make([]vertexChatContent, 0), SafetySettings: safetySettings, GenerationConfig: vertexChatGenerationConfig{ Temperature: request.Temperature, TopP: request.TopP, MaxOutputTokens: request.MaxTokens, }, } if request.ReasoningEffort != "" { thinkingConfig := vertexThinkingConfig{ IncludeThoughts: true, ThinkingBudget: 1024, } switch request.ReasoningEffort { case "none": thinkingConfig.IncludeThoughts = false thinkingConfig.ThinkingBudget = 0 case "low": thinkingConfig.ThinkingBudget = 1024 case "medium": thinkingConfig.ThinkingBudget = 4096 case "high": thinkingConfig.ThinkingBudget = 16384 } vertexRequest.GenerationConfig.ThinkingConfig = thinkingConfig } if request.Tools != nil { functions := make([]function, 0, len(request.Tools)) for _, tool := range request.Tools { functions = append(functions, tool.Function) } vertexRequest.Tools = []vertexTool{ { FunctionDeclarations: functions, }, } } shouldAddDummyModelMessage := false var lastFunctionName string for _, message := range request.Messages { content := vertexChatContent{ Role: message.Role, Parts: []vertexPart{}, } if len(message.ToolCalls) > 0 { lastFunctionName = message.ToolCalls[0].Function.Name args := make(map[string]interface{}) if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil { log.Errorf("unable to unmarshal function arguments: %v", err) } content.Parts = append(content.Parts, vertexPart{ FunctionCall: &vertexFunctionCall{ Name: lastFunctionName, Args: args, }, }) } else { for _, part := range message.ParseContent() { switch part.Type { case contentTypeText: if message.Role == roleTool { content.Parts = append(content.Parts, vertexPart{ FunctionResponse: &vertexFunctionResponse{ Name: lastFunctionName, Response: vertexFunctionResponseDetail{ Output: part.Text, }, }, }) } else { content.Parts = append(content.Parts, vertexPart{ Text: part.Text, }) } case contentTypeImageUrl: vpart, err := convertMediaContent(part.ImageUrl.Url) if err != nil { log.Errorf("unable to convert image content: %v", err) } else { content.Parts = append(content.Parts, vpart) } } } } // there's no assistant role in vertex and API shall vomit if role is not user or model switch content.Role { case roleAssistant: content.Role = "model" case roleTool: content.Role = roleUser case roleSystem: // converting system prompt to prompt from user for the same reason content.Role = roleUser shouldAddDummyModelMessage = true } vertexRequest.Contents = append(vertexRequest.Contents, content) // if a system message is the last message, we need to add a dummy model message to make vertex happy if shouldAddDummyModelMessage { vertexRequest.Contents = append(vertexRequest.Contents, vertexChatContent{ Role: "model", Parts: []vertexPart{ { Text: "Okay", }, }, }) shouldAddDummyModelMessage = false } } return &vertexRequest } func (v *vertexProvider) buildEmbeddingRequest(request *embeddingsRequest) *vertexEmbeddingRequest { inputs := request.ParseInput() instances := make([]vertexEmbeddingInstance, len(inputs)) for i, input := range inputs { instances[i] = vertexEmbeddingInstance{ Content: input, } } return &vertexEmbeddingRequest{Instances: instances} } type vertexChatRequest struct { CachedContent string `json:"cachedContent,omitempty"` Contents []vertexChatContent `json:"contents"` SystemInstruction *vertexSystemInstruction `json:"systemInstruction,omitempty"` Tools []vertexTool `json:"tools,omitempty"` SafetySettings []vertexChatSafetySetting `json:"safetySettings,omitempty"` GenerationConfig vertexChatGenerationConfig `json:"generationConfig,omitempty"` Labels map[string]string `json:"labels,omitempty"` } type vertexChatContent struct { // The producer of the content. Must be either 'user' or 'model'. Role string `json:"role,omitempty"` Parts []vertexPart `json:"parts"` } type vertexPart struct { Text string `json:"text,omitempty"` InlineData *blob `json:"inlineData,omitempty"` FileData *fileData `json:"fileData,omitempty"` FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"` FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"` Thounght *bool `json:"thought,omitempty"` } type blob struct { MimeType string `json:"mimeType"` Data string `json:"data"` } type fileData struct { MimeType string `json:"mimeType"` FileUri string `json:"fileUri"` } type vertexFunctionCall struct { Name string `json:"name"` Args map[string]interface{} `json:"args,omitempty"` } type vertexFunctionResponse struct { Name string `json:"name"` Response vertexFunctionResponseDetail `json:"response"` } type vertexFunctionResponseDetail struct { Output string `json:"output,omitempty"` Error string `json:"error,omitempty"` } type vertexSystemInstruction struct { Role string `json:"role"` Parts []vertexPart `json:"parts"` } type vertexTool struct { FunctionDeclarations any `json:"functionDeclarations"` } type vertexChatSafetySetting struct { Category string `json:"category"` Threshold string `json:"threshold"` } type vertexChatGenerationConfig struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"topP,omitempty"` TopK int `json:"topK,omitempty"` CandidateCount int `json:"candidateCount,omitempty"` MaxOutputTokens int `json:"maxOutputTokens,omitempty"` ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"` ResponseModalities []string `json:"responseModalities,omitempty"` ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"` } type vertexImageConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` ImageSize string `json:"imageSize,omitempty"` ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"` PersonGeneration string `json:"personGeneration,omitempty"` } type vertexImageOutputOptions struct { MimeType string `json:"mimeType,omitempty"` } type vertexThinkingConfig struct { IncludeThoughts bool `json:"includeThoughts,omitempty"` ThinkingBudget int `json:"thinkingBudget,omitempty"` } type vertexEmbeddingRequest struct { Instances []vertexEmbeddingInstance `json:"instances"` Parameters *vertexEmbeddingParams `json:"parameters,omitempty"` } type vertexEmbeddingInstance struct { TaskType string `json:"task_type"` Title string `json:"title,omitempty"` Content string `json:"content"` } type vertexEmbeddingParams struct { AutoTruncate bool `json:"autoTruncate,omitempty"` } type vertexChatResponse struct { Candidates []vertexChatCandidate `json:"candidates"` ResponseId string `json:"responseId,omitempty"` PromptFeedback vertexChatPromptFeedback `json:"promptFeedback"` UsageMetadata vertexUsageMetadata `json:"usageMetadata"` } type vertexChatCandidate struct { Content vertexChatContent `json:"content"` FinishReason string `json:"finishReason"` Index int `json:"index"` SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"` } type vertexChatSafetyRating struct { Category string `json:"category"` Probability string `json:"probability"` } type vertexChatPromptFeedback struct { SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"` } type vertexUsageMetadata struct { PromptTokenCount int `json:"promptTokenCount,omitempty"` CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"` ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` } type vertexEmbeddingResponse struct { Predictions []vertexPredictions `json:"predictions"` } type vertexPredictions struct { Embeddings struct { Values []float64 `json:"values"` Statistics *vertexStatistics `json:"statistics,omitempty"` } `json:"embeddings"` } type vertexStatistics struct { TokenCount int `json:"token_count"` Truncated bool `json:"truncated"` } type ServiceAccountKey struct { ClientEmail string `json:"client_email"` PrivateKeyID string `json:"private_key_id"` PrivateKey string `json:"private_key"` TokenURI string `json:"token_uri"` } func createJWT(key *ServiceAccountKey) (string, error) { // 解析 PEM 格式的 RSA 私钥 block, _ := pem.Decode([]byte(key.PrivateKey)) if block == nil { return "", fmt.Errorf("invalid PEM block") } parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return "", err } rsaKey := parsedKey.(*rsa.PrivateKey) // 构造 JWT Header jwtHeader := map[string]string{ "alg": "RS256", "typ": "JWT", "kid": key.PrivateKeyID, } headerJSON, _ := json.Marshal(jwtHeader) headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) // 构造 JWT Claims now := time.Now().Unix() claims := map[string]interface{}{ "iss": key.ClientEmail, "scope": "https://www.googleapis.com/auth/cloud-platform", "aud": key.TokenURI, "iat": now, "exp": now + 3600, // 1 小时有效期 } claimsJSON, _ := json.Marshal(claims) claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) signingInput := fmt.Sprintf("%s.%s", headerB64, claimsB64) hashed := sha256.Sum256([]byte(signingInput)) signature, err := rsaKey.Sign(nil, hashed[:], crypto.SHA256) if err != nil { return "", err } sigB64 := base64.RawURLEncoding.EncodeToString(signature) return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, sigB64), nil } func (v *vertexProvider) getAccessToken(jwtToken string) error { headers := [][2]string{ {"Content-Type", "application/x-www-form-urlencoded"}, } reqBody := "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=" + jwtToken err := v.client.Post("/token", headers, []byte(reqBody), func(statusCode int, responseHeaders http.Header, responseBody []byte) { responseString := string(responseBody) defer func() { _ = proxywasm.ResumeHttpRequest() }() if statusCode != http.StatusOK { log.Errorf("failed to create vertex access key, status: %d body: %s", statusCode, responseString) _ = util.ErrorHandler("ai-proxy.vertex.load_ak_failed", fmt.Errorf("failed to load vertex ak")) return } responseJson := gjson.Parse(responseString) accessToken := responseJson.Get("access_token").String() _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+accessToken) expiresIn := int64(3600) if expiresInVal := responseJson.Get("expires_in"); expiresInVal.Exists() { expiresIn = expiresInVal.Int() } expireTime := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix() keyName := v.buildTokenKey() err := setCachedAccessToken(keyName, accessToken, expireTime) if err != nil { log.Errorf("[vertex]: unable to cache access token: %v", err) } }, v.config.timeout) return err } func (v *vertexProvider) buildTokenKey() string { region := v.config.vertexRegion projectID := v.config.vertexProjectId return fmt.Sprintf("vertex-%s-%s-access-token", region, projectID) } type cachedAccessToken struct { Token string `json:"token"` ExpireAt int64 `json:"expireAt"` } func (v *vertexProvider) getCachedAccessToken(key string) (string, error) { data, _, err := proxywasm.GetSharedData(key) if err != nil { if errors.Is(err, types.ErrorStatusNotFound) { return "", nil } return "", err } if data == nil { return "", nil } var tokenInfo cachedAccessToken if err = json.Unmarshal(data, &tokenInfo); err != nil { return "", err } now := time.Now().Unix() refreshAhead := v.config.vertexTokenRefreshAhead if tokenInfo.ExpireAt > now+refreshAhead { return tokenInfo.Token, nil } return "", nil } func setCachedAccessToken(key string, accessToken string, expireTime int64) error { tokenInfo := cachedAccessToken{ Token: accessToken, ExpireAt: expireTime, } _, cas, err := proxywasm.GetSharedData(key) if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { return err } data, err := json.Marshal(tokenInfo) if err != nil { return err } return proxywasm.SetSharedData(key, data, cas) } // convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式 // 支持图片、视频、音频等多种媒体类型 func convertMediaContent(mediaUrl string) (vertexPart, error) { part := vertexPart{} if strings.HasPrefix(mediaUrl, "http") { mimeType := detectMimeTypeFromURL(mediaUrl) part.FileData = &fileData{ MimeType: mimeType, FileUri: mediaUrl, } return part, nil } else { // Base64 data URL 格式: data:;base64, re := regexp.MustCompile(`^data:([^;]+);base64,`) matches := re.FindStringSubmatch(mediaUrl) if len(matches) < 2 { return part, fmt.Errorf("invalid base64 format, expected data:;base64,") } mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3 parts := strings.Split(mimeType, "/") if len(parts) < 2 { return part, fmt.Errorf("invalid mimeType: %s", mimeType) } part.InlineData = &blob{ MimeType: mimeType, Data: strings.TrimPrefix(mediaUrl, matches[0]), } return part, nil } } // detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型 // 支持图片、视频、音频和文档类型 func detectMimeTypeFromURL(url string) string { // 移除查询参数和片段标识符 if idx := strings.Index(url, "?"); idx != -1 { url = url[:idx] } if idx := strings.Index(url, "#"); idx != -1 { url = url[:idx] } // 获取最后一个路径段 lastSlash := strings.LastIndex(url, "/") if lastSlash != -1 { url = url[lastSlash+1:] } // 获取扩展名 lastDot := strings.LastIndex(url, ".") if lastDot == -1 || lastDot == len(url)-1 { return "application/octet-stream" } ext := strings.ToLower(url[lastDot+1:]) // 扩展名到 MIME 类型的映射 mimeTypes := map[string]string{ // 图片格式 "jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", "gif": "image/gif", "webp": "image/webp", "bmp": "image/bmp", "svg": "image/svg+xml", "ico": "image/x-icon", "heic": "image/heic", "heif": "image/heif", "tiff": "image/tiff", "tif": "image/tiff", // 视频格式 "mp4": "video/mp4", "mpeg": "video/mpeg", "mpg": "video/mpeg", "mov": "video/quicktime", "avi": "video/x-msvideo", "wmv": "video/x-ms-wmv", "webm": "video/webm", "mkv": "video/x-matroska", "flv": "video/x-flv", "3gp": "video/3gpp", "3g2": "video/3gpp2", "m4v": "video/x-m4v", // 音频格式 "mp3": "audio/mpeg", "wav": "audio/wav", "ogg": "audio/ogg", "flac": "audio/flac", "aac": "audio/aac", "m4a": "audio/mp4", "wma": "audio/x-ms-wma", "opus": "audio/opus", // 文档格式 "pdf": "application/pdf", } if mimeType, ok := mimeTypes[ext]; ok { return mimeType } return "application/octet-stream" }