fix(vertex): add API Key auth for Vertex Raw Express Mode and fix tok… (#3695)

This commit is contained in:
woody
2026-04-10 09:55:27 +08:00
committed by GitHub
parent 2c15f97246
commit bf96860a78
3 changed files with 460 additions and 30 deletions

View File

@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
@@ -224,6 +225,34 @@ func (v *vertexProvider) getToken() (cached bool, err error) {
return false, err
}
func appendOrReplaceAPIKey(path, apiKey string) string {
if apiKey == "" {
return path
}
parsedPath, err := url.ParseRequestURI(path)
if err != nil {
// Fallback to simple append when path is not parseable.
if strings.Contains(path, "?") {
return path + "&key=" + apiKey
}
return path + "?key=" + apiKey
}
query := parsedPath.Query()
query.Set("key", apiKey)
parsedPath.RawQuery = query.Encode()
return parsedPath.RequestURI()
}
func (v *vertexProvider) getExpressAPIKey(ctx wrapper.HttpContext) string {
apiKey := v.config.GetApiTokenInUse(ctx)
if apiKey == "" {
apiKey = v.config.GetRandomToken()
}
return apiKey
}
func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !v.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
@@ -234,8 +263,14 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
if apiName == ApiNameVertexRaw {
ctx.SetContext(contextVertexRawMarker, true)
// Express Mode 不需要 OAuth 认证
// Express Mode: 将 API Key 追加到 URL query 参数中
if v.isExpressMode() {
headers := util.GetRequestHeaders()
path := headers.Get(":path")
path = appendOrReplaceAPIKey(path, v.getExpressAPIKey(ctx))
util.OverwriteRequestPathHeader(headers, path)
headers.Del("Authorization")
util.ReplaceRequestHeaders(headers)
return types.ActionContinue, nil
}
// 标准模式需要获取 OAuth token
@@ -354,7 +389,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
}
if strings.HasPrefix(request.Model, "claude") {
ctx.SetContext(contextClaudeMarker, true)
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
path := v.getAhthropicRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
@@ -366,7 +401,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
}
return claudeBody, nil
} else {
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
path := v.getRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest, err := v.buildVertexChatRequest(request)
@@ -382,7 +417,7 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
path := v.getRequestPath(ApiNameEmbeddings, request.Model, false)
path := v.getRequestPath(ctx, ApiNameEmbeddings, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildEmbeddingRequest(request)
@@ -395,7 +430,7 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
return nil, err
}
// 图片生成不使用流式端点,需要完整响应
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
path := v.getRequestPath(ctx, ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
@@ -442,7 +477,7 @@ func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []
return nil, fmt.Errorf("missing prompt in request")
}
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
path := v.getRequestPath(ctx, ApiNameImageEdit, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
@@ -485,7 +520,7 @@ func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, bo
prompt = vertexImageVariationDefaultPrompt
}
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
path := v.getRequestPath(ctx, ApiNameImageVariation, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
@@ -909,7 +944,7 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
func (v *vertexProvider) getAhthropicRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
action := ""
if stream {
action = vertexAnthropicMessageStreamAction
@@ -920,22 +955,15 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?,使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
apiKey := v.getExpressAPIKey(ctx)
return appendOrReplaceAPIKey(basePath, apiKey)
}
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
func (v *vertexProvider) getRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
action := ""
switch apiName {
case ApiNameEmbeddings:
@@ -954,15 +982,8 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
apiKey := v.getExpressAPIKey(ctx)
return appendOrReplaceAPIKey(basePath, apiKey)
}
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)

View File

@@ -8,6 +8,42 @@ import (
"github.com/stretchr/testify/require"
)
func TestAppendOrReplaceAPIKey(t *testing.T) {
t.Run("empty apiKey returns path unchanged", func(t *testing.T) {
path := "/v1/publishers/google/models/gemini:generateContent"
assert.Equal(t, path, appendOrReplaceAPIKey(path, ""))
})
t.Run("path without query appends ?key=", func(t *testing.T) {
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent", "my-key")
assert.Equal(t, "/v1/models/gemini:generateContent?key=my-key", result)
})
t.Run("path with existing query appends &key=", func(t *testing.T) {
result := appendOrReplaceAPIKey("/v1/models/gemini:streamGenerateContent?alt=sse", "my-key")
assert.Contains(t, result, "alt=sse")
assert.Contains(t, result, "key=my-key")
})
t.Run("existing key parameter is replaced", func(t *testing.T) {
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent?key=old-key&trace=1", "new-key")
assert.Contains(t, result, "key=new-key")
assert.NotContains(t, result, "old-key")
assert.Contains(t, result, "trace=1")
})
t.Run("unparseable path without query falls back to ?key= append", func(t *testing.T) {
// A bare string with no leading slash is not a valid RequestURI
result := appendOrReplaceAPIKey("not-a-valid-uri", "my-key")
assert.Equal(t, "not-a-valid-uri?key=my-key", result)
})
t.Run("unparseable path with query falls back to &key= append", func(t *testing.T) {
result := appendOrReplaceAPIKey("not-a-valid-uri?foo=bar", "my-key")
assert.Equal(t, "not-a-valid-uri?foo=bar&key=my-key", result)
})
}
func TestVertexProviderBuildChatRequestStructuredOutputMapping(t *testing.T) {
t.Run("json_object response format", func(t *testing.T) {
v := &vertexProvider{}