diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index ce4b15581..ce6e43c73 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -262,6 +262,19 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下: | `inputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输入变量,当 botType 为 workflow 时一起使用 | | `outputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输出变量,当 botType 为 workflow 时一起使用 | +#### Google Vertex AI + +Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------| +| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key,格式为 PEM 编码的 PKCS#8 私钥和 client_email 等信息 | +| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等),用于构建 Vertex API 地址 | +| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID,用于标识目标 GCP 项目 | +| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称,该服务为了访问oauth2.googleapis.com | +| `vertexGeminiSafetySetting` | map of string | 非必填 | - | Gemini 模型的内容安全过滤设置。 | +| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) | + ## 用法示例 ### 使用 OpenAI 协议代理 Azure OpenAI 服务 @@ -1629,6 +1642,69 @@ provider: } ``` +### 使用 OpenAI 协议代理 Google Vertex 服务 + +**配置信息** + +```yaml +provider: + type: vertex + vertexAuthKey: | + { + "type": "service_account", + "project_id": "your-project-id", + "private_key_id": "your-private-key-id", + "private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n", + "client_email": "your-service-account@your-project.iam.gserviceaccount.com", + "token_uri": "https://oauth2.googleapis.com/token" + } + vertexRegion: us-central1 + vertexProjectId: your-project-id + vertexAuthServiceName: your-auth-service-name +``` + +**请求示例** + +```json +{ + "model": "gemini-2.0-flash-001", + "messages": [ + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "chatcmpl-0000000000000", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "你好!我是 Vertex AI 提供的 Gemini 模型,由 Google 开发的人工智能助手。我可以回答问题、提供信息和帮助完成各种任务。有什么我可以帮您的吗?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.0-flash-001", + "object": "chat.completion", + "usage": { + "prompt_tokens": 15, + "completion_tokens": 43, + "total_tokens": 58 + } +} +``` + + ## 完整配置示例 ### Kubernetes 示例 diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index 67aef4ff3..b0586d896 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -208,6 +208,18 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i | ------------ | --------- | ----------- | ------- | ------------------------------------ | | `targetLang` | string | Required | - | The target language required by the DeepL translation service | +#### Google Vertex AI +For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is: + +| Name | Data Type | Requirement | Default | Description | +|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key used for authentication. The format should be PEM encoded PKCS#8 private key along with client_email and other information | +| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) used to build the Vertex API address | +| `vertexProjectId` | string | Required | - | Google Cloud Project ID, used to identify the target GCP project | +| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication, used to access oauth2.googleapis.com | +| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. | +| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds | + ## Usage Examples ### Using OpenAI Protocol Proxy for Azure OpenAI Service @@ -1411,6 +1423,64 @@ provider: } ``` +### Utilizing OpenAI Protocol Proxy for Google Vertex Services +**Configuration Information** +```yaml +provider: + type: vertex + vertexAuthKey: | + { + "type": "service_account", + "project_id": "your-project-id", + "private_key_id": "your-private-key-id", + "private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n", + "client_email": "your-service-account@your-project.iam.gserviceaccount.com", + "token_uri": "https://oauth2.googleapis.com/token" + } + vertexRegion: us-central1 + vertexProjectId: your-project-id + vertexAuthServiceName: your-auth-service-name +``` + +**Request Example** +```json +{ + "model": "gemini-2.0-flash-001", + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ], + "stream": false +} +``` + +**Response Example** +```json +{ + "id": "chatcmpl-0000000000000", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I am the Gemini model provided by Vertex AI, developed by Google. I can answer questions, provide information, and assist in completing various tasks. How can I help you today?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.0-flash-001", + "object": "chat.completion", + "usage": { + "prompt_tokens": 15, + "completion_tokens": 43, + "total_tokens": 58 + } +} +``` + ## Full Configuration Example ### Kubernetes Example diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 0cd55bdb3..563f868a6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -90,6 +90,7 @@ const ( providerTypeTogetherAI = "together-ai" providerTypeDify = "dify" providerTypeBedrock = "bedrock" + providerTypeVertex = "vertex" protocolOpenAI = "openai" protocolOriginal = "original" @@ -161,6 +162,7 @@ var ( providerTypeTogetherAI: &togetherAIProviderInitializer{}, providerTypeDify: &difyProviderInitializer{}, providerTypeBedrock: &bedrockProviderInitializer{}, + providerTypeVertex: &vertexProviderInitializer{}, } ) @@ -298,6 +300,21 @@ type ProviderConfig struct { // @Title zh-CN Gemini AI内容过滤和安全级别设定 // @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"` + // @Title zh-CN Vertex AI访问区域 + // @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表,请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions + vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"` + // @Title zh-CN Vertex AI项目Id + // @Description zh-CN 仅适用于Vertex AI服务。创建和管理项目请参阅https://cloud.google.com/resource-manager/docs/creating-managing-projects?hl=zh-cn#identifiers + vertexProjectId string `required:"false" yaml:"vertexProjectId" json:"vertexProjectId"` + // @Title zh-CN Vertex 认证秘钥 + // @Description zh-CN 用于Google服务账号认证的完整JSON密钥文件内容,获取可参考https://cloud.google.com/iam/docs/keys-create-delete?hl=zh-cn#iam-service-account-keys-create-console + vertexAuthKey string `required:"false" yaml:"vertexAuthKey" json:"vertexAuthKey"` + // @Title zh-CN Vertex 认证服务名 + // @Description zh-CN 用于Google服务账号认证的服务,DNS类型的服务名 + vertexAuthServiceName string `required:"false" yaml:"vertexAuthServiceName" json:"vertexAuthServiceName"` + // @Title zh-CN Vertex token刷新提前时间 + // @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒 + vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"` // @Title zh-CN 翻译服务需指定的目标语种 // @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。 targetLang string `required:"false" yaml:"targetLang" json:"targetLang"` @@ -390,12 +407,20 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.minimaxApiType = json.Get("minimaxApiType").String() c.minimaxGroupId = json.Get("minimaxGroupId").String() c.cloudflareAccountId = json.Get("cloudflareAccountId").String() - if c.typ == providerTypeGemini { + if c.typ == providerTypeGemini || c.typ == providerTypeVertex { c.geminiSafetySetting = make(map[string]string) for k, v := range json.Get("geminiSafetySetting").Map() { c.geminiSafetySetting[k] = v.String() } } + c.vertexRegion = json.Get("vertexRegion").String() + c.vertexProjectId = json.Get("vertexProjectId").String() + c.vertexAuthKey = json.Get("vertexAuthKey").String() + c.vertexAuthServiceName = json.Get("vertexAuthServiceName").String() + c.vertexTokenRefreshAhead = json.Get("vertexTokenRefreshAhead").Int() + if c.vertexTokenRefreshAhead == 0 { + c.vertexTokenRefreshAhead = 60 + } c.targetLang = json.Get("targetLang").String() if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go new file mode 100644 index 000000000..b0d115d5c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -0,0 +1,668 @@ +package provider + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" +) + +const ( + vertexAuthDomain = "oauth2.googleapis.com" + vertexDomain = "{REGION}-aiplatform.googleapis.com" + // /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION} + vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" + vertexChatCompletionAction = "generateContent" + vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" + vertexEmbeddingAction = "predict" +) + +type vertexProviderInitializer struct { +} + +func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error { + if config.vertexAuthKey == "" { + return errors.New("missing vertexAuthKey in vertex provider config") + } + if config.vertexRegion == "" || config.vertexProjectId == "" { + return errors.New("missing vertexRegion or vertexProjectId in vertex provider config") + } + if config.vertexAuthServiceName == "" { + return errors.New("missing vertexAuthServiceName in vertex provider config") + } + return nil +} + +func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): vertexPathTemplate, + string(ApiNameEmbeddings): vertexPathTemplate, + } +} + +func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(v.DefaultCapabilities()) + return &vertexProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + Domain: vertexAuthDomain, + ServiceName: config.vertexAuthServiceName, + Port: 443, + }), + contextCache: createContextCache(&config), + }, nil +} + +type vertexProvider struct { + client wrapper.HttpClient + config ProviderConfig + contextCache *contextCache +} + +func (v *vertexProvider) GetProviderType() string { + return providerTypeVertex +} + +func (v *vertexProvider) GetApiName(path string) ApiName { + if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) { + return ApiNameChatCompletion + } + if strings.HasSuffix(path, vertexEmbeddingAction) { + return ApiNameEmbeddings + } + return "" +} + +func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + v.config.handleRequestHeaders(v, ctx, apiName) + return nil +} + +func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + vertexRegionDomain := strings.Replace(vertexDomain, "{REGION}", v.config.vertexRegion, 1) + util.OverwriteRequestHostHeader(headers, vertexRegionDomain) +} + +func (v *vertexProvider) getToken() (cached bool, err error) { + cacheKeyName := v.buildTokenKey() + cachedAccessToken, err := v.getCachedAccessToken(cacheKeyName) + if err == nil && cachedAccessToken != "" { + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+cachedAccessToken) + return true, nil + } + + var key ServiceAccountKey + if err := json.Unmarshal([]byte(v.config.vertexAuthKey), &key); err != nil { + return false, fmt.Errorf("[vertex]: unable to unmarshal auth key json: %v", err) + } + + if key.ClientEmail == "" || key.PrivateKey == "" || key.TokenURI == "" { + return false, fmt.Errorf("[vertex]: missing auth params") + } + + jwtToken, err := createJWT(&key) + if err != nil { + log.Errorf("[vertex]: unable to create JWT token: %v", err) + return false, err + } + + err = v.getAccessToken(jwtToken) + if err != nil { + log.Errorf("[vertex]: unable to get access token: %v", err) + return false, err + } + + return false, err +} + +func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + if !v.config.isSupportedAPI(apiName) { + return types.ActionContinue, errUnsupportedApiName + } + if v.config.IsOriginal() { + return types.ActionContinue, nil + } + headers := util.GetOriginalRequestHeaders() + body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) + util.ReplaceRequestHeaders(headers) + _ = proxywasm.ReplaceHttpRequestBody(body) + if err != nil { + return types.ActionContinue, err + } + cached, err := v.getToken() + if cached { + return types.ActionContinue, nil + } + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} + +func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return v.onChatCompletionRequestBody(ctx, body, headers) + } else { + return v.onEmbeddingsRequestBody(ctx, body, headers) + } +} + +func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &chatCompletionRequest{} + err := v.config.parseRequestAndMapModel(ctx, request, body) + if err != nil { + return nil, err + } + path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) + + vertexRequest := v.buildVertexChatRequest(request) + return json.Marshal(vertexRequest) +} + +func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &embeddingsRequest{} + if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + path := v.getRequestPath(ApiNameEmbeddings, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) + + vertexRequest := v.buildEmbeddingRequest(request) + return json.Marshal(vertexRequest) +} + +func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + log.Infof("[vertexProvider] receive chunk body: %s", string(chunk)) + if isLastChunk || len(chunk) == 0 { + return nil, nil + } + if name != ApiNameChatCompletion { + return chunk, nil + } + responseBuilder := &strings.Builder{} + lines := strings.Split(string(chunk), "\n") + for _, data := range lines { + if len(data) < 6 { + // ignore blank line or wrong format + continue + } + data = data[6:] + var vertexResp vertexChatResponse + if err := json.Unmarshal([]byte(data), &vertexResp); err != nil { + log.Errorf("unable to unmarshal vertex response: %v", err) + continue + } + response := v.buildChatCompletionStreamResponse(ctx, &vertexResp) + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + v.appendResponse(responseBuilder, string(responseBody)) + } + modifiedResponseChunk := responseBuilder.String() + log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return v.onChatCompletionResponseBody(ctx, body) + } else { + return v.onEmbeddingsResponseBody(ctx, body) + } +} + +func (v *vertexProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + vertexResponse := &vertexChatResponse{} + if err := json.Unmarshal(body, vertexResponse); err != nil { + return nil, fmt.Errorf("unable to unmarshal vertex chat response: %v", err) + } + response := v.buildChatCompletionResponse(ctx, vertexResponse) + return json.Marshal(response) +} + +func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *vertexChatResponse) *chatCompletionResponse { + fullTextResponse := chatCompletionResponse{ + Id: response.ResponseId, + Object: objectChatCompletion, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Choices: make([]chatCompletionChoice, 0, len(response.Candidates)), + Usage: usage{ + PromptTokens: response.UsageMetadata.PromptTokenCount, + CompletionTokens: response.UsageMetadata.CandidatesTokenCount, + TotalTokens: response.UsageMetadata.TotalTokenCount, + }, + } + for _, candidate := range response.Candidates { + choice := chatCompletionChoice{ + Index: candidate.Index, + Message: &chatMessage{ + Role: roleAssistant, + }, + FinishReason: candidate.FinishReason, + } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } else { + choice.Message.Content = "" + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func (v *vertexProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + vertexResponse := &vertexEmbeddingResponse{} + if err := json.Unmarshal(body, vertexResponse); err != nil { + return nil, fmt.Errorf("unable to unmarshal vertex embeddings response: %v", err) + } + response := v.buildEmbeddingsResponse(ctx, vertexResponse) + return json.Marshal(response) +} + +func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertexResp *vertexEmbeddingResponse) *embeddingsResponse { + response := embeddingsResponse{ + Object: "list", + Data: make([]embedding, 0, len(vertexResp.Predictions)), + Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + } + totalTokens := 0 + for _, item := range vertexResp.Predictions { + response.Data = append(response.Data, embedding{ + Object: `embedding`, + Index: 0, + Embedding: item.Embeddings.Values, + }) + if item.Embeddings.Statistics != nil { + totalTokens += item.Embeddings.Statistics.TokenCount + } + } + response.Usage.TotalTokens = totalTokens + return &response +} + +func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse { + var choice chatCompletionChoice + if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 { + choice.Delta = &chatMessage{Content: vertexResp.Candidates[0].Content.Parts[0].Text} + } + streamResponse := chatCompletionResponse{ + Id: vertexResp.ResponseId, + Object: objectChatCompletionChunk, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Choices: []chatCompletionChoice{choice}, + Usage: usage{ + PromptTokens: vertexResp.UsageMetadata.PromptTokenCount, + CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount, + TotalTokens: vertexResp.UsageMetadata.TotalTokenCount, + }, + } + return &streamResponse +} + +func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { + responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) +} + +func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { + action := "" + if apiName == ApiNameEmbeddings { + action = vertexEmbeddingAction + } else if stream { + action = vertexChatCompletionStreamAction + } else { + action = vertexChatCompletionAction + } + return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) +} + +func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest { + safetySettings := make([]vertexChatSafetySetting, 0) + for category, threshold := range v.config.geminiSafetySetting { + safetySettings = append(safetySettings, vertexChatSafetySetting{ + Category: category, + Threshold: threshold, + }) + } + vertexRequest := vertexChatRequest{ + Contents: make([]vertexChatContent, 0), + SafetySettings: safetySettings, + GenerationConfig: vertexChatGenerationConfig{ + Temperature: request.Temperature, + TopP: request.TopP, + MaxOutputTokens: request.MaxTokens, + }, + } + if request.Tools != nil { + functions := make([]function, 0, len(request.Tools)) + for _, tool := range request.Tools { + functions = append(functions, tool.Function) + } + vertexRequest.Tools = []vertexTool{ + { + FunctionDeclarations: functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range request.Messages { + content := vertexChatContent{ + Role: message.Role, + Parts: []vertexPart{ + { + Text: message.StringContent(), + }, + }, + } + + // there's no assistant role in vertex and API shall vomit if role is not user or model + if content.Role == roleAssistant { + content.Role = "model" + } else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason + content.Role = roleUser + shouldAddDummyModelMessage = true + } + vertexRequest.Contents = append(vertexRequest.Contents, content) + + // if a system message is the last message, we need to add a dummy model message to make vertex happy + if shouldAddDummyModelMessage { + vertexRequest.Contents = append(vertexRequest.Contents, vertexChatContent{ + Role: "model", + Parts: []vertexPart{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &vertexRequest +} + +func (v *vertexProvider) buildEmbeddingRequest(request *embeddingsRequest) *vertexEmbeddingRequest { + inputs := request.ParseInput() + instances := make([]vertexEmbeddingInstance, len(inputs)) + for i, input := range inputs { + instances[i] = vertexEmbeddingInstance{ + Content: input, + } + } + return &vertexEmbeddingRequest{Instances: instances} +} + +type vertexChatRequest struct { + CachedContent string `json:"cachedContent,omitempty"` + Contents []vertexChatContent `json:"contents"` + SystemInstruction *vertexSystemInstruction `json:"systemInstruction,omitempty"` + Tools []vertexTool `json:"tools,omitempty"` + SafetySettings []vertexChatSafetySetting `json:"safetySettings,omitempty"` + GenerationConfig vertexChatGenerationConfig `json:"generationConfig,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +type vertexChatContent struct { + // The producer of the content. Must be either 'user' or 'model'. + Role string `json:"role,omitempty"` + Parts []vertexPart `json:"parts"` +} + +type vertexPart struct { + Text string `json:"text,omitempty"` + InlineData *blob `json:"inlineData,omitempty"` + FileData *fileData `json:"fileData,omitempty"` +} + +type blob struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type fileData struct { + MimeType string `json:"mimeType"` + FileUri string `json:"fileUri"` +} + +type vertexSystemInstruction struct { + Role string `json:"role"` + Parts []vertexPart `json:"parts"` +} + +type vertexTool struct { + FunctionDeclarations any `json:"functionDeclarations"` +} + +type vertexChatSafetySetting struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type vertexChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` +} + +type vertexEmbeddingRequest struct { + Instances []vertexEmbeddingInstance `json:"instances"` + Parameters *vertexEmbeddingParams `json:"parameters,omitempty"` +} + +type vertexEmbeddingInstance struct { + TaskType string `json:"task_type"` + Title string `json:"title,omitempty"` + Content string `json:"content"` +} + +type vertexEmbeddingParams struct { + AutoTruncate bool `json:"autoTruncate,omitempty"` +} + +type vertexChatResponse struct { + Candidates []vertexChatCandidate `json:"candidates"` + ResponseId string `json:"responseId,omitempty"` + PromptFeedback vertexChatPromptFeedback `json:"promptFeedback"` + UsageMetadata vertexUsageMetadata `json:"usageMetadata"` +} + +type vertexChatCandidate struct { + Content vertexChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"` +} + +type vertexChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type vertexChatPromptFeedback struct { + SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"` +} + +type vertexUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` +} + +type vertexEmbeddingResponse struct { + Predictions []vertexPredictions `json:"predictions"` +} + +type vertexPredictions struct { + Embeddings struct { + Values []float64 `json:"values"` + Statistics *vertexStatistics `json:"statistics,omitempty"` + } `json:"embeddings"` +} + +type vertexStatistics struct { + TokenCount int `json:"token_count"` + Truncated bool `json:"truncated"` +} + +type ServiceAccountKey struct { + ClientEmail string `json:"client_email"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + TokenURI string `json:"token_uri"` +} + +func createJWT(key *ServiceAccountKey) (string, error) { + // 解析 PEM 格式的 RSA 私钥 + block, _ := pem.Decode([]byte(key.PrivateKey)) + if block == nil { + return "", fmt.Errorf("invalid PEM block") + } + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return "", err + } + rsaKey := parsedKey.(*rsa.PrivateKey) + + // 构造 JWT Header + jwtHeader := map[string]string{ + "alg": "RS256", + "typ": "JWT", + "kid": key.PrivateKeyID, + } + headerJSON, _ := json.Marshal(jwtHeader) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + + // 构造 JWT Claims + now := time.Now().Unix() + claims := map[string]interface{}{ + "iss": key.ClientEmail, + "scope": "https://www.googleapis.com/auth/cloud-platform", + "aud": key.TokenURI, + "iat": now, + "exp": now + 3600, // 1 小时有效期 + } + claimsJSON, _ := json.Marshal(claims) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + signingInput := fmt.Sprintf("%s.%s", headerB64, claimsB64) + hashed := sha256.Sum256([]byte(signingInput)) + signature, err := rsaKey.Sign(nil, hashed[:], crypto.SHA256) + if err != nil { + return "", err + } + sigB64 := base64.RawURLEncoding.EncodeToString(signature) + + return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, sigB64), nil +} + +func (v *vertexProvider) getAccessToken(jwtToken string) error { + headers := [][2]string{ + {"Content-Type", "application/x-www-form-urlencoded"}, + } + reqBody := "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=" + jwtToken + err := v.client.Post("/token", headers, []byte(reqBody), func(statusCode int, responseHeaders http.Header, responseBody []byte) { + responseString := string(responseBody) + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if statusCode != http.StatusOK { + log.Errorf("failed to create vertex access key, status: %d body: %s", statusCode, responseString) + _ = util.ErrorHandler("ai-proxy.vertex.load_ak_failed", fmt.Errorf("failed to load vertex ak")) + return + } + responseJson := gjson.Parse(responseString) + accessToken := responseJson.Get("access_token").String() + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+accessToken) + + expiresIn := int64(3600) + if expiresInVal := responseJson.Get("expires_in"); expiresInVal.Exists() { + expiresIn = expiresInVal.Int() + } + expireTime := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix() + keyName := v.buildTokenKey() + err := setCachedAccessToken(keyName, accessToken, expireTime) + if err != nil { + log.Errorf("[vertex]: unable to cache access token: %v", err) + } + }, v.config.timeout) + return err +} + +func (v *vertexProvider) buildTokenKey() string { + region := v.config.vertexRegion + projectID := v.config.vertexProjectId + + return fmt.Sprintf("vertex-%s-%s-access-token", region, projectID) +} + +type cachedAccessToken struct { + Token string `json:"token"` + ExpireAt int64 `json:"expireAt"` +} + +func (v *vertexProvider) getCachedAccessToken(key string) (string, error) { + data, _, err := proxywasm.GetSharedData(key) + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return "", nil + } + return "", err + } + if data == nil { + return "", nil + } + + var tokenInfo cachedAccessToken + if err = json.Unmarshal(data, &tokenInfo); err != nil { + return "", err + } + + now := time.Now().Unix() + refreshAhead := v.config.vertexTokenRefreshAhead + + if tokenInfo.ExpireAt > now+refreshAhead { + return tokenInfo.Token, nil + } + + return "", nil +} + +func setCachedAccessToken(key string, accessToken string, expireTime int64) error { + tokenInfo := cachedAccessToken{ + Token: accessToken, + ExpireAt: expireTime, + } + + _, cas, err := proxywasm.GetSharedData(key) + if err != nil && !errors.Is(err, types.ErrorStatusNotFound) { + return err + } + + data, err := json.Marshal(tokenInfo) + if err != nil { + return err + } + + return proxywasm.SetSharedData(key, data, cas) +}