Implement Vertex Raw mode support in AI Proxy (#3375)

This commit is contained in:
woody
2026-01-21 14:45:06 +08:00
committed by GitHub
parent 3a28a9b6a7
commit a2eb599eff
4 changed files with 403 additions and 12 deletions

View File

@@ -82,12 +82,12 @@ var invalidVertexStandardModeConfig = func() json.RawMessage {
var vertexOpenAICompatibleModeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
},
})
return data
@@ -97,12 +97,12 @@ var vertexOpenAICompatibleModeConfig = func() json.RawMessage {
var vertexOpenAICompatibleModeWithModelMappingConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
"modelMapping": map[string]string{
"gpt-4": "gemini-2.0-flash",
"gpt-3.5-turbo": "gemini-1.5-flash",
@@ -124,6 +124,47 @@ var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage {
return data
}()
// 测试配置Vertex Raw 模式配置Express Mode + 原生 Vertex API 路径)
var vertexRawModeExpressConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key-for-raw-mode"},
"protocol": "original",
},
})
return data
}()
// 测试配置Vertex Raw 模式配置(标准模式 + 原生 Vertex API 路径)
var vertexRawModeStandardConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
"protocol": "original",
},
})
return data
}()
// 测试配置Vertex Raw 模式配置Express Mode + basePath removePrefix
var vertexRawModeWithBasePathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key-for-raw-mode"},
"protocol": "original",
"basePath": "/vertex-proxy",
"basePathHandling": "removePrefix",
},
})
return data
}()
func RunVertexParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试 Vertex 标准模式配置解析
@@ -1231,3 +1272,314 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
})
})
}
// ==================== Vertex Raw 模式测试 ====================
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Raw 模式请求头处理Express Mode + 原生 Vertex API 路径)
t.Run("vertex raw mode express - request headers with native vertex path", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeExpressConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 使用原生 Vertex AI REST API 路径
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回 HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 vertex 域名Express Mode 使用不带 region 前缀的域名)
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
"Host header should be changed to vertex domain without region prefix")
})
// 测试 Vertex Raw 模式请求头处理(标准模式 + 原生 Vertex API 路径)
t.Run("vertex raw mode standard - request headers with native vertex path", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeStandardConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 使用原生 Vertex AI REST API 路径
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 vertex 域名(标准模式使用带 region 前缀的域名)
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"),
"Host header should be changed to vertex domain with region prefix")
})
// 测试 Vertex Raw 模式请求头处理(带 basePath 前缀)
t.Run("vertex raw mode with basePath - request headers", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeWithBasePathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 使用带 basePath 前缀的原生 Vertex AI REST API 路径
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 vertex 域名
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
"Host header should be changed to vertex domain")
// 验证路径是否移除了 basePath 前缀
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed")
require.Contains(t, pathHeader, "/v1/projects/", "Path should contain original vertex path after basePath removal")
})
// 测试 Vertex Raw 模式请求头处理Anthropic 模型路径)
t.Run("vertex raw mode express - request headers with anthropic model path", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeExpressConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 使用 Anthropic 模型的原生 Vertex AI REST API 路径
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 vertex 域名
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
"Host header should be changed to vertex domain")
})
})
}
func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Raw 模式请求体处理Express Mode - 透传请求体)
t.Run("vertex raw mode express - request body passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeExpressConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置原生 Vertex 格式的请求体
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}],"generationConfig":{"temperature":0.7}}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// Express Mode 不需要暂停等待 OAuth token
require.Equal(t, types.ActionContinue, action)
// 验证请求体被透传(不做格式转换)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 请求体应该保持原样
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
})
// 测试 Vertex Raw 模式请求体处理(标准模式 - 需要 OAuth token
// 注意:使用 countTokens action因为 generateContent/predict 等会被识别为其他 API 类型
// 注意在单元测试环境中由于测试配置使用的是无效的私钥JWT 创建会失败,
// 因此 getToken() 会返回错误,导致 ActionContinue 而不是 ActionPause。
// 这个测试主要验证代码正确进入了 Vertex Raw 模式的处理分支,请求体被透传。
t.Run("vertex raw mode standard - request body with oauth", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeStandardConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头 - 使用 countTokens action这是一个不会被其他 API 类型匹配的原生 Vertex API
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:countTokens"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置原生 Vertex 格式的请求体
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 注意在单元测试环境中由于私钥无效JWT 创建失败会返回 ActionContinue
// 在真实环境中,如果 JWT 创建成功,会返回 ActionPause 等待 OAuth token
// 这里我们只验证代码正确进入了 Vertex Raw 模式的处理分支
require.Equal(t, types.ActionContinue, action)
// 验证请求体被透传(不做格式转换)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 请求体应该保持原样(这是 Vertex Raw 模式的核心功能)
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
})
// 测试 Vertex Raw 模式请求体处理(带 basePath 前缀 - 路径正确处理)
t.Run("vertex raw mode with basePath - request body passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeWithBasePathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头(带 basePath 前缀)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置原生 Vertex 格式的请求体(图片生成)
requestBody := `{"instances":[{"prompt":"A beautiful sunset"}],"parameters":{"sampleCount":1}}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// Express Mode 不需要暂停等待 OAuth token
require.Equal(t, types.ActionContinue, action)
// 验证请求体被透传
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
// 验证路径已正确处理(移除 basePath
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed")
})
// 测试 Vertex Raw 模式请求体处理(流式请求)
t.Run("vertex raw mode express - streaming request body passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeExpressConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头(流式端点)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置原生 Vertex 格式的请求体
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Tell me a story"}]}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被透传
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
})
})
}
func RunVertexRawModeOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Raw 模式响应体处理(透传响应)
t.Run("vertex raw mode express - response body passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexRawModeExpressConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置原生 Vertex 格式的响应体
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "Hello! How can I help you?"}]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 10,
"totalTokenCount": 15
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体被透传(不做格式转换)
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
// 响应应该保持原生 Vertex 格式
require.Contains(t, responseStr, "candidates", "Response should keep native vertex format with candidates")
require.Contains(t, responseStr, "usageMetadata", "Response should keep native vertex format with usageMetadata")
})
})
}