diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 2ee4272d2..561e139d6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "regexp" "strings" "time" @@ -224,6 +225,34 @@ func (v *vertexProvider) getToken() (cached bool, err error) { return false, err } +func appendOrReplaceAPIKey(path, apiKey string) string { + if apiKey == "" { + return path + } + + parsedPath, err := url.ParseRequestURI(path) + if err != nil { + // Fallback to simple append when path is not parseable. + if strings.Contains(path, "?") { + return path + "&key=" + apiKey + } + return path + "?key=" + apiKey + } + + query := parsedPath.Query() + query.Set("key", apiKey) + parsedPath.RawQuery = query.Encode() + return parsedPath.RequestURI() +} + +func (v *vertexProvider) getExpressAPIKey(ctx wrapper.HttpContext) string { + apiKey := v.config.GetApiTokenInUse(ctx) + if apiKey == "" { + apiKey = v.config.GetRandomToken() + } + return apiKey +} + func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !v.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName @@ -234,8 +263,14 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, // 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用 if apiName == ApiNameVertexRaw { ctx.SetContext(contextVertexRawMarker, true) - // Express Mode 不需要 OAuth 认证 + // Express Mode: 将 API Key 追加到 URL query 参数中 if v.isExpressMode() { + headers := util.GetRequestHeaders() + path := headers.Get(":path") + path = appendOrReplaceAPIKey(path, v.getExpressAPIKey(ctx)) + util.OverwriteRequestPathHeader(headers, path) + headers.Del("Authorization") + util.ReplaceRequestHeaders(headers) return types.ActionContinue, nil } // 标准模式需要获取 OAuth token @@ -354,7 +389,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo } if strings.HasPrefix(request.Model, "claude") { ctx.SetContext(contextClaudeMarker, true) - path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + path := v.getAhthropicRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream) util.OverwriteRequestPathHeader(headers, path) claudeRequest := v.claude.buildClaudeTextGenRequest(request) @@ -366,7 +401,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo } return claudeBody, nil } else { - path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + path := v.getRequestPath(ctx, ApiNameChatCompletion, request.Model, request.Stream) util.OverwriteRequestPathHeader(headers, path) vertexRequest, err := v.buildVertexChatRequest(request) @@ -382,7 +417,7 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { return nil, err } - path := v.getRequestPath(ApiNameEmbeddings, request.Model, false) + path := v.getRequestPath(ctx, ApiNameEmbeddings, request.Model, false) util.OverwriteRequestPathHeader(headers, path) vertexRequest := v.buildEmbeddingRequest(request) @@ -395,7 +430,7 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b return nil, err } // 图片生成不使用流式端点,需要完整响应 - path := v.getRequestPath(ApiNameImageGeneration, request.Model, false) + path := v.getRequestPath(ctx, ApiNameImageGeneration, request.Model, false) util.OverwriteRequestPathHeader(headers, path) vertexRequest, err := v.buildVertexImageGenerationRequest(request) @@ -442,7 +477,7 @@ func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body [] return nil, fmt.Errorf("missing prompt in request") } - path := v.getRequestPath(ApiNameImageEdit, request.Model, false) + path := v.getRequestPath(ctx, ApiNameImageEdit, request.Model, false) util.OverwriteRequestPathHeader(headers, path) headers.Set("Content-Type", util.MimeTypeApplicationJson) vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs) @@ -485,7 +520,7 @@ func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, bo prompt = vertexImageVariationDefaultPrompt } - path := v.getRequestPath(ApiNameImageVariation, request.Model, false) + path := v.getRequestPath(ctx, ApiNameImageVariation, request.Model, false) util.OverwriteRequestPathHeader(headers, path) headers.Set("Content-Type", util.MimeTypeApplicationJson) vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs) @@ -909,7 +944,7 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } -func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string { +func (v *vertexProvider) getAhthropicRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string { action := "" if stream { action = vertexAnthropicMessageStreamAction @@ -920,22 +955,15 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string if v.isExpressMode() { // Express Mode: 简化路径 + API Key 参数 basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action) - apiKey := v.config.GetRandomToken() - // 如果 action 已经包含 ?,使用 & 拼接 - var fullPath string - if strings.Contains(action, "?") { - fullPath = basePath + "&key=" + apiKey - } else { - fullPath = basePath + "?key=" + apiKey - } - return fullPath + apiKey := v.getExpressAPIKey(ctx) + return appendOrReplaceAPIKey(basePath, apiKey) } path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) return path } -func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { +func (v *vertexProvider) getRequestPath(ctx wrapper.HttpContext, apiName ApiName, modelId string, stream bool) string { action := "" switch apiName { case ApiNameEmbeddings: @@ -954,15 +982,8 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream if v.isExpressMode() { // Express Mode: 简化路径 + API Key 参数 basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action) - apiKey := v.config.GetRandomToken() - // 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接 - var fullPath string - if strings.Contains(action, "?") { - fullPath = basePath + "&key=" + apiKey - } else { - fullPath = basePath + "?key=" + apiKey - } - return fullPath + apiKey := v.getExpressAPIKey(ctx) + return appendOrReplaceAPIKey(basePath, apiKey) } path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go index 5b81fcc06..c9113d40e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go @@ -8,6 +8,42 @@ import ( "github.com/stretchr/testify/require" ) +func TestAppendOrReplaceAPIKey(t *testing.T) { + t.Run("empty apiKey returns path unchanged", func(t *testing.T) { + path := "/v1/publishers/google/models/gemini:generateContent" + assert.Equal(t, path, appendOrReplaceAPIKey(path, "")) + }) + + t.Run("path without query appends ?key=", func(t *testing.T) { + result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent", "my-key") + assert.Equal(t, "/v1/models/gemini:generateContent?key=my-key", result) + }) + + t.Run("path with existing query appends &key=", func(t *testing.T) { + result := appendOrReplaceAPIKey("/v1/models/gemini:streamGenerateContent?alt=sse", "my-key") + assert.Contains(t, result, "alt=sse") + assert.Contains(t, result, "key=my-key") + }) + + t.Run("existing key parameter is replaced", func(t *testing.T) { + result := appendOrReplaceAPIKey("/v1/models/gemini:generateContent?key=old-key&trace=1", "new-key") + assert.Contains(t, result, "key=new-key") + assert.NotContains(t, result, "old-key") + assert.Contains(t, result, "trace=1") + }) + + t.Run("unparseable path without query falls back to ?key= append", func(t *testing.T) { + // A bare string with no leading slash is not a valid RequestURI + result := appendOrReplaceAPIKey("not-a-valid-uri", "my-key") + assert.Equal(t, "not-a-valid-uri?key=my-key", result) + }) + + t.Run("unparseable path with query falls back to &key= append", func(t *testing.T) { + result := appendOrReplaceAPIKey("not-a-valid-uri?foo=bar", "my-key") + assert.Equal(t, "not-a-valid-uri?foo=bar&key=my-key", result) + }) +} + func TestVertexProviderBuildChatRequestStructuredOutputMapping(t *testing.T) { t.Run("json_object response format", func(t *testing.T) { v := &vertexProvider{} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go index 591633a65..8ab18a68b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -3,7 +3,9 @@ package test import ( "bytes" "encoding/json" + "math/rand" "mime/multipart" + "net/url" "strings" "testing" @@ -37,6 +39,17 @@ var vertexExpressModeConfig = func() json.RawMessage { return data }() +// 测试配置:Vertex Express Mode 配置(多 API Token) +var vertexExpressModeMultiTokensConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-express-a", "test-api-key-express-b"}, + }, + }) + return data +}() + // 测试配置:Vertex Express Mode 配置(含模型映射) var vertexExpressModeWithModelMappingConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -167,6 +180,18 @@ var vertexRawModeWithBasePathConfig = func() json.RawMessage { return data }() +// 测试配置:Vertex Raw 模式配置(Express Mode + 多 API Token) +var vertexRawModeExpressMultiTokensConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-raw-a", "test-api-key-raw-b"}, + "protocol": "original", + }, + }) + return data +}() + func RunVertexParseConfigTests(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试 Vertex 标准模式配置解析 @@ -380,6 +405,149 @@ func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) { require.True(t, hasVertexLogs, "Should have vertex processing logs") }) + // 测试 Vertex Express Mode 请求体处理(多 token - Google 路径使用请求上下文中的 apiTokenInUse) + t.Run("vertex express mode chat completion should reuse api token in context", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeMultiTokensConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + tokens := []string{"test-api-key-express-a", "test-api-key-express-b"} + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 从 debug log 中提取请求头阶段固定的 apiTokenInUse + var apiTokenInUse string + for _, debugLog := range host.GetDebugLogs() { + const prefix = "Use apiToken " + const suffix = " to send request" + start := strings.Index(debugLog, prefix) + if start == -1 { + continue + } + start += len(prefix) + end := strings.Index(debugLog[start:], suffix) + if end == -1 { + continue + } + apiTokenInUse = debugLog[start : start+end] + break + } + require.Contains(t, tokens, apiTokenInUse, "apiTokenInUse should be selected from configured tokens") + + // 强制设置随机种子,让旧实现(OnRequestBody 再次随机)必然选到不同 token + targetIndex := 0 + if apiTokenInUse == tokens[0] { + targetIndex = 1 + } + seed := int64(1) + for { + if rand.New(rand.NewSource(seed)).Intn(len(tokens)) == targetIndex { + break + } + seed++ + } + rand.Seed(seed) + + requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"token consistency test"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotEmpty(t, pathHeader, "Path header should not be empty") + require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Google publisher endpoint") + + parsedPath, err := url.ParseRequestURI(pathHeader) + require.NoError(t, err) + query := parsedPath.Query() + require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter") + require.Equal(t, apiTokenInUse, query.Get("key"), + "Path key should use apiTokenInUse selected in request headers phase") + }) + + // 测试 Vertex Express Mode 请求体处理(多 token - Anthropic 路径使用请求上下文中的 apiTokenInUse) + t.Run("vertex express mode anthropic request should reuse api token in context", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeMultiTokensConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + tokens := []string{"test-api-key-express-a", "test-api-key-express-b"} + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 从 debug log 中提取请求头阶段固定的 apiTokenInUse + var apiTokenInUse string + for _, debugLog := range host.GetDebugLogs() { + const prefix = "Use apiToken " + const suffix = " to send request" + start := strings.Index(debugLog, prefix) + if start == -1 { + continue + } + start += len(prefix) + end := strings.Index(debugLog[start:], suffix) + if end == -1 { + continue + } + apiTokenInUse = debugLog[start : start+end] + break + } + require.Contains(t, tokens, apiTokenInUse, "apiTokenInUse should be selected from configured tokens") + + // 强制设置随机种子,让旧实现(OnRequestBody 再次随机)必然选到不同 token + targetIndex := 0 + if apiTokenInUse == tokens[0] { + targetIndex = 1 + } + seed := int64(1) + for { + if rand.New(rand.NewSource(seed)).Intn(len(tokens)) == targetIndex { + break + } + seed++ + } + rand.Seed(seed) + + requestBody := `{"model":"claude-sonnet-4@20250514","messages":[{"role":"user","content":"hello anthropic"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotEmpty(t, pathHeader, "Path header should not be empty") + require.Contains(t, pathHeader, "/v1/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict", + "Path should use Anthropic publisher endpoint") + + parsedPath, err := url.ParseRequestURI(pathHeader) + require.NoError(t, err) + query := parsedPath.Query() + require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter") + require.Equal(t, apiTokenInUse, query.Get("key"), + "Path key should use apiTokenInUse selected in request headers phase") + }) + // 测试 Vertex Express Mode structured outputs: json_schema 映射 t.Run("vertex express mode structured outputs json_schema request body mapping", func(t *testing.T) { host, status := test.NewTestHost(vertexExpressModeConfig) @@ -2202,7 +2370,7 @@ func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) { func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { test.RunTest(t, func(t *testing.T) { - // 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体) + // 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体 + API Key 认证) t.Run("vertex raw mode express - request body passthrough", func(t *testing.T) { host, status := test.NewTestHost(vertexRawModeExpressConfig) defer host.Reset() @@ -2214,6 +2382,7 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, {":method", "POST"}, {"Content-Type", "application/json"}, + {"Authorization", "Bearer some-token"}, }) // 设置原生 Vertex 格式的请求体 @@ -2229,6 +2398,22 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { // 请求体应该保持原样 require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + + // 验证 API Key 被追加到 URL path 中 + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode", + "API key should be appended to path as query parameter") + + // 验证 Authorization header 被删除 + require.False(t, test.HasHeaderWithValue(requestHeaders, "Authorization", "Bearer some-token"), + "Authorization header should be removed in Express Mode") }) // 测试 Vertex Raw 模式请求体处理(标准模式 - 需要 OAuth token) @@ -2304,13 +2489,13 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed") }) - // 测试 Vertex Raw 模式请求体处理(流式请求) + // 测试 Vertex Raw 模式请求体处理(流式请求 - path 已含 ? 时用 & 拼接 API Key) t.Run("vertex raw mode express - streaming request body passthrough", func(t *testing.T) { host, status := test.NewTestHost(vertexRawModeExpressConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) - // 先设置请求头(流式端点) + // 先设置请求头(流式端点,path 已含 ?alt=sse) host.CallOnHttpRequestHeaders([][2]string{ {":authority", "example.com"}, {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse"}, @@ -2328,6 +2513,194 @@ func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { processedBody := host.GetRequestBody() require.NotNil(t, processedBody) require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + + // 验证 API Key 使用 & 拼接(因为 path 已含 ?alt=sse) + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "?alt=sse&key=test-api-key-for-raw-mode", + "API key should be appended with & when path already contains ?") + }) + + // 测试 Vertex Raw 模式请求体处理(Express Mode + Anthropic 模型路径) + t.Run("vertex raw mode express - anthropic model request body with api key", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用 Anthropic 模型的原生 Vertex AI REST API 路径 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"anthropic_version":"vertex-2023-10-16","messages":[{"role":"user","content":"Hello"}],"max_tokens":1024}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证请求体被透传 + processedBody := host.GetRequestBody() + require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + + // 验证 API Key 被追加到 path + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode", + "API key should be appended to anthropic model path") + }) + + // 测试 Vertex Raw 模式请求体处理(Express Mode + basePath - API Key 正确追加) + t.Run("vertex raw mode with basePath express - request body with api key", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeWithBasePathConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 带 basePath 前缀的请求 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证路径:basePath 被移除 + API Key 被追加 + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotContains(t, pathHeader, "/vertex-proxy", + "Path should have basePath prefix removed") + require.Contains(t, pathHeader, "?key=test-api-key-for-raw-mode", + "API key should be appended after basePath removal") + }) + + // 测试 Vertex Raw 模式请求体处理(Express Mode + 多 token,使用请求上下文中的 apiTokenInUse) + t.Run("vertex raw mode express - should reuse api token in context for query key", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressMultiTokensConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 选择一个保证前两次 Intn(2) 结果不同的种子: + // 第一次用于 SetApiTokenInUse,第二次仅在旧实现中用于 OnRequestBody.GetRandomToken。 + seed := int64(1) + for { + r := rand.New(rand.NewSource(seed)) + if r.Intn(2) != r.Intn(2) { + break + } + seed++ + } + rand.Seed(seed) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotEmpty(t, pathHeader, "Path header should not be empty") + + parsedPath, err := url.ParseRequestURI(pathHeader) + require.NoError(t, err) + query := parsedPath.Query() + require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter") + keyInPath := query.Get("key") + require.NotEmpty(t, keyInPath, "Path should contain key query parameter") + + // 从 debug log 中提取本次请求固定的 apiTokenInUse + var apiTokenInUse string + for _, debugLog := range host.GetDebugLogs() { + const prefix = "Use apiToken " + const suffix = " to send request" + start := strings.Index(debugLog, prefix) + if start == -1 { + continue + } + start += len(prefix) + end := strings.Index(debugLog[start:], suffix) + if end == -1 { + continue + } + apiTokenInUse = debugLog[start : start+end] + break + } + require.NotEmpty(t, apiTokenInUse, "apiTokenInUse should be logged") + require.Equal(t, apiTokenInUse, keyInPath, + "Query key must use apiTokenInUse from request context") + }) + + // 测试 Vertex Raw 模式请求体处理(Express Mode + 已有 key 参数时应覆盖而不是追加重复) + t.Run("vertex raw mode express - should replace existing key query parameter", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=client-key&trace=1"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + var pathHeader string + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotEmpty(t, pathHeader, "Path header should not be empty") + + parsedPath, err := url.ParseRequestURI(pathHeader) + require.NoError(t, err) + query := parsedPath.Query() + + require.Len(t, query["key"], 1, "Path should contain exactly one key query parameter") + require.Equal(t, "test-api-key-for-raw-mode", query.Get("key"), + "Existing key query parameter should be replaced by configured API key") + require.Equal(t, "sse", query.Get("alt"), "Existing query parameter alt should be preserved") + require.Equal(t, "1", query.Get("trace"), "Existing query parameter trace should be preserved") }) }) }