From a2eb599eff0a5397fce3ede1a58c256886a015d9 Mon Sep 17 00:00:00 2001 From: woody Date: Wed, 21 Jan 2026 14:45:06 +0800 Subject: [PATCH] Implement Vertex Raw mode support in AI Proxy (#3375) --- .../wasm-go/extensions/ai-proxy/main_test.go | 4 + .../extensions/ai-proxy/provider/provider.go | 1 + .../extensions/ai-proxy/provider/vertex.go | 34 ++ .../extensions/ai-proxy/test/vertex.go | 376 +++++++++++++++++- 4 files changed, 403 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index fb2b4fbe9..83b3182a6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -137,6 +137,10 @@ func TestVertex(t *testing.T) { test.RunVertexExpressModeOnStreamingResponseBodyTests(t) test.RunVertexExpressModeImageGenerationRequestBodyTests(t) test.RunVertexExpressModeImageGenerationResponseBodyTests(t) + // Vertex Raw 模式测试 + test.RunVertexRawModeOnHttpRequestHeadersTests(t) + test.RunVertexRawModeOnHttpRequestBodyTests(t) + test.RunVertexRawModeOnHttpResponseBodyTests(t) } func TestBedrock(t *testing.T) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 563664f89..ef71d6d98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -70,6 +70,7 @@ const ( ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent" ApiNameAnthropicMessages ApiName = "anthropic/v1/messages" ApiNameAnthropicComplete ApiName = "anthropic/v1/complete" + ApiNameVertexRaw ApiName = "vertex/raw" // OpenAI PathOpenAIPrefix = "/v1" diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 4598f9b6e..3791e06ef 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -44,9 +44,15 @@ const ( vertexGlobalRegion = "global" contextClaudeMarker = "isClaudeRequest" contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" + contextVertexRawMarker = "isVertexRawRequest" vertexAnthropicVersion = "vertex-2023-10-16" ) +// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径 +// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action} +// 允许任意 basePath 前缀,兼容 basePathHandling 配置 +var vertexRawPathRegex = regexp.MustCompile(`^.*/([^/]+)/projects/([^/]+)/locations/([^/]+)/publishers/([^/]+)/models/([^/:]+):([^/?]+)`) + type vertexProviderInitializer struct{} func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error { @@ -92,6 +98,7 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { string(ApiNameChatCompletion): vertexPathTemplate, string(ApiNameEmbeddings): vertexPathTemplate, string(ApiNameImageGeneration): vertexPathTemplate, + string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换 } } @@ -143,6 +150,12 @@ func (v *vertexProvider) GetProviderType() string { } func (v *vertexProvider) GetApiName(path string) ApiName { + // 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀 + // 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action} + // 必须在其他 action 检查之前,因为 :predict、:generateContent 等 action 会被其他规则匹配 + if vertexRawPathRegex.MatchString(path) { + return ApiNameVertexRaw + } if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) { return ApiNameChatCompletion } @@ -211,6 +224,27 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if !v.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + + // Vertex Raw 模式: 透传请求体,只做 OAuth 认证 + // 用于直接访问 Vertex AI REST API,不做协议转换 + // 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用 + if apiName == ApiNameVertexRaw { + ctx.SetContext(contextVertexRawMarker, true) + // Express Mode 不需要 OAuth 认证 + if v.isExpressMode() { + return types.ActionContinue, nil + } + // 标准模式需要获取 OAuth token + cached, err := v.getToken() + if cached { + return types.ActionContinue, nil + } + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + if v.config.IsOriginal() { return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go index 312e70260..bc3b53c11 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -82,12 +82,12 @@ var invalidVertexStandardModeConfig = func() json.RawMessage { var vertexOpenAICompatibleModeConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "provider": map[string]interface{}{ - "type": "vertex", - "vertexOpenAICompatible": true, - "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, - "vertexRegion": "us-central1", - "vertexProjectId": "test-project-id", - "vertexAuthServiceName": "test-auth-service", + "type": "vertex", + "vertexOpenAICompatible": true, + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", }, }) return data @@ -97,12 +97,12 @@ var vertexOpenAICompatibleModeConfig = func() json.RawMessage { var vertexOpenAICompatibleModeWithModelMappingConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ "provider": map[string]interface{}{ - "type": "vertex", - "vertexOpenAICompatible": true, - "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, - "vertexRegion": "us-central1", - "vertexProjectId": "test-project-id", - "vertexAuthServiceName": "test-auth-service", + "type": "vertex", + "vertexOpenAICompatible": true, + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", "modelMapping": map[string]string{ "gpt-4": "gemini-2.0-flash", "gpt-3.5-turbo": "gemini-1.5-flash", @@ -124,6 +124,47 @@ var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage { return data }() +// 测试配置:Vertex Raw 模式配置(Express Mode + 原生 Vertex API 路径) +var vertexRawModeExpressConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-for-raw-mode"}, + "protocol": "original", + }, + }) + return data +}() + +// 测试配置:Vertex Raw 模式配置(标准模式 + 原生 Vertex API 路径) +var vertexRawModeStandardConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", + "protocol": "original", + }, + }) + return data +}() + +// 测试配置:Vertex Raw 模式配置(Express Mode + basePath removePrefix) +var vertexRawModeWithBasePathConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key-for-raw-mode"}, + "protocol": "original", + "basePath": "/vertex-proxy", + "basePathHandling": "removePrefix", + }, + }) + return data +}() + func RunVertexParseConfigTests(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试 Vertex 标准模式配置解析 @@ -1231,3 +1272,314 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) { }) }) } + +// ==================== Vertex Raw 模式测试 ==================== + +func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Raw 模式请求头处理(Express Mode + 原生 Vertex API 路径) + t.Run("vertex raw mode express - request headers with native vertex path", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用原生 Vertex AI REST API 路径 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 应该返回 HeaderStopIteration,因为需要处理请求体 + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证 Host 是否被改为 vertex 域名(Express Mode 使用不带 region 前缀的域名) + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), + "Host header should be changed to vertex domain without region prefix") + }) + + // 测试 Vertex Raw 模式请求头处理(标准模式 + 原生 Vertex API 路径) + t.Run("vertex raw mode standard - request headers with native vertex path", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeStandardConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用原生 Vertex AI REST API 路径 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证 Host 是否被改为 vertex 域名(标准模式使用带 region 前缀的域名) + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"), + "Host header should be changed to vertex domain with region prefix") + }) + + // 测试 Vertex Raw 模式请求头处理(带 basePath 前缀) + t.Run("vertex raw mode with basePath - request headers", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeWithBasePathConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用带 basePath 前缀的原生 Vertex AI REST API 路径 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证 Host 是否被改为 vertex 域名 + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), + "Host header should be changed to vertex domain") + + // 验证路径是否移除了 basePath 前缀 + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed") + require.Contains(t, pathHeader, "/v1/projects/", "Path should contain original vertex path after basePath removal") + }) + + // 测试 Vertex Raw 模式请求头处理(Anthropic 模型路径) + t.Run("vertex raw mode express - request headers with anthropic model path", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用 Anthropic 模型的原生 Vertex AI REST API 路径 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证 Host 是否被改为 vertex 域名 + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), + "Host header should be changed to vertex domain") + }) + }) +} + +func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体) + t.Run("vertex raw mode express - request body passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置原生 Vertex 格式的请求体 + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}],"generationConfig":{"temperature":0.7}}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // Express Mode 不需要暂停等待 OAuth token + require.Equal(t, types.ActionContinue, action) + + // 验证请求体被透传(不做格式转换) + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 请求体应该保持原样 + require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + }) + + // 测试 Vertex Raw 模式请求体处理(标准模式 - 需要 OAuth token) + // 注意:使用 countTokens action,因为 generateContent/predict 等会被识别为其他 API 类型 + // 注意:在单元测试环境中,由于测试配置使用的是无效的私钥,JWT 创建会失败, + // 因此 getToken() 会返回错误,导致 ActionContinue 而不是 ActionPause。 + // 这个测试主要验证代码正确进入了 Vertex Raw 模式的处理分支,请求体被透传。 + t.Run("vertex raw mode standard - request body with oauth", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeStandardConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 - 使用 countTokens action,这是一个不会被其他 API 类型匹配的原生 Vertex API + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:countTokens"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置原生 Vertex 格式的请求体 + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // 注意:在单元测试环境中,由于私钥无效,JWT 创建失败会返回 ActionContinue + // 在真实环境中,如果 JWT 创建成功,会返回 ActionPause 等待 OAuth token + // 这里我们只验证代码正确进入了 Vertex Raw 模式的处理分支 + require.Equal(t, types.ActionContinue, action) + + // 验证请求体被透传(不做格式转换) + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 请求体应该保持原样(这是 Vertex Raw 模式的核心功能) + require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + }) + + // 测试 Vertex Raw 模式请求体处理(带 basePath 前缀 - 路径正确处理) + t.Run("vertex raw mode with basePath - request body passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeWithBasePathConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头(带 basePath 前缀) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置原生 Vertex 格式的请求体(图片生成) + requestBody := `{"instances":[{"prompt":"A beautiful sunset"}],"parameters":{"sampleCount":1}}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // Express Mode 不需要暂停等待 OAuth token + require.Equal(t, types.ActionContinue, action) + + // 验证请求体被透传 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + + // 验证路径已正确处理(移除 basePath) + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed") + }) + + // 测试 Vertex Raw 模式请求体处理(流式请求) + t.Run("vertex raw mode express - streaming request body passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头(流式端点) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置原生 Vertex 格式的请求体 + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Tell me a story"}]}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证请求体被透传 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged") + }) + }) +} + +func RunVertexRawModeOnHttpResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Raw 模式响应体处理(透传响应) + t.Run("vertex raw mode express - response body passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexRawModeExpressConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性 + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置原生 Vertex 格式的响应体 + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Hello! How can I help you?"}] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 10, + "totalTokenCount": 15 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体被透传(不做格式转换) + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + // 响应应该保持原生 Vertex 格式 + require.Contains(t, responseStr, "candidates", "Response should keep native vertex format with candidates") + require.Contains(t, responseStr, "usageMetadata", "Response should keep native vertex format with usageMetadata") + }) + }) +}