Files
higress/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go
2025-12-20 10:33:53 +08:00

874 lines
27 KiB
Go

package provider
import (
"crypto"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
vertexAuthDomain = "oauth2.googleapis.com"
vertexDomain = "aiplatform.googleapis.com"
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
vertexAnthropicMessageStreamAction = "streamRawPredict"
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
contextClaudeMarker = "isClaudeRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
)
type vertexProviderInitializer struct{}
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config")
}
if config.vertexRegion == "" || config.vertexProjectId == "" {
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config")
}
if config.vertexAuthServiceName == "" {
return errors.New("missing vertexAuthServiceName in vertex provider config")
}
return nil
}
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
}
}
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(v.DefaultCapabilities())
return &vertexProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
}),
contextCache: createContextCache(&config),
claude: &claudeProvider{
config: config,
contextCache: createContextCache(&config),
},
}, nil
}
type vertexProvider struct {
client wrapper.HttpClient
config ProviderConfig
contextCache *contextCache
claude *claudeProvider
}
func (v *vertexProvider) GetProviderType() string {
return providerTypeVertex
}
func (v *vertexProvider) GetApiName(path string) ApiName {
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
return ApiNameChatCompletion
}
if strings.HasSuffix(path, vertexEmbeddingAction) {
return ApiNameEmbeddings
}
return ""
}
func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
v.config.handleRequestHeaders(v, ctx, apiName)
return nil
}
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
var finalVertexDomain string
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
finalVertexDomain = vertexDomain
}
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
}
func (v *vertexProvider) getToken() (cached bool, err error) {
cacheKeyName := v.buildTokenKey()
cachedAccessToken, err := v.getCachedAccessToken(cacheKeyName)
if err == nil && cachedAccessToken != "" {
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+cachedAccessToken)
return true, nil
}
var key ServiceAccountKey
if err := json.Unmarshal([]byte(v.config.vertexAuthKey), &key); err != nil {
return false, fmt.Errorf("[vertex]: unable to unmarshal auth key json: %v", err)
}
if key.ClientEmail == "" || key.PrivateKey == "" || key.TokenURI == "" {
return false, fmt.Errorf("[vertex]: missing auth params")
}
jwtToken, err := createJWT(&key)
if err != nil {
log.Errorf("[vertex]: unable to create JWT token: %v", err)
return false, err
}
err = v.getAccessToken(jwtToken)
if err != nil {
log.Errorf("[vertex]: unable to get access token: %v", err)
return false, err
}
return false, err
}
func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !v.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
if v.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return types.ActionContinue, err
}
cached, err := v.getToken()
if cached {
return types.ActionContinue, nil
}
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return v.onChatCompletionRequestBody(ctx, body, headers)
} else {
return v.onEmbeddingsRequestBody(ctx, body, headers)
}
}
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := v.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
if strings.HasPrefix(request.Model, "claude") {
ctx.SetContext(contextClaudeMarker, true)
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
claudeRequest.Model = ""
claudeRequest.AnthropicVersion = vertexAnthropicVersion
claudeBody, err := json.Marshal(claudeRequest)
if err != nil {
return nil, err
}
return claudeBody, nil
} else {
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexChatRequest(request)
return json.Marshal(vertexRequest)
}
}
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &embeddingsRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
path := v.getRequestPath(ApiNameEmbeddings, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildEmbeddingRequest(request)
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
}
if len(chunk) == 0 {
return nil, nil
}
if name != ApiNameChatCompletion {
return chunk, nil
}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var vertexResp vertexChatResponse
if err := json.Unmarshal([]byte(data), &vertexResp); err != nil {
log.Errorf("unable to unmarshal vertex response: %v", err)
continue
}
response := v.buildChatCompletionStreamResponse(ctx, &vertexResp)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
v.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
if apiName == ApiNameChatCompletion {
return v.onChatCompletionResponseBody(ctx, body)
} else {
return v.onEmbeddingsResponseBody(ctx, body)
}
}
func (v *vertexProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
vertexResponse := &vertexChatResponse{}
if err := json.Unmarshal(body, vertexResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal vertex chat response: %v", err)
}
response := v.buildChatCompletionResponse(ctx, vertexResponse)
return json.Marshal(response)
}
func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *vertexChatResponse) *chatCompletionResponse {
fullTextResponse := chatCompletionResponse{
Id: response.ResponseId,
Object: objectChatCompletion,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: make([]chatCompletionChoice, 0, len(response.Candidates)),
Usage: &usage{
PromptTokens: response.UsageMetadata.PromptTokenCount,
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
TotalTokens: response.UsageMetadata.TotalTokenCount,
},
}
for _, candidate := range response.Candidates {
choice := chatCompletionChoice{
Index: candidate.Index,
Message: &chatMessage{
Role: roleAssistant,
},
FinishReason: util.Ptr(candidate.FinishReason),
}
if len(candidate.Content.Parts) > 0 {
part := candidate.Content.Parts[0]
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
choice.Message.ToolCalls = []toolCall{
{
Type: "function",
Function: functionCall{
Name: part.FunctionCall.Name,
Arguments: string(args),
},
},
}
} else if part.Thounght != nil && len(candidate.Content.Parts) > 1 {
choice.Message.Content = reasoningStartTag + part.Text + reasoningEndTag + candidate.Content.Parts[1].Text
} else if part.Text != "" {
choice.Message.Content = part.Text
}
} else {
choice.Message.Content = ""
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func (v *vertexProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
vertexResponse := &vertexEmbeddingResponse{}
if err := json.Unmarshal(body, vertexResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal vertex embeddings response: %v", err)
}
response := v.buildEmbeddingsResponse(ctx, vertexResponse)
return json.Marshal(response)
}
func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertexResp *vertexEmbeddingResponse) *embeddingsResponse {
response := embeddingsResponse{
Object: "list",
Data: make([]embedding, 0, len(vertexResp.Predictions)),
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
}
totalTokens := 0
for _, item := range vertexResp.Predictions {
response.Data = append(response.Data, embedding{
Object: `embedding`,
Index: 0,
Embedding: item.Embeddings.Values,
})
if item.Embeddings.Statistics != nil {
totalTokens += item.Embeddings.Statistics.TokenCount
}
}
response.Usage.TotalTokens = totalTokens
return &response
}
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
choice.Delta = &chatMessage{}
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
part := vertexResp.Candidates[0].Content.Parts[0]
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
choice.Delta = &chatMessage{
ToolCalls: []toolCall{
{
Type: "function",
Function: functionCall{
Name: part.FunctionCall.Name,
Arguments: string(args),
},
},
},
}
} else if part.Thounght != nil {
if ctx.GetContext("thinking_start") == nil {
choice.Delta = &chatMessage{Content: reasoningStartTag + part.Text}
ctx.SetContext("thinking_start", true)
} else {
choice.Delta = &chatMessage{Content: part.Text}
}
} else if part.Text != "" {
if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil {
choice.Delta = &chatMessage{Content: reasoningEndTag + part.Text}
ctx.SetContext("thinking_end", true)
} else {
choice.Delta = &chatMessage{Content: part.Text}
}
}
}
streamResponse := chatCompletionResponse{
Id: vertexResp.ResponseId,
Object: objectChatCompletionChunk,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: []chatCompletionChoice{choice},
Usage: &usage{
PromptTokens: vertexResp.UsageMetadata.PromptTokenCount,
CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: vertexResp.UsageMetadata.TotalTokenCount,
},
}
return &streamResponse
}
func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if stream {
action = vertexAnthropicMessageStreamAction
} else {
action = vertexAnthropicMessageAction
}
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
action = vertexEmbeddingAction
} else if stream {
action = vertexChatCompletionStreamAction
} else {
action = vertexChatCompletionAction
}
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
}
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {
safetySettings = append(safetySettings, vertexChatSafetySetting{
Category: category,
Threshold: threshold,
})
}
vertexRequest := vertexChatRequest{
Contents: make([]vertexChatContent, 0),
SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{
Temperature: request.Temperature,
TopP: request.TopP,
MaxOutputTokens: request.MaxTokens,
},
}
if request.ReasoningEffort != "" {
thinkingBudget := 1024 // default
switch request.ReasoningEffort {
case "low":
thinkingBudget = 1024
case "medium":
thinkingBudget = 4096
case "high":
thinkingBudget = 16384
}
vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: thinkingBudget,
}
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
for _, tool := range request.Tools {
functions = append(functions, tool.Function)
}
vertexRequest.Tools = []vertexTool{
{
FunctionDeclarations: functions,
},
}
}
shouldAddDummyModelMessage := false
var lastFunctionName string
for _, message := range request.Messages {
content := vertexChatContent{
Role: message.Role,
Parts: []vertexPart{},
}
if len(message.ToolCalls) > 0 {
lastFunctionName = message.ToolCalls[0].Function.Name
args := make(map[string]interface{})
if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil {
log.Errorf("unable to unmarshal function arguments: %v", err)
}
content.Parts = append(content.Parts, vertexPart{
FunctionCall: &vertexFunctionCall{
Name: lastFunctionName,
Args: args,
},
})
} else {
for _, part := range message.ParseContent() {
switch part.Type {
case contentTypeText:
if message.Role == roleTool {
content.Parts = append(content.Parts, vertexPart{
FunctionResponse: &vertexFunctionResponse{
Name: lastFunctionName,
Response: vertexFunctionResponseDetail{
Output: part.Text,
},
},
})
} else {
content.Parts = append(content.Parts, vertexPart{
Text: part.Text,
})
}
case contentTypeImageUrl:
vpart, err := convertImageContent(part.ImageUrl.Url)
if err != nil {
log.Errorf("unable to convert image content: %v", err)
} else {
content.Parts = append(content.Parts, vpart)
}
}
}
}
// there's no assistant role in vertex and API shall vomit if role is not user or model
switch content.Role {
case roleAssistant:
content.Role = "model"
case roleTool:
content.Role = roleUser
case roleSystem: // converting system prompt to prompt from user for the same reason
content.Role = roleUser
shouldAddDummyModelMessage = true
}
vertexRequest.Contents = append(vertexRequest.Contents, content)
// if a system message is the last message, we need to add a dummy model message to make vertex happy
if shouldAddDummyModelMessage {
vertexRequest.Contents = append(vertexRequest.Contents, vertexChatContent{
Role: "model",
Parts: []vertexPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &vertexRequest
}
func (v *vertexProvider) buildEmbeddingRequest(request *embeddingsRequest) *vertexEmbeddingRequest {
inputs := request.ParseInput()
instances := make([]vertexEmbeddingInstance, len(inputs))
for i, input := range inputs {
instances[i] = vertexEmbeddingInstance{
Content: input,
}
}
return &vertexEmbeddingRequest{Instances: instances}
}
type vertexChatRequest struct {
CachedContent string `json:"cachedContent,omitempty"`
Contents []vertexChatContent `json:"contents"`
SystemInstruction *vertexSystemInstruction `json:"systemInstruction,omitempty"`
Tools []vertexTool `json:"tools,omitempty"`
SafetySettings []vertexChatSafetySetting `json:"safetySettings,omitempty"`
GenerationConfig vertexChatGenerationConfig `json:"generationConfig,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
}
type vertexChatContent struct {
// The producer of the content. Must be either 'user' or 'model'.
Role string `json:"role,omitempty"`
Parts []vertexPart `json:"parts"`
}
type vertexPart struct {
Text string `json:"text,omitempty"`
InlineData *blob `json:"inlineData,omitempty"`
FileData *fileData `json:"fileData,omitempty"`
FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"`
Thounght *bool `json:"thought,omitempty"`
}
type blob struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type fileData struct {
MimeType string `json:"mimeType"`
FileUri string `json:"fileUri"`
}
type vertexFunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args,omitempty"`
}
type vertexFunctionResponse struct {
Name string `json:"name"`
Response vertexFunctionResponseDetail `json:"response"`
}
type vertexFunctionResponseDetail struct {
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
}
type vertexSystemInstruction struct {
Role string `json:"role"`
Parts []vertexPart `json:"parts"`
}
type vertexTool struct {
FunctionDeclarations any `json:"functionDeclarations"`
}
type vertexChatSafetySetting struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type vertexChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
}
type vertexThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget int `json:"thinkingBudget,omitempty"`
}
type vertexEmbeddingRequest struct {
Instances []vertexEmbeddingInstance `json:"instances"`
Parameters *vertexEmbeddingParams `json:"parameters,omitempty"`
}
type vertexEmbeddingInstance struct {
TaskType string `json:"task_type"`
Title string `json:"title,omitempty"`
Content string `json:"content"`
}
type vertexEmbeddingParams struct {
AutoTruncate bool `json:"autoTruncate,omitempty"`
}
type vertexChatResponse struct {
Candidates []vertexChatCandidate `json:"candidates"`
ResponseId string `json:"responseId,omitempty"`
PromptFeedback vertexChatPromptFeedback `json:"promptFeedback"`
UsageMetadata vertexUsageMetadata `json:"usageMetadata"`
}
type vertexChatCandidate struct {
Content vertexChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"`
}
type vertexChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type vertexChatPromptFeedback struct {
SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"`
}
type vertexUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
type vertexEmbeddingResponse struct {
Predictions []vertexPredictions `json:"predictions"`
}
type vertexPredictions struct {
Embeddings struct {
Values []float64 `json:"values"`
Statistics *vertexStatistics `json:"statistics,omitempty"`
} `json:"embeddings"`
}
type vertexStatistics struct {
TokenCount int `json:"token_count"`
Truncated bool `json:"truncated"`
}
type ServiceAccountKey struct {
ClientEmail string `json:"client_email"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
TokenURI string `json:"token_uri"`
}
func createJWT(key *ServiceAccountKey) (string, error) {
// 解析 PEM 格式的 RSA 私钥
block, _ := pem.Decode([]byte(key.PrivateKey))
if block == nil {
return "", fmt.Errorf("invalid PEM block")
}
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaKey := parsedKey.(*rsa.PrivateKey)
// 构造 JWT Header
jwtHeader := map[string]string{
"alg": "RS256",
"typ": "JWT",
"kid": key.PrivateKeyID,
}
headerJSON, _ := json.Marshal(jwtHeader)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
// 构造 JWT Claims
now := time.Now().Unix()
claims := map[string]interface{}{
"iss": key.ClientEmail,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": key.TokenURI,
"iat": now,
"exp": now + 3600, // 1 小时有效期
}
claimsJSON, _ := json.Marshal(claims)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := fmt.Sprintf("%s.%s", headerB64, claimsB64)
hashed := sha256.Sum256([]byte(signingInput))
signature, err := rsaKey.Sign(nil, hashed[:], crypto.SHA256)
if err != nil {
return "", err
}
sigB64 := base64.RawURLEncoding.EncodeToString(signature)
return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, sigB64), nil
}
func (v *vertexProvider) getAccessToken(jwtToken string) error {
headers := [][2]string{
{"Content-Type", "application/x-www-form-urlencoded"},
}
reqBody := "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=" + jwtToken
err := v.client.Post("/token", headers, []byte(reqBody), func(statusCode int, responseHeaders http.Header, responseBody []byte) {
responseString := string(responseBody)
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if statusCode != http.StatusOK {
log.Errorf("failed to create vertex access key, status: %d body: %s", statusCode, responseString)
_ = util.ErrorHandler("ai-proxy.vertex.load_ak_failed", fmt.Errorf("failed to load vertex ak"))
return
}
responseJson := gjson.Parse(responseString)
accessToken := responseJson.Get("access_token").String()
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+accessToken)
expiresIn := int64(3600)
if expiresInVal := responseJson.Get("expires_in"); expiresInVal.Exists() {
expiresIn = expiresInVal.Int()
}
expireTime := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
keyName := v.buildTokenKey()
err := setCachedAccessToken(keyName, accessToken, expireTime)
if err != nil {
log.Errorf("[vertex]: unable to cache access token: %v", err)
}
}, v.config.timeout)
return err
}
func (v *vertexProvider) buildTokenKey() string {
region := v.config.vertexRegion
projectID := v.config.vertexProjectId
return fmt.Sprintf("vertex-%s-%s-access-token", region, projectID)
}
type cachedAccessToken struct {
Token string `json:"token"`
ExpireAt int64 `json:"expireAt"`
}
func (v *vertexProvider) getCachedAccessToken(key string) (string, error) {
data, _, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return "", nil
}
return "", err
}
if data == nil {
return "", nil
}
var tokenInfo cachedAccessToken
if err = json.Unmarshal(data, &tokenInfo); err != nil {
return "", err
}
now := time.Now().Unix()
refreshAhead := v.config.vertexTokenRefreshAhead
if tokenInfo.ExpireAt > now+refreshAhead {
return tokenInfo.Token, nil
}
return "", nil
}
func setCachedAccessToken(key string, accessToken string, expireTime int64) error {
tokenInfo := cachedAccessToken{
Token: accessToken,
ExpireAt: expireTime,
}
_, cas, err := proxywasm.GetSharedData(key)
if err != nil && !errors.Is(err, types.ErrorStatusNotFound) {
return err
}
data, err := json.Marshal(tokenInfo)
if err != nil {
return err
}
return proxywasm.SetSharedData(key, data, cas)
}
func convertImageContent(imageUrl string) (vertexPart, error) {
part := vertexPart{}
if strings.HasPrefix(imageUrl, "http") {
arr := strings.Split(imageUrl, ".")
mimeType := "image/" + arr[len(arr)-1]
part.FileData = &fileData{
MimeType: mimeType,
FileUri: imageUrl,
}
return part, nil
} else {
re := regexp.MustCompile(`^data:([^;]+);base64,`)
matches := re.FindStringSubmatch(imageUrl)
if len(matches) < 2 {
return part, fmt.Errorf("invalid base64 format")
}
mimeType := matches[1] // e.g. image/png
parts := strings.Split(mimeType, "/")
if len(parts) < 2 {
return part, fmt.Errorf("invalid mimeType")
}
part.InlineData = &blob{
MimeType: mimeType,
Data: strings.TrimPrefix(imageUrl, matches[0]),
}
return part, nil
}
}