diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index a62172372..56877a21e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -26,6 +26,8 @@ description: AI 代理插件配置参考 > 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body,再转换为对应 LLM 厂商的文本向量协议 +> 请求路径后缀匹配 `/v1/images/generations` 时,对应文生图场景,会用 OpenAI 的图片生成协议解析请求 Body,再转换为对应 LLM 厂商的图片生成协议 + ## 运行属性 插件执行阶段:`默认阶段` @@ -2164,6 +2166,108 @@ provider: } ``` +### 使用 OpenAI 协议代理 Google Vertex 图片生成服务 + +Vertex AI 支持使用 Gemini 模型进行图片生成。通过 ai-proxy 插件,可以使用 OpenAI 的 `/v1/images/generations` 接口协议来调用 Vertex AI 的图片生成能力。 + +**配置信息** + +```yaml +provider: + type: vertex + apiTokens: + - "YOUR_API_KEY" + modelMapping: + "dall-e-3": "gemini-2.0-flash-exp" + geminiSafetySetting: + HARM_CATEGORY_HARASSMENT: "OFF" + HARM_CATEGORY_HATE_SPEECH: "OFF" + HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF" + HARM_CATEGORY_DANGEROUS_CONTENT: "OFF" +``` + +**使用 curl 请求** + +```bash +curl -X POST "http://your-gateway-address/v1/images/generations" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.0-flash-exp", + "prompt": "一只可爱的橘猫在阳光下打盹", + "size": "1024x1024" + }' +``` + +**使用 OpenAI Python SDK** + +```python +from openai import OpenAI + +client = OpenAI( + api_key="any-value", # 可以是任意值,认证由网关处理 + base_url="http://your-gateway-address/v1" +) + +response = client.images.generate( + model="gemini-2.0-flash-exp", + prompt="一只可爱的橘猫在阳光下打盹", + size="1024x1024", + n=1 +) + +# 获取生成的图片(base64 编码) +image_data = response.data[0].b64_json +print(f"Generated image (base64): {image_data[:100]}...") +``` + +**响应示例** + +```json +{ + "created": 1729986750, + "data": [ + { + "b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..." + } + ], + "usage": { + "total_tokens": 1356, + "input_tokens": 13, + "output_tokens": 1120 + } +} +``` + +**支持的尺寸参数** + +Vertex AI 支持的宽高比(aspectRatio):`1:1`、`3:2`、`2:3`、`3:4`、`4:3`、`4:5`、`5:4`、`9:16`、`16:9`、`21:9` + +Vertex AI 支持的分辨率(imageSize):`1k`、`2k`、`4k` + +| OpenAI size 参数 | Vertex AI aspectRatio | Vertex AI imageSize | +|------------------|----------------------|---------------------| +| 256x256 | 1:1 | 1k | +| 512x512 | 1:1 | 1k | +| 1024x1024 | 1:1 | 1k | +| 1792x1024 | 16:9 | 2k | +| 1024x1792 | 9:16 | 2k | +| 2048x2048 | 1:1 | 2k | +| 4096x4096 | 1:1 | 4k | +| 1536x1024 | 3:2 | 2k | +| 1024x1536 | 2:3 | 2k | +| 1024x768 | 4:3 | 1k | +| 768x1024 | 3:4 | 1k | +| 1280x1024 | 5:4 | 1k | +| 1024x1280 | 4:5 | 1k | +| 2560x1080 | 21:9 | 2k | + +**注意事项** + +- 图片生成使用 Gemini 模型(如 `gemini-2.0-flash-exp`、`gemini-3-pro-image-preview`),不同模型的可用性可能因区域而异 +- 返回的图片数据为 base64 编码格式(`b64_json`) +- 可以通过 `geminiSafetySetting` 配置内容安全过滤级别 +- 如果需要使用模型映射(如将 `dall-e-3` 映射到 Gemini 模型),可以配置 `modelMapping` + ### 使用 OpenAI 协议代理 AWS Bedrock 服务 AWS Bedrock 支持两种认证方式: diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index b1831f9dc..3a736753a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -25,6 +25,8 @@ The plugin now supports **automatic protocol detection**, allowing seamless comp > When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol. +> When the request path suffix matches `/v1/images/generations`, it corresponds to text-to-image scenarios. The request body will be parsed using OpenAI's image generation protocol and then converted to the corresponding LLM vendor's image generation protocol. + ## Execution Properties Plugin execution phase: `Default Phase` Plugin execution priority: `100` @@ -1927,6 +1929,108 @@ provider: } ``` +### Utilizing OpenAI Protocol Proxy for Google Vertex Image Generation + +Vertex AI supports image generation using Gemini models. Through the ai-proxy plugin, you can use OpenAI's `/v1/images/generations` API to call Vertex AI's image generation capabilities. + +**Configuration Information** + +```yaml +provider: + type: vertex + apiTokens: + - "YOUR_API_KEY" + modelMapping: + "dall-e-3": "gemini-2.0-flash-exp" + geminiSafetySetting: + HARM_CATEGORY_HARASSMENT: "OFF" + HARM_CATEGORY_HATE_SPEECH: "OFF" + HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF" + HARM_CATEGORY_DANGEROUS_CONTENT: "OFF" +``` + +**Using curl** + +```bash +curl -X POST "http://your-gateway-address/v1/images/generations" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.0-flash-exp", + "prompt": "A cute orange cat napping in the sunshine", + "size": "1024x1024" + }' +``` + +**Using OpenAI Python SDK** + +```python +from openai import OpenAI + +client = OpenAI( + api_key="any-value", # Can be any value, authentication is handled by the gateway + base_url="http://your-gateway-address/v1" +) + +response = client.images.generate( + model="gemini-2.0-flash-exp", + prompt="A cute orange cat napping in the sunshine", + size="1024x1024", + n=1 +) + +# Get the generated image (base64 encoded) +image_data = response.data[0].b64_json +print(f"Generated image (base64): {image_data[:100]}...") +``` + +**Response Example** + +```json +{ + "created": 1729986750, + "data": [ + { + "b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..." + } + ], + "usage": { + "total_tokens": 1356, + "input_tokens": 13, + "output_tokens": 1120 + } +} +``` + +**Supported Size Parameters** + +Vertex AI supported aspect ratios: `1:1`, `3:2`, `2:3`, `3:4`, `4:3`, `4:5`, `5:4`, `9:16`, `16:9`, `21:9` + +Vertex AI supported resolutions (imageSize): `1k`, `2k`, `4k` + +| OpenAI size parameter | Vertex AI aspectRatio | Vertex AI imageSize | +|-----------------------|----------------------|---------------------| +| 256x256 | 1:1 | 1k | +| 512x512 | 1:1 | 1k | +| 1024x1024 | 1:1 | 1k | +| 1792x1024 | 16:9 | 2k | +| 1024x1792 | 9:16 | 2k | +| 2048x2048 | 1:1 | 2k | +| 4096x4096 | 1:1 | 4k | +| 1536x1024 | 3:2 | 2k | +| 1024x1536 | 2:3 | 2k | +| 1024x768 | 4:3 | 1k | +| 768x1024 | 3:4 | 1k | +| 1280x1024 | 5:4 | 1k | +| 1024x1280 | 4:5 | 1k | +| 2560x1080 | 21:9 | 2k | + +**Notes** + +- Image generation uses Gemini models (e.g., `gemini-2.0-flash-exp`, `gemini-3-pro-image-preview`). Model availability may vary by region +- The returned image data is in base64 encoded format (`b64_json`) +- Content safety filtering levels can be configured via `geminiSafetySetting` +- If you need model mapping (e.g., mapping `dall-e-3` to a Gemini model), configure `modelMapping` + ### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services AWS Bedrock supports two authentication methods: diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 93ded731d..fb2b4fbe9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -135,6 +135,8 @@ func TestVertex(t *testing.T) { test.RunVertexExpressModeOnHttpRequestBodyTests(t) test.RunVertexExpressModeOnHttpResponseBodyTests(t) test.RunVertexExpressModeOnStreamingResponseBodyTests(t) + test.RunVertexExpressModeImageGenerationRequestBodyTests(t) + test.RunVertexExpressModeImageGenerationResponseBodyTests(t) } func TestBedrock(t *testing.T) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 38bea82f7..4598f9b6e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -89,8 +89,9 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): vertexPathTemplate, - string(ApiNameEmbeddings): vertexPathTemplate, + string(ApiNameChatCompletion): vertexPathTemplate, + string(ApiNameEmbeddings): vertexPathTemplate, + string(ApiNameImageGeneration): vertexPathTemplate, } } @@ -265,10 +266,15 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return v.onChatCompletionRequestBody(ctx, body, headers) - } else { + case ApiNameEmbeddings: return v.onEmbeddingsRequestBody(ctx, body, headers) + case ApiNameImageGeneration: + return v.onImageGenerationRequestBody(ctx, body, headers) + default: + return body, nil } } @@ -338,6 +344,119 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ return json.Marshal(vertexRequest) } +func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &imageGenerationRequest{} + if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + // 图片生成不使用流式端点,需要完整响应 + path := v.getRequestPath(ApiNameImageGeneration, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) + + vertexRequest := v.buildVertexImageGenerationRequest(request) + return json.Marshal(vertexRequest) +} + +func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest { + // 构建安全设置 + safetySettings := make([]vertexChatSafetySetting, 0) + for category, threshold := range v.config.geminiSafetySetting { + safetySettings = append(safetySettings, vertexChatSafetySetting{ + Category: category, + Threshold: threshold, + }) + } + + // 解析尺寸参数 + aspectRatio, imageSize := v.parseImageSize(request.Size) + + // 确定输出 MIME 类型 + mimeType := "image/png" + if request.OutputFormat != "" { + switch request.OutputFormat { + case "jpeg", "jpg": + mimeType = "image/jpeg" + case "webp": + mimeType = "image/webp" + default: + mimeType = "image/png" + } + } + + vertexRequest := &vertexChatRequest{ + Contents: []vertexChatContent{{ + Role: roleUser, + Parts: []vertexPart{{ + Text: request.Prompt, + }}, + }}, + SafetySettings: safetySettings, + GenerationConfig: vertexChatGenerationConfig{ + Temperature: 1.0, + MaxOutputTokens: 32768, + ResponseModalities: []string{"TEXT", "IMAGE"}, + ImageConfig: &vertexImageConfig{ + AspectRatio: aspectRatio, + ImageSize: imageSize, + ImageOutputOptions: &vertexImageOutputOptions{ + MimeType: mimeType, + }, + PersonGeneration: "ALLOW_ALL", + }, + }, + } + + return vertexRequest +} + +// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize +// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9 +// Vertex AI 支持的 imageSize: 1k, 2k, 4k +func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) { + // 默认值 + aspectRatio = "1:1" + imageSize = "1k" + + if size == "" { + return + } + + // 预定义的尺寸映射(OpenAI 标准尺寸) + sizeMapping := map[string]struct { + aspectRatio string + imageSize string + }{ + // OpenAI DALL-E 标准尺寸 + "256x256": {"1:1", "1k"}, + "512x512": {"1:1", "1k"}, + "1024x1024": {"1:1", "1k"}, + "1792x1024": {"16:9", "2k"}, + "1024x1792": {"9:16", "2k"}, + // 扩展尺寸支持 + "2048x2048": {"1:1", "2k"}, + "4096x4096": {"1:1", "4k"}, + // 3:2 和 2:3 比例 + "1536x1024": {"3:2", "2k"}, + "1024x1536": {"2:3", "2k"}, + // 4:3 和 3:4 比例 + "1024x768": {"4:3", "1k"}, + "768x1024": {"3:4", "1k"}, + "1365x1024": {"4:3", "1k"}, + "1024x1365": {"3:4", "1k"}, + // 5:4 和 4:5 比例 + "1280x1024": {"5:4", "1k"}, + "1024x1280": {"4:5", "1k"}, + // 21:9 超宽比例 + "2560x1080": {"21:9", "2k"}, + } + + if mapping, ok := sizeMapping[size]; ok { + return mapping.aspectRatio, mapping.imageSize + } + + return +} + func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { // OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列 // Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX @@ -394,10 +513,16 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.TransformResponseBody(ctx, apiName, body) } - if apiName == ApiNameChatCompletion { + + switch apiName { + case ApiNameChatCompletion: return v.onChatCompletionResponseBody(ctx, body) - } else { + case ApiNameEmbeddings: return v.onEmbeddingsResponseBody(ctx, body) + case ApiNameImageGeneration: + return v.onImageGenerationResponseBody(ctx, body) + default: + return body, nil } } @@ -490,6 +615,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex return &response } +func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + // 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据 + // 这样可以显著减少内存分配和复制次数 + response := v.buildImageGenerationResponseFromJSON(body) + return json.Marshal(response) +} + +// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应 +// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高 +func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse { + result := gjson.ParseBytes(body) + data := make([]imageGenerationData, 0) + + // 遍历所有 candidates,提取图片数据 + candidates := result.Get("candidates") + candidates.ForEach(func(_, candidate gjson.Result) bool { + parts := candidate.Get("content.parts") + parts.ForEach(func(_, part gjson.Result) bool { + // 跳过思考过程 (thought: true) + if part.Get("thought").Bool() { + return true + } + // 提取图片数据 + inlineData := part.Get("inlineData.data") + if inlineData.Exists() && inlineData.String() != "" { + data = append(data, imageGenerationData{ + B64: inlineData.String(), + }) + } + return true + }) + return true + }) + + // 提取 usage 信息 + usage := result.Get("usageMetadata") + + return &imageGenerationResponse{ + Created: time.Now().UnixMilli() / 1000, + Data: data, + Usage: &imageGenerationUsage{ + TotalTokens: int(usage.Get("totalTokenCount").Int()), + InputTokens: int(usage.Get("promptTokenCount").Int()), + OutputTokens: int(usage.Get("candidatesTokenCount").Int()), + }, + } +} + func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse { var choice chatCompletionChoice choice.Delta = &chatMessage{} @@ -574,12 +747,18 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { action := "" - if apiName == ApiNameEmbeddings { + switch apiName { + case ApiNameEmbeddings: action = vertexEmbeddingAction - } else if stream { - action = vertexChatCompletionStreamAction - } else { + case ApiNameImageGeneration: + // 图片生成使用非流式端点,需要完整响应 action = vertexChatCompletionAction + default: + if stream { + action = vertexChatCompletionStreamAction + } else { + action = vertexChatCompletionAction + } } if v.isExpressMode() { @@ -689,7 +868,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) }) } case contentTypeImageUrl: - vpart, err := convertImageContent(part.ImageUrl.Url) + vpart, err := convertMediaContent(part.ImageUrl.Url) if err != nil { log.Errorf("unable to convert image content: %v", err) } else { @@ -804,12 +983,25 @@ type vertexChatSafetySetting struct { } type vertexChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"` +} + +type vertexImageConfig struct { + AspectRatio string `json:"aspectRatio,omitempty"` + ImageSize string `json:"imageSize,omitempty"` + ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"` + PersonGeneration string `json:"personGeneration,omitempty"` +} + +type vertexImageOutputOptions struct { + MimeType string `json:"mimeType,omitempty"` } type vertexThinkingConfig struct { @@ -1020,32 +1212,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro return proxywasm.SetSharedData(key, data, cas) } -func convertImageContent(imageUrl string) (vertexPart, error) { +// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式 +// 支持图片、视频、音频等多种媒体类型 +func convertMediaContent(mediaUrl string) (vertexPart, error) { part := vertexPart{} - if strings.HasPrefix(imageUrl, "http") { - arr := strings.Split(imageUrl, ".") - mimeType := "image/" + arr[len(arr)-1] + if strings.HasPrefix(mediaUrl, "http") { + mimeType := detectMimeTypeFromURL(mediaUrl) part.FileData = &fileData{ MimeType: mimeType, - FileUri: imageUrl, + FileUri: mediaUrl, } return part, nil } else { + // Base64 data URL 格式: data:;base64, re := regexp.MustCompile(`^data:([^;]+);base64,`) - matches := re.FindStringSubmatch(imageUrl) + matches := re.FindStringSubmatch(mediaUrl) if len(matches) < 2 { - return part, fmt.Errorf("invalid base64 format") + return part, fmt.Errorf("invalid base64 format, expected data:;base64,") } - mimeType := matches[1] // e.g. image/png + mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3 parts := strings.Split(mimeType, "/") if len(parts) < 2 { - return part, fmt.Errorf("invalid mimeType") + return part, fmt.Errorf("invalid mimeType: %s", mimeType) } part.InlineData = &blob{ MimeType: mimeType, - Data: strings.TrimPrefix(imageUrl, matches[0]), + Data: strings.TrimPrefix(mediaUrl, matches[0]), } return part, nil } } + +// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型 +// 支持图片、视频、音频和文档类型 +func detectMimeTypeFromURL(url string) string { + // 移除查询参数和片段标识符 + if idx := strings.Index(url, "?"); idx != -1 { + url = url[:idx] + } + if idx := strings.Index(url, "#"); idx != -1 { + url = url[:idx] + } + + // 获取最后一个路径段 + lastSlash := strings.LastIndex(url, "/") + if lastSlash != -1 { + url = url[lastSlash+1:] + } + + // 获取扩展名 + lastDot := strings.LastIndex(url, ".") + if lastDot == -1 || lastDot == len(url)-1 { + return "application/octet-stream" + } + ext := strings.ToLower(url[lastDot+1:]) + + // 扩展名到 MIME 类型的映射 + mimeTypes := map[string]string{ + // 图片格式 + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "webp": "image/webp", + "bmp": "image/bmp", + "svg": "image/svg+xml", + "ico": "image/x-icon", + "heic": "image/heic", + "heif": "image/heif", + "tiff": "image/tiff", + "tif": "image/tiff", + // 视频格式 + "mp4": "video/mp4", + "mpeg": "video/mpeg", + "mpg": "video/mpeg", + "mov": "video/quicktime", + "avi": "video/x-msvideo", + "wmv": "video/x-ms-wmv", + "webm": "video/webm", + "mkv": "video/x-matroska", + "flv": "video/x-flv", + "3gp": "video/3gpp", + "3g2": "video/3gpp2", + "m4v": "video/x-m4v", + // 音频格式 + "mp3": "audio/mpeg", + "wav": "audio/wav", + "ogg": "audio/ogg", + "flac": "audio/flac", + "aac": "audio/aac", + "m4a": "audio/mp4", + "wma": "audio/x-ms-wma", + "opus": "audio/opus", + // 文档格式 + "pdf": "application/pdf", + } + + if mimeType, ok := mimeTypes[ext]; ok { + return mimeType + } + + return "application/octet-stream" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go index eb0fabf8c..312e70260 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -886,3 +886,348 @@ func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) { }) }) } + +// ==================== 图片生成测试 ==================== + +func RunVertexExpressModeImageGenerationRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 图片生成请求体处理 + t.Run("vertex express mode image generation request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(OpenAI 图片生成格式) + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute orange cat napping in the sunshine","size":"1024x1024"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // Express Mode 不需要暂停等待 OAuth token + require.Equal(t, types.ActionContinue, action) + + // 验证请求体是否被正确处理 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 验证请求体被转换为 Vertex 格式 + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "contents", "Request should be converted to vertex format with contents") + require.Contains(t, bodyStr, "generationConfig", "Request should contain generationConfig") + require.Contains(t, bodyStr, "responseModalities", "Request should contain responseModalities for image generation") + require.Contains(t, bodyStr, "IMAGE", "Request should specify IMAGE in responseModalities") + require.Contains(t, bodyStr, "imageConfig", "Request should contain imageConfig") + + // 验证路径包含 API Key 和正确的模型 + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter") + require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format") + require.Contains(t, pathHeader, "generateContent", "Path should use generateContent action for image generation") + require.NotContains(t, pathHeader, "streamGenerateContent", "Path should NOT use streaming for image generation") + }) + + // 测试 Vertex Express Mode 图片生成请求体处理(自定义尺寸) + t.Run("vertex express mode image generation with custom size", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(宽屏尺寸) + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A beautiful sunset over the ocean","size":"1792x1024"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证请求体是否正确处理尺寸映射 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + bodyStr := string(processedBody) + // 1792x1024 应该映射为 16:9 宽高比 + require.Contains(t, bodyStr, "aspectRatio", "Request should contain aspectRatio in imageConfig") + require.Contains(t, bodyStr, "16:9", "Request should map 1792x1024 to 16:9 aspect ratio") + }) + + // 测试 Vertex Express Mode 图片生成请求体处理(含安全设置) + t.Run("vertex express mode image generation with safety settings", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A mountain landscape"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证请求体包含安全设置 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "safetySettings", "Request should contain safetySettings") + }) + + // 测试 Vertex Express Mode 图片生成请求体处理(含模型映射) + t.Run("vertex express mode image generation with model mapping", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(使用映射前的模型名称) + requestBody := `{"model":"gpt-4","prompt":"A futuristic city"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证路径中使用了映射后的模型名称 + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name") + }) + }) +} + +func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex Express Mode 图片生成响应体处理 + t.Run("vertex express mode image generation response body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute cat"}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性,确保IsResponseFromUpstream()返回true + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置响应体(Vertex 图片生成格式) + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "inlineData": { + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + }] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 1024, + "totalTokenCount": 1034 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体是否被正确处理 + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + + // 验证响应体被转换为 OpenAI 图片生成格式 + require.Contains(t, responseStr, "created", "Response should contain created field") + require.Contains(t, responseStr, "data", "Response should contain data array") + require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field with base64 image data") + require.Contains(t, responseStr, "usage", "Response should contain usage information") + require.Contains(t, responseStr, "total_tokens", "Response should contain total_tokens in usage") + }) + + // 测试 Vertex Express Mode 图片生成响应体处理(跳过思考过程) + t.Run("vertex express mode image generation response body - skip thinking", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-3-pro-image-preview","prompt":"An Eiffel tower"}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性 + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置响应体(包含思考过程和图片) + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [ + { + "text": "Considering visual elements...", + "thought": true + }, + { + "inlineData": { + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + } + ] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 13, + "candidatesTokenCount": 1120, + "totalTokenCount": 1356, + "thoughtsTokenCount": 223 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体是否被正确处理 + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + + // 验证响应体只包含图片数据,不包含思考过程文本 + require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field") + require.NotContains(t, responseStr, "Considering visual elements", "Response should NOT contain thinking text") + require.NotContains(t, responseStr, "thought", "Response should NOT contain thought field") + }) + + // 测试 Vertex Express Mode 图片生成响应体处理(空图片数据) + t.Run("vertex express mode image generation response body - no image", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/generations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"test"}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性 + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置响应体(只有文本,没有图片) + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "text": "I cannot generate that image." + }] + }, + "finishReason": "SAFETY" + }], + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 10, + "totalTokenCount": 15 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体是否被正确处理(即使没有图片) + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + + // 验证响应体结构正确,data 数组为空 + require.Contains(t, responseStr, "created", "Response should contain created field") + require.Contains(t, responseStr, "data", "Response should contain data array") + }) + }) +}