add support for image generation in Vertex AI provider (#3335)

This commit is contained in:
woody
2026-01-19 16:40:29 +08:00
committed by GitHub
parent ac69eb5b27
commit 399d2f372e
5 changed files with 848 additions and 27 deletions

View File

@@ -886,3 +886,348 @@ func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) {
})
})
}
// ==================== 图片生成测试 ====================
func RunVertexExpressModeImageGenerationRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 图片生成请求体处理
t.Run("vertex express mode image generation request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体OpenAI 图片生成格式)
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute orange cat napping in the sunshine","size":"1024x1024"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// Express Mode 不需要暂停等待 OAuth token
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证请求体被转换为 Vertex 格式
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "contents", "Request should be converted to vertex format with contents")
require.Contains(t, bodyStr, "generationConfig", "Request should contain generationConfig")
require.Contains(t, bodyStr, "responseModalities", "Request should contain responseModalities for image generation")
require.Contains(t, bodyStr, "IMAGE", "Request should specify IMAGE in responseModalities")
require.Contains(t, bodyStr, "imageConfig", "Request should contain imageConfig")
// 验证路径包含 API Key 和正确的模型
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format")
require.Contains(t, pathHeader, "generateContent", "Path should use generateContent action for image generation")
require.NotContains(t, pathHeader, "streamGenerateContent", "Path should NOT use streaming for image generation")
})
// 测试 Vertex Express Mode 图片生成请求体处理(自定义尺寸)
t.Run("vertex express mode image generation with custom size", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体(宽屏尺寸)
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A beautiful sunset over the ocean","size":"1792x1024"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否正确处理尺寸映射
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
// 1792x1024 应该映射为 16:9 宽高比
require.Contains(t, bodyStr, "aspectRatio", "Request should contain aspectRatio in imageConfig")
require.Contains(t, bodyStr, "16:9", "Request should map 1792x1024 to 16:9 aspect ratio")
})
// 测试 Vertex Express Mode 图片生成请求体处理(含安全设置)
t.Run("vertex express mode image generation with safety settings", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A mountain landscape"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体包含安全设置
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "safetySettings", "Request should contain safetySettings")
})
// 测试 Vertex Express Mode 图片生成请求体处理(含模型映射)
t.Run("vertex express mode image generation with model mapping", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体(使用映射前的模型名称)
requestBody := `{"model":"gpt-4","prompt":"A futuristic city"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证路径中使用了映射后的模型名称
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
})
})
}
func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 图片生成响应体处理
t.Run("vertex express mode image generation response body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute cat"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性确保IsResponseFromUpstream()返回true
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体Vertex 图片生成格式)
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"inlineData": {
"mimeType": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 1024,
"totalTokenCount": 1034
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
// 验证响应体被转换为 OpenAI 图片生成格式
require.Contains(t, responseStr, "created", "Response should contain created field")
require.Contains(t, responseStr, "data", "Response should contain data array")
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field with base64 image data")
require.Contains(t, responseStr, "usage", "Response should contain usage information")
require.Contains(t, responseStr, "total_tokens", "Response should contain total_tokens in usage")
})
// 测试 Vertex Express Mode 图片生成响应体处理(跳过思考过程)
t.Run("vertex express mode image generation response body - skip thinking", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-3-pro-image-preview","prompt":"An Eiffel tower"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体(包含思考过程和图片)
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [
{
"text": "Considering visual elements...",
"thought": true
},
{
"inlineData": {
"mimeType": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}
]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 13,
"candidatesTokenCount": 1120,
"totalTokenCount": 1356,
"thoughtsTokenCount": 223
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
// 验证响应体只包含图片数据,不包含思考过程文本
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
require.NotContains(t, responseStr, "Considering visual elements", "Response should NOT contain thinking text")
require.NotContains(t, responseStr, "thought", "Response should NOT contain thought field")
})
// 测试 Vertex Express Mode 图片生成响应体处理(空图片数据)
t.Run("vertex express mode image generation response body - no image", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"test"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体(只有文本,没有图片)
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"text": "I cannot generate that image."
}]
},
"finishReason": "SAFETY"
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 10,
"totalTokenCount": 15
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理(即使没有图片)
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
// 验证响应体结构正确data 数组为空
require.Contains(t, responseStr, "created", "Response should contain created field")
require.Contains(t, responseStr, "data", "Response should contain data array")
})
})
}