mirror of
https://github.com/alibaba/higress.git
synced 2026-04-20 03:27:26 +08:00
fix(vertex): add API Key auth for Vertex Raw Express Mode and fix tok… (#3695)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -224,6 +225,34 @@ func (v *vertexProvider) getToken() (cached bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
func appendOrReplaceAPIKey(path, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return path
|
||||
}
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(path)
|
||||
if err != nil {
|
||||
// Fallback to simple append when path is not parseable.
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&key=" + apiKey
|
||||
}
|
||||
return path + "?key=" + apiKey
|
||||
}
|
||||
|
||||
query := parsedPath.Query()
|
||||
query.Set("key", apiKey)
|
||||
parsedPath.RawQuery = query.Encode()
|
||||
return parsedPath.RequestURI()
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getExpressAPIKey(ctx wrapper.HttpContext) string {
|
||||
apiKey := v.config.GetApiTokenInUse(ctx)
|
||||
if apiKey == "" {
|
||||
apiKey = v.config.GetRandomToken()
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !v.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
@@ -234,8 +263,14 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
|
||||
if apiName == ApiNameVertexRaw {
|
||||
ctx.SetContext(contextVertexRawMarker, true)
|
||||
// Express Mode 不需要 OAuth 认证
|
||||
// Express Mode: 将 API Key 追加到 URL query 参数中
|
||||
if v.isExpressMode() {
|
||||
headers := util.GetRequestHeaders()
|
||||
path := headers.Get(":path")
|
||||
path = appendOrReplaceAPIKey(path, v.getExpressAPIKey(ctx))
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
// 标准模式需要获取 OAuth token
|
||||
@@ -354,7 +389,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "claude") {
|
||||
ctx.SetContext(contextClaudeMarker, true)
|
||||
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
path := v.getAhthropicRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
|
||||
@@ -366,7 +401,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
}
|
||||
return claudeBody, nil
|
||||
} else {
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
path := v.getRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest, err := v.buildVertexChatRequest(request)
|
||||
@@ -382,7 +417,7 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameEmbeddings, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildEmbeddingRequest(request)
|
||||
@@ -395,7 +430,7 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
return nil, err
|
||||
}
|
||||
// 图片生成不使用流式端点,需要完整响应
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
@@ -442,7 +477,7 @@ func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
@@ -485,7 +520,7 @@ func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, bo
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
path := v.getRequestPath(ctx, ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
@@ -909,7 +944,7 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
func (v *vertexProvider) getAhthropicRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if stream {
|
||||
action = vertexAnthropicMessageStreamAction
|
||||
@@ -920,22 +955,15 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
apiKey := v.getExpressAPIKey(ctx)
|
||||
return appendOrReplaceAPIKey(basePath, apiKey)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
func (v *vertexProvider) getRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
@@ -954,15 +982,8 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
apiKey := v.getExpressAPIKey(ctx)
|
||||
return appendOrReplaceAPIKey(basePath, apiKey)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
@@ -8,6 +8,42 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppendOrReplaceAPIKey(t *testing.T) {
|
||||
t.Run("empty apiKey returns path unchanged", func(t *testing.T) {
|
||||
path := "/v1/publishers/google/models/gemini:generateContent"
|
||||
assert.Equal(t, path, appendOrReplaceAPIKey(path, ""))
|
||||
})
|
||||
|
||||
t.Run("path without query appends ?key=", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent", "my-key")
|
||||
assert.Equal(t, "/v1/models/gemini:generateContent?key=my-key", result)
|
||||
})
|
||||
|
||||
t.Run("path with existing query appends &key=", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:streamGenerateContent?alt=sse", "my-key")
|
||||
assert.Contains(t, result, "alt=sse")
|
||||
assert.Contains(t, result, "key=my-key")
|
||||
})
|
||||
|
||||
t.Run("existing key parameter is replaced", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent?key=old-key&trace=1", "new-key")
|
||||
assert.Contains(t, result, "key=new-key")
|
||||
assert.NotContains(t, result, "old-key")
|
||||
assert.Contains(t, result, "trace=1")
|
||||
})
|
||||
|
||||
t.Run("unparseable path without query falls back to ?key= append", func(t *testing.T) {
|
||||
// A bare string with no leading slash is not a valid RequestURI
|
||||
result := appendOrReplaceAPIKey("not-a-valid-uri", "my-key")
|
||||
assert.Equal(t, "not-a-valid-uri?key=my-key", result)
|
||||
})
|
||||
|
||||
t.Run("unparseable path with query falls back to &key= append", func(t *testing.T) {
|
||||
result := appendOrReplaceAPIKey("not-a-valid-uri?foo=bar", "my-key")
|
||||
assert.Equal(t, "not-a-valid-uri?foo=bar&key=my-key", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVertexProviderBuildChatRequestStructuredOutputMapping(t *testing.T) {
|
||||
t.Run("json_object response format", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
|
||||
@@ -3,7 +3,9 @@ package test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -37,6 +39,17 @@ var vertexExpressModeConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(多 API Token)
|
||||
var vertexExpressModeMultiTokensConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-express-a", "test-api-key-express-b"},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(含模型映射)
|
||||
var vertexExpressModeWithModelMappingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -167,6 +180,18 @@ var vertexRawModeWithBasePathConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Raw 模式配置(Express Mode + 多 API Token)
|
||||
var vertexRawModeExpressMultiTokensConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-raw-a", "test-api-key-raw-b"},
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunVertexParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex 标准模式配置解析
|
||||
@@ -380,6 +405,149 @@ func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(多 token - Google 路径使用请求上下文中的 apiTokenInUse)
|
||||
t.Run("vertex express mode chat completion should reuse api token in context", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeMultiTokensConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
tokens := []string{"test-api-key-express-a", "test-api-key-express-b"}
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 从 debug log 中提取请求头阶段固定的 apiTokenInUse
|
||||
var apiTokenInUse string
|
||||
for _, debugLog := range host.GetDebugLogs() {
|
||||
const prefix = "Use apiToken "
|
||||
const suffix = " to send request"
|
||||
start := strings.Index(debugLog, prefix)
|
||||
if start == -1 {
|
||||
continue
|
||||
}
|
||||
start += len(prefix)
|
||||
end := strings.Index(debugLog[start:], suffix)
|
||||
if end == -1 {
|
||||
continue
|
||||
}
|
||||
apiTokenInUse = debugLog[start : start+end]
|
||||
break
|
||||
}
|
||||
require.Contains(t, tokens, apiTokenInUse, "apiTokenInUse should be selected from configured tokens")
|
||||
|
||||
// 强制设置随机种子,让旧实现(OnRequestBody 再次随机)必然选到不同 token
|
||||
targetIndex := 0
|
||||
if apiTokenInUse == tokens[0] {
|
||||
targetIndex = 1
|
||||
}
|
||||
seed := int64(1)
|
||||
for {
|
||||
if rand.New(rand.NewSource(seed)).Intn(len(tokens)) == targetIndex {
|
||||
break
|
||||
}
|
||||
seed++
|
||||
}
|
||||
rand.Seed(seed)
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"token consistency test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, pathHeader, "Path header should not be empty")
|
||||
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Google publisher endpoint")
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(pathHeader)
|
||||
require.NoError(t, err)
|
||||
query := parsedPath.Query()
|
||||
require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter")
|
||||
require.Equal(t, apiTokenInUse, query.Get("key"),
|
||||
"Path key should use apiTokenInUse selected in request headers phase")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(多 token - Anthropic 路径使用请求上下文中的 apiTokenInUse)
|
||||
t.Run("vertex express mode anthropic request should reuse api token in context", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeMultiTokensConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
tokens := []string{"test-api-key-express-a", "test-api-key-express-b"}
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 从 debug log 中提取请求头阶段固定的 apiTokenInUse
|
||||
var apiTokenInUse string
|
||||
for _, debugLog := range host.GetDebugLogs() {
|
||||
const prefix = "Use apiToken "
|
||||
const suffix = " to send request"
|
||||
start := strings.Index(debugLog, prefix)
|
||||
if start == -1 {
|
||||
continue
|
||||
}
|
||||
start += len(prefix)
|
||||
end := strings.Index(debugLog[start:], suffix)
|
||||
if end == -1 {
|
||||
continue
|
||||
}
|
||||
apiTokenInUse = debugLog[start : start+end]
|
||||
break
|
||||
}
|
||||
require.Contains(t, tokens, apiTokenInUse, "apiTokenInUse should be selected from configured tokens")
|
||||
|
||||
// 强制设置随机种子,让旧实现(OnRequestBody 再次随机)必然选到不同 token
|
||||
targetIndex := 0
|
||||
if apiTokenInUse == tokens[0] {
|
||||
targetIndex = 1
|
||||
}
|
||||
seed := int64(1)
|
||||
for {
|
||||
if rand.New(rand.NewSource(seed)).Intn(len(tokens)) == targetIndex {
|
||||
break
|
||||
}
|
||||
seed++
|
||||
}
|
||||
rand.Seed(seed)
|
||||
|
||||
requestBody := `{"model":"claude-sonnet-4@20250514","messages":[{"role":"user","content":"hello anthropic"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, pathHeader, "Path header should not be empty")
|
||||
require.Contains(t, pathHeader, "/v1/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict",
|
||||
"Path should use Anthropic publisher endpoint")
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(pathHeader)
|
||||
require.NoError(t, err)
|
||||
query := parsedPath.Query()
|
||||
require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter")
|
||||
require.Equal(t, apiTokenInUse, query.Get("key"),
|
||||
"Path key should use apiTokenInUse selected in request headers phase")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: json_schema 映射
|
||||
t.Run("vertex express mode structured outputs json_schema request body mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
@@ -2202,7 +2370,7 @@ func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
|
||||
func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体)
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体 + API Key 认证)
|
||||
t.Run("vertex raw mode express - request body passthrough", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||
defer host.Reset()
|
||||
@@ -2214,6 +2382,7 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"Authorization", "Bearer some-token"},
|
||||
})
|
||||
|
||||
// 设置原生 Vertex 格式的请求体
|
||||
@@ -2229,6 +2398,22 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
|
||||
// 请求体应该保持原样
|
||||
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||
|
||||
// 验证 API Key 被追加到 URL path 中
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode",
|
||||
"API key should be appended to path as query parameter")
|
||||
|
||||
// 验证 Authorization header 被删除
|
||||
require.False(t, test.HasHeaderWithValue(requestHeaders, "Authorization", "Bearer some-token"),
|
||||
"Authorization header should be removed in Express Mode")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(标准模式 - 需要 OAuth token)
|
||||
@@ -2304,13 +2489,13 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(流式请求)
|
||||
// 测试 Vertex Raw 模式请求体处理(流式请求 - path 已含 ? 时用 & 拼接 API Key)
|
||||
t.Run("vertex raw mode express - streaming request body passthrough", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头(流式端点)
|
||||
// 先设置请求头(流式端点,path 已含 ?alt=sse)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse"},
|
||||
@@ -2328,6 +2513,194 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||
|
||||
// 验证 API Key 使用 & 拼接(因为 path 已含 ?alt=sse)
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "?alt=sse&key=test-api-key-for-raw-mode",
|
||||
"API key should be appended with & when path already contains ?")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode + Anthropic 模型路径)
|
||||
t.Run("vertex raw mode express - anthropic model request body with api key", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 使用 Anthropic 模型的原生 Vertex AI REST API 路径
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/projects/test-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"anthropic_version":"vertex-2023-10-16","messages":[{"role":"user","content":"Hello"}],"max_tokens":1024}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体被透传
|
||||
processedBody := host.GetRequestBody()
|
||||
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||
|
||||
// 验证 API Key 被追加到 path
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode",
|
||||
"API key should be appended to anthropic model path")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode + basePath - API Key 正确追加)
|
||||
t.Run("vertex raw mode with basePath express - request body with api key", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeWithBasePathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 带 basePath 前缀的请求
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径:basePath 被移除 + API Key 被追加
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotContains(t, pathHeader, "/vertex-proxy",
|
||||
"Path should have basePath prefix removed")
|
||||
require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode",
|
||||
"API key should be appended after basePath removal")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode + 多 token,使用请求上下文中的 apiTokenInUse)
|
||||
t.Run("vertex raw mode express - should reuse api token in context for query key", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeExpressMultiTokensConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 选择一个保证前两次 Intn(2) 结果不同的种子:
|
||||
// 第一次用于 SetApiTokenInUse,第二次仅在旧实现中用于 OnRequestBody.GetRandomToken。
|
||||
seed := int64(1)
|
||||
for {
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
if r.Intn(2) != r.Intn(2) {
|
||||
break
|
||||
}
|
||||
seed++
|
||||
}
|
||||
rand.Seed(seed)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, pathHeader, "Path header should not be empty")
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(pathHeader)
|
||||
require.NoError(t, err)
|
||||
query := parsedPath.Query()
|
||||
require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter")
|
||||
keyInPath := query.Get("key")
|
||||
require.NotEmpty(t, keyInPath, "Path should contain key query parameter")
|
||||
|
||||
// 从 debug log 中提取本次请求固定的 apiTokenInUse
|
||||
var apiTokenInUse string
|
||||
for _, debugLog := range host.GetDebugLogs() {
|
||||
const prefix = "Use apiToken "
|
||||
const suffix = " to send request"
|
||||
start := strings.Index(debugLog, prefix)
|
||||
if start == -1 {
|
||||
continue
|
||||
}
|
||||
start += len(prefix)
|
||||
end := strings.Index(debugLog[start:], suffix)
|
||||
if end == -1 {
|
||||
continue
|
||||
}
|
||||
apiTokenInUse = debugLog[start : start+end]
|
||||
break
|
||||
}
|
||||
require.NotEmpty(t, apiTokenInUse, "apiTokenInUse should be logged")
|
||||
require.Equal(t, apiTokenInUse, keyInPath,
|
||||
"Query key must use apiTokenInUse from request context")
|
||||
})
|
||||
|
||||
// 测试 Vertex Raw 模式请求体处理(Express Mode + 已有 key 参数时应覆盖而不是追加重复)
|
||||
t.Run("vertex raw mode express - should replace existing key query parameter", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=client-key&trace=1"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
var pathHeader string
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, pathHeader, "Path header should not be empty")
|
||||
|
||||
parsedPath, err := url.ParseRequestURI(pathHeader)
|
||||
require.NoError(t, err)
|
||||
query := parsedPath.Query()
|
||||
|
||||
require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter")
|
||||
require.Equal(t, "test-api-key-for-raw-mode", query.Get("key"),
|
||||
"Existing key query parameter should be replaced by configured API key")
|
||||
require.Equal(t, "sse", query.Get("alt"), "Existing query parameter alt should be preserved")
|
||||
require.Equal(t, "1", query.Get("trace"), "Existing query parameter trace should be preserved")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user