diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index a03516a15..a1c213f99 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } } - if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) { + if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) { ctx.DontReadRequestBody() - log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType) + log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType) } if apiName == "" { @@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if err == nil { return action } + log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err) _ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue @@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName { return "" } + +func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool { + if strings.Contains(contentType, util.MimeTypeApplicationJson) { + return true + } + contentType = strings.ToLower(contentType) + if strings.HasPrefix(contentType, "multipart/form-data") { + return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation + } + return false +} diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 947a57fdd..acd699c30 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) { } } +func Test_isSupportedRequestContentType(t *testing.T) { + tests := []struct { + name string + apiName provider.ApiName + contentType string + want bool + }{ + { + name: "json chat completion", + apiName: provider.ApiNameChatCompletion, + contentType: "application/json", + want: true, + }, + { + name: "multipart image edit", + apiName: provider.ApiNameImageEdit, + contentType: "multipart/form-data; boundary=----boundary", + want: true, + }, + { + name: "multipart image variation", + apiName: provider.ApiNameImageVariation, + contentType: "multipart/form-data; boundary=----boundary", + want: true, + }, + { + name: "multipart chat completion", + apiName: provider.ApiNameChatCompletion, + contentType: "multipart/form-data; boundary=----boundary", + want: false, + }, + { + name: "text plain image edit", + apiName: provider.ApiNameImageEdit, + contentType: "text/plain", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSupportedRequestContentType(tt.apiName, tt.contentType) + if got != tt.want { + t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want) + } + }) + } +} + func TestAi360(t *testing.T) { test.RunAi360ParseConfigTests(t) test.RunAi360OnHttpRequestHeadersTests(t) @@ -137,6 +185,8 @@ func TestVertex(t *testing.T) { test.RunVertexExpressModeOnStreamingResponseBodyTests(t) test.RunVertexExpressModeImageGenerationRequestBodyTests(t) test.RunVertexExpressModeImageGenerationResponseBodyTests(t) + test.RunVertexExpressModeImageEditVariationRequestBodyTests(t) + test.RunVertexExpressModeImageEditVariationResponseBodyTests(t) // Vertex Raw 模式测试 test.RunVertexRawModeOnHttpRequestHeadersTests(t) test.RunVertexRawModeOnHttpRequestBodyTests(t) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 465bccc3b..b231c739d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -1,6 +1,7 @@ package provider import ( + "encoding/json" "fmt" "strings" @@ -461,6 +462,122 @@ type imageGenerationRequest struct { Size string `json:"size,omitempty"` } +type imageInputURL struct { + URL string `json:"url,omitempty"` + ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"` +} + +func (i *imageInputURL) UnmarshalJSON(data []byte) error { + // Support a plain string payload, e.g. "data:image/png;base64,..." + var rawURL string + if err := json.Unmarshal(data, &rawURL); err == nil { + i.URL = rawURL + i.ImageURL = nil + return nil + } + + type alias imageInputURL + var value alias + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *i = imageInputURL(value) + return nil +} + +func (i *imageInputURL) GetURL() string { + if i == nil { + return "" + } + if i.ImageURL != nil && i.ImageURL.Url != "" { + return i.ImageURL.Url + } + return i.URL +} + +type imageEditRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Image *imageInputURL `json:"image,omitempty"` + Images []imageInputURL `json:"images,omitempty"` + ImageURL *imageInputURL `json:"image_url,omitempty"` + Mask *imageInputURL `json:"mask,omitempty"` + MaskURL *imageInputURL `json:"mask_url,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +func (r *imageEditRequest) GetImageURLs() []string { + urls := make([]string, 0, len(r.Images)+2) + for _, image := range r.Images { + if url := image.GetURL(); url != "" { + urls = append(urls, url) + } + } + if r.Image != nil { + if url := r.Image.GetURL(); url != "" { + urls = append(urls, url) + } + } + if r.ImageURL != nil { + if url := r.ImageURL.GetURL(); url != "" { + urls = append(urls, url) + } + } + return urls +} + +func (r *imageEditRequest) HasMask() bool { + if r.Mask != nil && r.Mask.GetURL() != "" { + return true + } + return r.MaskURL != nil && r.MaskURL.GetURL() != "" +} + +type imageVariationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Image *imageInputURL `json:"image,omitempty"` + Images []imageInputURL `json:"images,omitempty"` + ImageURL *imageInputURL `json:"image_url,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +func (r *imageVariationRequest) GetImageURLs() []string { + urls := make([]string, 0, len(r.Images)+2) + for _, image := range r.Images { + if url := image.GetURL(); url != "" { + urls = append(urls, url) + } + } + if r.Image != nil { + if url := r.Image.GetURL(); url != "" { + urls = append(urls, url) + } + } + if r.ImageURL != nil { + if url := r.ImageURL.GetURL(); url != "" { + urls = append(urls, url) + } + } + return urls +} + type imageGenerationData struct { URL string `json:"url,omitempty"` B64 string `json:"b64_json,omitempty"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go new file mode 100644 index 000000000..a54f9282d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go @@ -0,0 +1,156 @@ +package provider + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "strconv" + "strings" +) + +type multipartImageRequest struct { + Model string + Prompt string + Size string + OutputFormat string + N int + ImageURLs []string + HasMask bool +} + +func isMultipartFormData(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + return strings.EqualFold(mediaType, "multipart/form-data") +} + +func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, fmt.Errorf("unable to parse content-type: %v", err) + } + boundary := params["boundary"] + if boundary == "" { + return nil, fmt.Errorf("missing multipart boundary") + } + + req := &multipartImageRequest{ + ImageURLs: make([]string, 0), + } + reader := multipart.NewReader(bytes.NewReader(body), boundary) + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("unable to read multipart part: %v", err) + } + fieldName := part.FormName() + if fieldName == "" { + _ = part.Close() + continue + } + partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) + + partData, err := io.ReadAll(part) + _ = part.Close() + if err != nil { + return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err) + } + + value := strings.TrimSpace(string(partData)) + switch fieldName { + case "model": + req.Model = value + continue + case "prompt": + req.Prompt = value + continue + case "size": + req.Size = value + continue + case "output_format": + req.OutputFormat = value + continue + case "n": + if value != "" { + if parsed, err := strconv.Atoi(value); err == nil { + req.N = parsed + } + } + continue + } + + if isMultipartImageField(fieldName) { + if isMultipartImageURLValue(value) { + req.ImageURLs = append(req.ImageURLs, value) + continue + } + if len(partData) == 0 { + continue + } + imageURL := buildMultipartDataURL(partContentType, partData) + req.ImageURLs = append(req.ImageURLs, imageURL) + continue + } + if isMultipartMaskField(fieldName) { + if len(partData) > 0 || value != "" { + req.HasMask = true + } + continue + } + } + + return req, nil +} + +func isMultipartImageField(fieldName string) bool { + return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[") +} + +func isMultipartMaskField(fieldName string) bool { + return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[") +} + +func isMultipartImageURLValue(value string) bool { + if value == "" { + return false + } + loweredValue := strings.ToLower(value) + return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://") +} + +func buildMultipartDataURL(contentType string, data []byte) string { + mimeType := strings.TrimSpace(contentType) + if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") { + mimeType = http.DetectContentType(data) + } + mimeType = normalizeMultipartMimeType(mimeType) + if mimeType == "" { + mimeType = "application/octet-stream" + } + encoded := base64.StdEncoding.EncodeToString(data) + return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded) +} + +func normalizeMultipartMimeType(contentType string) string { + contentType = strings.TrimSpace(contentType) + if contentType == "" { + return "" + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err == nil && mediaType != "" { + return strings.TrimSpace(mediaType) + } + if idx := strings.Index(contentType, ";"); idx > 0 { + return strings.TrimSpace(contentType[:idx]) + } + return contentType +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index cd4d1cab5..2518da321 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -763,19 +763,19 @@ func (c *ProviderConfig) GetRandomToken() string { func isStatefulAPI(apiName string) bool { // These APIs maintain session state and should be routed to the same provider consistently statefulAPIs := map[string]bool{ - string(ApiNameResponses): true, // Response API - uses previous_response_id - string(ApiNameFiles): true, // Files API - maintains file state - string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload - string(ApiNameRetrieveFileContent): true, // File content - depends on file upload - string(ApiNameBatches): true, // Batch API - maintains batch state - string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation - string(ApiNameCancelBatch): true, // Batch operations - depends on batch state - string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state - string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status - string(ApiNameFineTuningJobEvents): true, // Fine-tuning events - string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints - string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job - string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job + string(ApiNameResponses): true, // Response API - uses previous_response_id + string(ApiNameFiles): true, // Files API - maintains file state + string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload + string(ApiNameRetrieveFileContent): true, // File content - depends on file upload + string(ApiNameBatches): true, // Batch API - maintains batch state + string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation + string(ApiNameCancelBatch): true, // Batch operations - depends on batch state + string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state + string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status + string(ApiNameFineTuningJobEvents): true, // Fine-tuning events + string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints + string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job + string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job } return statefulAPIs[apiName] } @@ -845,6 +845,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques return err } return c.setRequestModel(ctx, req) + case *imageEditRequest: + if err := decodeImageEditRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req) + case *imageVariationRequest: + if err := decodeImageVariationRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req) default: return errors.New("unsupported request type") } @@ -860,6 +870,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf model = &req.Model case *imageGenerationRequest: model = &req.Model + case *imageEditRequest: + model = &req.Model + case *imageVariationRequest: + model = &req.Model default: return errors.New("unsupported request type") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index 55e496ab7..3e3a57aa4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" - "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" ) func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error { @@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) return nil } +func decodeImageEditRequest(body []byte, request *imageEditRequest) error { + if err := json.Unmarshal(body, request); err != nil { + return fmt.Errorf("unable to unmarshal request: %v", err) + } + return nil +} + +func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error { + if err := json.Unmarshal(body, request); err != nil { + return fmt.Errorf("unable to unmarshal request: %v", err) + } + return nil +} + func replaceJsonRequestBody(request interface{}) error { body, err := json.Marshal(request) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 3791e06ef..b6a10ac36 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -46,6 +46,7 @@ const ( contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" contextVertexRawMarker = "isVertexRawRequest" vertexAnthropicVersion = "vertex-2023-10-16" + vertexImageVariationDefaultPrompt = "Create variations of the provided image." ) // vertexRawPathRegex 匹配原生 Vertex AI REST API 路径 @@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string { string(ApiNameChatCompletion): vertexPathTemplate, string(ApiNameEmbeddings): vertexPathTemplate, string(ApiNameImageGeneration): vertexPathTemplate, + string(ApiNameImageEdit): vertexPathTemplate, + string(ApiNameImageVariation): vertexPathTemplate, string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换 } } @@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap return v.onEmbeddingsRequestBody(ctx, body, headers) case ApiNameImageGeneration: return v.onImageGenerationRequestBody(ctx, body, headers) + case ApiNameImageEdit: + return v.onImageEditRequestBody(ctx, body, headers) + case ApiNameImageVariation: + return v.onImageVariationRequestBody(ctx, body, headers) default: return body, nil } @@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b path := v.getRequestPath(ApiNameImageGeneration, request.Model, false) util.OverwriteRequestPathHeader(headers, path) - vertexRequest := v.buildVertexImageGenerationRequest(request) + vertexRequest, err := v.buildVertexImageGenerationRequest(request) + if err != nil { + return nil, err + } return json.Marshal(vertexRequest) } -func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest { +func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &imageEditRequest{} + imageURLs := make([]string, 0) + contentType := headers.Get("Content-Type") + if isMultipartFormData(contentType) { + parsedRequest, err := parseMultipartImageRequest(body, contentType) + if err != nil { + return nil, err + } + request.Model = parsedRequest.Model + request.Prompt = parsedRequest.Prompt + request.Size = parsedRequest.Size + request.OutputFormat = parsedRequest.OutputFormat + request.N = parsedRequest.N + imageURLs = parsedRequest.ImageURLs + if err := v.config.mapModel(ctx, &request.Model); err != nil { + return nil, err + } + if parsedRequest.HasMask { + return nil, fmt.Errorf("mask is not supported for vertex image edits yet") + } + } else { + if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + if request.HasMask() { + return nil, fmt.Errorf("mask is not supported for vertex image edits yet") + } + imageURLs = request.GetImageURLs() + } + if len(imageURLs) == 0 { + return nil, fmt.Errorf("missing image_url in request") + } + if request.Prompt == "" { + return nil, fmt.Errorf("missing prompt in request") + } + + path := v.getRequestPath(ApiNameImageEdit, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) + headers.Set("Content-Type", util.MimeTypeApplicationJson) + vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs) + if err != nil { + return nil, err + } + return json.Marshal(vertexRequest) +} + +func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &imageVariationRequest{} + imageURLs := make([]string, 0) + contentType := headers.Get("Content-Type") + if isMultipartFormData(contentType) { + parsedRequest, err := parseMultipartImageRequest(body, contentType) + if err != nil { + return nil, err + } + request.Model = parsedRequest.Model + request.Prompt = parsedRequest.Prompt + request.Size = parsedRequest.Size + request.OutputFormat = parsedRequest.OutputFormat + request.N = parsedRequest.N + imageURLs = parsedRequest.ImageURLs + if err := v.config.mapModel(ctx, &request.Model); err != nil { + return nil, err + } + } else { + if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + imageURLs = request.GetImageURLs() + } + if len(imageURLs) == 0 { + return nil, fmt.Errorf("missing image_url in request") + } + + prompt := request.Prompt + if prompt == "" { + prompt = vertexImageVariationDefaultPrompt + } + + path := v.getRequestPath(ApiNameImageVariation, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) + headers.Set("Content-Type", util.MimeTypeApplicationJson) + vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs) + if err != nil { + return nil, err + } + return json.Marshal(vertexRequest) +} + +func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) { + return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil) +} + +func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) { // 构建安全设置 safetySettings := make([]vertexChatSafetySetting, 0) for category, threshold := range v.config.geminiSafetySetting { @@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat } // 解析尺寸参数 - aspectRatio, imageSize := v.parseImageSize(request.Size) + aspectRatio, imageSize := v.parseImageSize(size) // 确定输出 MIME 类型 mimeType := "image/png" - if request.OutputFormat != "" { - switch request.OutputFormat { + if outputFormat != "" { + switch outputFormat { case "jpeg", "jpg": mimeType = "image/jpeg" case "webp": @@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat } } + parts := make([]vertexPart, 0, len(imageURLs)+1) + for _, imageURL := range imageURLs { + part, err := convertMediaContent(imageURL) + if err != nil { + return nil, err + } + parts = append(parts, part) + } + if prompt != "" { + parts = append(parts, vertexPart{ + Text: prompt, + }) + } + if len(parts) == 0 { + return nil, fmt.Errorf("missing prompt and image_url in request") + } + vertexRequest := &vertexChatRequest{ Contents: []vertexChatContent{{ - Role: roleUser, - Parts: []vertexPart{{ - Text: request.Prompt, - }}, + Role: roleUser, + Parts: parts, }}, SafetySettings: safetySettings, GenerationConfig: vertexChatGenerationConfig{ @@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat }, } - return vertexRequest + return vertexRequest, nil } // parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize @@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName return v.onChatCompletionResponseBody(ctx, body) case ApiNameEmbeddings: return v.onEmbeddingsResponseBody(ctx, body) - case ApiNameImageGeneration: + case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation: return v.onImageGenerationResponseBody(ctx, body) default: return body, nil @@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream switch apiName { case ApiNameEmbeddings: action = vertexEmbeddingAction - case ApiNameImageGeneration: + case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation: // 图片生成使用非流式端点,需要完整响应 action = vertexChatCompletionAction default: diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go index bc3b53c11..d57f86fad 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -1,7 +1,9 @@ package test import ( + "bytes" "encoding/json" + "mime/multipart" "strings" "testing" @@ -1273,6 +1275,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) { }) } +func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) { + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + + for key, value := range fields { + require.NoError(t, writer.WriteField(key, value)) + } + + for fieldName, data := range files { + part, err := writer.CreateFormFile(fieldName, "upload-image.png") + require.NoError(t, err) + _, err = part.Write(data) + require.NoError(t, err) + } + + require.NoError(t, writer.Close()) + return buffer.Bytes(), writer.FormDataContentType() +} + +func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url") + require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved") + require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format") + + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action") + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key") + }) + + t.Run("vertex express mode image edit request body with image string", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string") + require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved") + }) + + t.Run("vertex express mode image edit multipart request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + body, contentType := buildMultipartRequestBody(t, map[string]string{ + "model": "gemini-2.0-flash-exp", + "prompt": "Add sunglasses to the cat", + "size": "1024x1024", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", contentType}, + }) + + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData") + require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved") + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json") + }) + + t.Run("vertex express mode image variation multipart request body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + body, contentType := buildMultipartRequestBody(t, map[string]string{ + "model": "gemini-2.0-flash-exp", + "size": "1024x1024", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/variations"}, + {":method", "POST"}, + {"Content-Type", contentType}, + }) + + action := host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData") + require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt") + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json") + }) + + t.Run("vertex express mode image edit with model mapping", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name") + }) + + t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/variations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + bodyStr := string(processedBody) + require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url") + require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt") + + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action") + require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key") + }) + }) +} + +func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + t.Run("vertex express mode image edit response body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "inlineData": { + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + }] + } + }], + "usageMetadata": { + "promptTokenCount": 12, + "candidatesTokenCount": 1024, + "totalTokenCount": 1036 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + require.Equal(t, types.ActionContinue, action) + + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field") + require.Contains(t, responseStr, "usage", "Response should contain usage field") + }) + + t.Run("vertex express mode image variation response body", func(t *testing.T) { + host, status := test.NewTestHost(vertexExpressModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/variations"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + + responseBody := `{ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "inlineData": { + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + }] + } + }], + "usageMetadata": { + "promptTokenCount": 8, + "candidatesTokenCount": 768, + "totalTokenCount": 776 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + require.Equal(t, types.ActionContinue, action) + + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field") + require.Contains(t, responseStr, "usage", "Response should contain usage field") + }) + }) +} + // ==================== Vertex Raw 模式测试 ==================== func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {