mirror of
https://github.com/alibaba/higress.git
synced 2026-05-27 06:07:27 +08:00
fix(vertex): add API Key auth for Vertex Raw Express Mode and fix tok… (#3695)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -224,6 +225,34 @@ func (v *vertexProvider) getToken() (cached bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
func appendOrReplaceAPIKey(path, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return path
|
||||
}
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(path)
|
||||
if err != nil {
|
||||
// Fallback to simple append when path is not parseable.
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&key=" + apiKey
|
||||
}
|
||||
return path + "?key=" + apiKey
|
||||
}
|
||||
|
||||
query := parsedPath.Query()
|
||||
query.Set("key", apiKey)
|
||||
parsedPath.RawQuery = query.Encode()
|
||||
return parsedPath.RequestURI()
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getExpressAPIKey(ctx wrapper.HttpContext) string {
|
||||
apiKey := v.config.GetApiTokenInUse(ctx)
|
||||
if apiKey == "" {
|
||||
apiKey = v.config.GetRandomToken()
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !v.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
@@ -234,8 +263,14 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
|
||||
if apiName == ApiNameVertexRaw {
|
||||
ctx.SetContext(contextVertexRawMarker, true)
|
||||
// Express Mode 不需要 OAuth 认证
|
||||
// Express Mode: 将 API Key 追加到 URL query 参数中
|
||||
if v.isExpressMode() {
|
||||
headers := util.GetRequestHeaders()
|
||||
path := headers.Get(":path")
|
||||
path = appendOrReplaceAPIKey(path, v.getExpressAPIKey(ctx))
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
// 标准模式需要获取 OAuth token
|
||||
@@ -354,7 +389,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "claude") {
|
||||
ctx.SetContext(contextClaudeMarker, true)
|
||||
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
path := v.getAhthropicRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
|
||||
@@ -366,7 +401,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
}
|
||||
return claudeBody, nil
|
||||
} else {
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
path := v.getRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest, err := v.buildVertexChatRequest(request)
|
||||
@@ -382,7 +417,7 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameEmbeddings, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildEmbeddingRequest(request)
|
||||
@@ -395,7 +430,7 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
return nil, err
|
||||
}
|
||||
// 图片生成不使用流式端点,需要完整响应
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
@@ -442,7 +477,7 @@ func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
@@ -485,7 +520,7 @@ func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, bo
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
@@ -909,7 +944,7 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
func (v *vertexProvider) getAhthropicRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if stream {
|
||||
action = vertexAnthropicMessageStreamAction
|
||||
@@ -920,22 +955,15 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
apiKey := v.getExpressAPIKey(ctx)
|
||||
return appendOrReplaceAPIKey(basePath, apiKey)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
func (v *vertexProvider) getRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
@@ -954,15 +982,8 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
apiKey := v.getExpressAPIKey(ctx)
|
||||
return appendOrReplaceAPIKey(basePath, apiKey)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
@@ -8,6 +8,42 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppendOrReplaceAPIKey(t *testing.T) {
|
||||
t.Run("empty apiKey returns path unchanged", func(t *testing.T) {
|
||||
path := "/v1/publishers/google/models/gemini:generateContent"
|
||||
assert.Equal(t, path, appendOrReplaceAPIKey(path, ""))
|
||||
})
|
||||
|
||||
t.Run("path without query appends ?key=", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent", "my-key")
|
||||
assert.Equal(t, "/v1/models/gemini:generateContent?key=my-key", result)
|
||||
})
|
||||
|
||||
t.Run("path with existing query appends &key=", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:streamGenerateContent?alt=sse", "my-key")
|
||||
assert.Contains(t, result, "alt=sse")
|
||||
assert.Contains(t, result, "key=my-key")
|
||||
})
|
||||
|
||||
t.Run("existing key parameter is replaced", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent?key=old-key&trace=1", "new-key")
|
||||
assert.Contains(t, result, "key=new-key")
|
||||
assert.NotContains(t, result, "old-key")
|
||||
assert.Contains(t, result, "trace=1")
|
||||
})
|
||||
|
||||
t.Run("unparseable path without query falls back to ?key= append", func(t *testing.T) {
|
||||
// A bare string with no leading slash is not a valid RequestURI
|
||||
result := appendOrReplaceAPIKey("not-a-valid-uri", "my-key")
|
||||
assert.Equal(t, "not-a-valid-uri?key=my-key", result)
|
||||
})
|
||||
|
||||
t.Run("unparseable path with query falls back to &key= append", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("not-a-valid-uri?foo=bar", "my-key")
|
||||
assert.Equal(t, "not-a-valid-uri?foo=bar&key=my-key", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVertexProviderBuildChatRequestStructuredOutputMapping(t *testing.T) {
|
||||
t.Run("json_object response format", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
|
||||
Reference in New Issue
Block a user