From 032a69556f0682917dd6cf50c62ce49f0a4b567f Mon Sep 17 00:00:00 2001 From: woody Date: Tue, 13 Jan 2026 20:00:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(vertex):=20=E4=B8=BA=20ai-proxy=20?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E7=9A=84=20Vertex=20AI=20Provider=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20Express=20Mode=20=E6=94=AF=E6=8C=81=20||?= =?UTF-8?q?=20feat(vertex):=20Add=20Express=20Mode=20support=20to=20Vertex?= =?UTF-8?q?=20AI=20Provider=20of=20ai-proxy=20plug-in=20(#3301)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/wasm-go/extensions/ai-proxy/README.md | 69 ++- .../wasm-go/extensions/ai-proxy/README_EN.md | 66 ++- .../wasm-go/extensions/ai-proxy/main_test.go | 8 + .../extensions/ai-proxy/provider/vertex.go | 103 +++- .../extensions/ai-proxy/test/vertex.go | 499 ++++++++++++++++++ 5 files changed, 726 insertions(+), 19 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/vertex.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 499ae5444..9c3c8bf91 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -309,7 +309,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下: #### Google Vertex AI -Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下: +Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式: + +**标准模式**(使用 Service Account): | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------| @@ -320,6 +322,15 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下 | `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | | `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) | +**Express Mode**(使用 API Key,简化配置): + +Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------| +| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key,从 Google Cloud Console 的 API & Services > Credentials 获取 | +| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | + #### AWS Bedrock AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式: @@ -1955,7 +1966,7 @@ provider: } ``` -### 使用 OpenAI 协议代理 Google Vertex 服务 +### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式) **配置信息** @@ -2017,6 +2028,60 @@ provider: } ``` +### 使用 OpenAI 协议代理 Google Vertex 服务(Express Mode) + +Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。 + +**配置信息** + +```yaml +provider: + type: vertex + apiTokens: + - "YOUR_API_KEY" +``` + +**请求示例** + +```json +{ + "model": "gemini-2.5-flash", + "messages": [ + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "chatcmpl-0000000000000", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "你好!我是 Gemini,由 Google 开发的人工智能助手。有什么我可以帮您的吗?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.5-flash", + "object": "chat.completion", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } +} +``` + ### 使用 OpenAI 协议代理 AWS Bedrock 服务 AWS Bedrock 支持两种认证方式: diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index b02812b24..83583972c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -255,7 +255,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i | `targetLang` | string | Required | - | The target language required by the DeepL translation service | #### Google Vertex AI -For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is: +For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes: + +**Standard Mode** (using Service Account): | Name | Data Type | Requirement | Default | Description | |-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------| @@ -266,6 +268,15 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field | `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. | | `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds | +**Express Mode** (using API Key, simplified configuration): + +Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview). + +| Name | Data Type | Requirement | Default | Description | +|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials | +| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. | + #### AWS Bedrock For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods: @@ -1728,7 +1739,7 @@ provider: } ``` -### Utilizing OpenAI Protocol Proxy for Google Vertex Services +### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode) **Configuration Information** ```yaml provider: @@ -1786,6 +1797,57 @@ provider: } ``` +### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode) + +Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly. + +**Configuration Information** +```yaml +provider: + type: vertex + apiTokens: + - "YOUR_API_KEY" +``` + +**Request Example** +```json +{ + "model": "gemini-2.5-flash", + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ], + "stream": false +} +``` + +**Response Example** +```json +{ + "id": "chatcmpl-0000000000000", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.5-flash", + "object": "chat.completion", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } +} +``` + ### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services AWS Bedrock supports two authentication methods: diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index c7accee46..93ded731d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -129,6 +129,14 @@ func TestGeneric(t *testing.T) { test.RunGenericOnHttpRequestBodyTests(t) } +func TestVertex(t *testing.T) { + test.RunVertexParseConfigTests(t) + test.RunVertexExpressModeOnHttpRequestHeadersTests(t) + test.RunVertexExpressModeOnHttpRequestBodyTests(t) + test.RunVertexExpressModeOnHttpResponseBodyTests(t) + test.RunVertexExpressModeOnStreamingResponseBodyTests(t) +} + func TestBedrock(t *testing.T) { test.RunBedrockParseConfigTests(t) test.RunBedrockOnHttpRequestHeadersTests(t) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index bcf68a779..b947501e0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -27,8 +27,11 @@ const ( vertexAuthDomain = "oauth2.googleapis.com" vertexDomain = "aiplatform.googleapis.com" // /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION} - vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" - vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s" + vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" + vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s" + // Express Mode 路径模板 (不含 project/location) + vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s" + vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s" vertexChatCompletionAction = "generateContent" vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" vertexAnthropicMessageAction = "rawPredict" @@ -42,6 +45,13 @@ const ( type vertexProviderInitializer struct{} func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error { + // Express Mode: 如果配置了 apiTokens,则使用 API Key 认证 + if len(config.apiTokens) > 0 { + // Express Mode 不需要其他配置 + return nil + } + + // 标准模式: 保持原有验证逻辑 if config.vertexAuthKey == "" { return errors.New("missing vertexAuthKey in vertex provider config") } @@ -63,19 +73,32 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { config.setDefaultCapabilities(v.DefaultCapabilities()) - return &vertexProvider{ - config: config, - client: wrapper.NewClusterClient(wrapper.DnsCluster{ - Domain: vertexAuthDomain, - ServiceName: config.vertexAuthServiceName, - Port: 443, - }), + + provider := &vertexProvider{ + config: config, contextCache: createContextCache(&config), claude: &claudeProvider{ config: config, contextCache: createContextCache(&config), }, - }, nil + } + + // 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置) + if !provider.isExpressMode() { + provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{ + Domain: vertexAuthDomain, + ServiceName: config.vertexAuthServiceName, + Port: 443, + }) + } + + return provider, nil +} + +// isExpressMode 检测是否启用 Express Mode +// 如果配置了 apiTokens,则使用 Express Mode(API Key 认证) +func (v *vertexProvider) isExpressMode() bool { + return len(v.config.apiTokens) > 0 } type vertexProvider struct { @@ -106,11 +129,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { var finalVertexDomain string - if v.config.vertexRegion != vertexGlobalRegion { - finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain) - } else { + + if v.isExpressMode() { + // Express Mode: 固定域名,不带 region 前缀 finalVertexDomain = vertexDomain + } else { + // 标准模式: 带 region 前缀 + if v.config.vertexRegion != vertexGlobalRegion { + finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain) + } else { + finalVertexDomain = vertexDomain + } } + util.OverwriteRequestHostHeader(headers, finalVertexDomain) } @@ -156,6 +187,16 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, headers := util.GetRequestHeaders() body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) headers.Set("Content-Length", fmt.Sprint(len(body))) + + if v.isExpressMode() { + // Express Mode: 不需要 Authorization header,API Key 已在 URL 中 + headers.Del("Authorization") + util.ReplaceRequestHeaders(headers) + _ = proxywasm.ReplaceHttpRequestBody(body) + return types.ActionContinue, err + } + + // 标准模式: 需要获取 OAuth token util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) if err != nil { @@ -422,7 +463,23 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string } else { action = vertexAnthropicMessageAction } - return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) + + if v.isExpressMode() { + // Express Mode: 简化路径 + API Key 参数 + basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action) + apiKey := v.config.GetRandomToken() + // 如果 action 已经包含 ?,使用 & 拼接 + var fullPath string + if strings.Contains(action, "?") { + fullPath = basePath + "&key=" + apiKey + } else { + fullPath = basePath + "?key=" + apiKey + } + return fullPath + } + + path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) + return path } func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { @@ -434,7 +491,23 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream } else { action = vertexChatCompletionAction } - return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) + + if v.isExpressMode() { + // Express Mode: 简化路径 + API Key 参数 + basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action) + apiKey := v.config.GetRandomToken() + // 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接 + var fullPath string + if strings.Contains(action, "?") { + fullPath = basePath + "&key=" + apiKey + } else { + fullPath = basePath + "?key=" + apiKey + } + return fullPath + } + + path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) + return path } func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest { diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go new file mode 100644 index 000000000..e2ba84ade --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -0,0 +1,499 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +// 测试配置:Vertex 标准模式配置 +var basicVertexConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", + }, + }) + return data +}() + +// 测试配置:Vertex Express Mode 配置(使用 apiTokens) +var vertexExpressModeConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-123456789"}, + }, + }) + return data +}() + +// 测试配置:Vertex Express Mode 配置(含模型映射) +var vertexExpressModeWithModelMappingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-123456789"}, + "modelMapping": map[string]string{ + "gpt-4": "gemini-2.5-flash", + "gpt-3.5-turbo": "gemini-2.5-flash-lite", + "text-embedding-ada-002": "text-embedding-001", + }, + }, + }) + return data +}() + +// 测试配置:Vertex Express Mode 配置(含安全设置) +var vertexExpressModeWithSafetyConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-123456789"}, + "geminiSafetySetting": map[string]string{ + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_LOW_AND_ABOVE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", + }, + }, + }) + return data +}() + +// 测试配置:无效 Vertex 标准模式配置(缺少 vertexAuthKey) +var invalidVertexStandardModeConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + // 缺少必需的标准模式配置 + }, + }) + return data +}() + +func RunVertexParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + // 测试 Vertex 标准模式配置解析 + t.Run("vertex standard mode config", func(t *testing.T) { + host, status := test.NewTestHost(basicVertexConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // 测试 Vertex Express Mode 配置解析 + t.Run("vertex express mode config", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // 测试 Vertex Express Mode 配置(含模型映射) + t.Run("vertex express mode with model mapping config", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // 测试无效 Vertex 标准模式配置(缺少 vertexAuthKey) + t.Run("invalid vertex standard mode config - missing auth key", func(t *testing.T) { + host, status := test.NewTestHost(invalidVertexStandardModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + // 测试 Vertex Express Mode 配置(含安全设置) + t.Run("vertex express mode with safety setting config", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +func RunVertexExpressModeOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 请求头处理(聊天完成接口) + t.Run("vertex express mode chat completion request headers", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 设置请求头 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 应该返回HeaderStopIteration,因为需要处理请求体 + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证Host是否被改为 vertex 域名(Express Mode 使用不带 region 前缀的域名) + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain without region prefix") + + // 检查是否有相关的处理日志 + debugLogs := host.GetDebugLogs() + hasVertexLogs := false + for _, log := range debugLogs { + if strings.Contains(log, "vertex") { + hasVertexLogs = true + break + } + } + require.True(t, hasVertexLogs, "Should have vertex processing logs") + }) + + // 测试 Vertex Express Mode 请求头处理(嵌入接口) + t.Run("vertex express mode embeddings request headers", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 设置请求头 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + // 验证嵌入接口的请求头处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证Host转换 + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain") + }) + }) +} + +func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 请求体处理(聊天完成接口) + t.Run("vertex express mode chat completion request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // Express Mode 不需要暂停等待 OAuth token + require.Equal(t, types.ActionContinue, action) + + // 验证请求体是否被正确处理 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 验证请求体被转换为 Vertex 格式 + require.Contains(t, string(processedBody), "contents", "Request should be converted to vertex format") + require.Contains(t, string(processedBody), "generationConfig", "Request should contain vertex generation config") + + // 验证路径包含 API Key + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter") + require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format without project/location") + + // 验证没有 Authorization header(Express Mode 使用 URL 参数) + hasAuthHeader := false + for _, header := range requestHeaders { + if header[0] == "Authorization" && header[1] != "" { + hasAuthHeader = true + break + } + } + require.False(t, hasAuthHeader, "Authorization header should be removed in Express Mode") + + // 检查是否有相关的处理日志 + debugLogs := host.GetDebugLogs() + hasVertexLogs := false + for _, log := range debugLogs { + if strings.Contains(log, "vertex") { + hasVertexLogs = true + break + } + } + require.True(t, hasVertexLogs, "Should have vertex processing logs") + }) + + // 测试 Vertex Express Mode 请求体处理(嵌入接口) + t.Run("vertex express mode embeddings request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"text-embedding-001","input":"test text"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证嵌入接口的请求体处理 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 验证请求体被转换为 Vertex 格式 + require.Contains(t, string(processedBody), "instances", "Request should be converted to vertex format") + + // 验证路径包含 API Key + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter") + }) + + // 测试 Vertex Express Mode 请求体处理(流式请求) + t.Run("vertex express mode streaming request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置流式请求体 + requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证路径包含流式 action + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "streamGenerateContent", "Path should contain streaming action") + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key") + }) + + // 测试 Vertex Express Mode 请求体处理(含模型映射) + t.Run("vertex express mode with model mapping request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(使用 OpenAI 模型名) + requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证路径包含映射后的模型名 + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name") + }) + }) +} + +func RunVertexExpressModeOnHttpResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 响应体处理(聊天完成接口) + t.Run("vertex express mode chat completion response body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性,确保IsResponseFromUpstream()返回true + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置响应体(Vertex 格式) + responseBody := `{ + "candidates": [{ + "content": { + "parts": [{ + "text": "Hello! How can I help you today?" + }] + }, + "finishReason": "STOP", + "index": 0 + }], + "usageMetadata": { + "promptTokenCount": 9, + "candidatesTokenCount": 12, + "totalTokenCount": 21 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体是否被正确处理 + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + // 验证响应体内容(转换为OpenAI格式) + responseStr := string(processedResponseBody) + + // 检查响应体是否被转换 + if strings.Contains(responseStr, "chat.completion") { + require.Contains(t, responseStr, "assistant", "Response should contain assistant role") + require.Contains(t, responseStr, "usage", "Response should contain usage information") + } + + // 检查是否有相关的处理日志 + debugLogs := host.GetDebugLogs() + hasResponseBodyLogs := false + for _, log := range debugLogs { + if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "vertex") { + hasResponseBodyLogs = true + break + } + } + require.True(t, hasResponseBodyLogs, "Should have response body processing logs") + }) + }) +} + +func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 流式响应处理 + t.Run("vertex express mode streaming response body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置流式请求体 + requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置流式响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "text/event-stream"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 模拟流式响应体 + chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}` + chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}` + + // 处理流式响应体 + action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false) + require.Equal(t, types.ActionContinue, action1) + + action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true) + require.Equal(t, types.ActionContinue, action2) + + // 验证流式响应处理 + debugLogs := host.GetDebugLogs() + hasStreamingLogs := false + for _, log := range debugLogs { + if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") { + hasStreamingLogs = true + break + } + } + require.True(t, hasStreamingLogs, "Should have streaming response processing logs") + }) + }) +}