[ai-proxy] vertex image edits & variations (#3536)

This commit is contained in:
woody
2026-02-27 10:18:30 +08:00
committed by GitHub
parent e9aecb6e1f
commit e2a22d1171
8 changed files with 830 additions and 28 deletions

View File

@@ -1,6 +1,7 @@
package provider
import (
"encoding/json"
"fmt"
"strings"
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
Size string `json:"size,omitempty"`
}
type imageInputURL struct {
URL string `json:"url,omitempty"`
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
}
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
// Support a plain string payload, e.g. "data:image/png;base64,..."
var rawURL string
if err := json.Unmarshal(data, &rawURL); err == nil {
i.URL = rawURL
i.ImageURL = nil
return nil
}
type alias imageInputURL
var value alias
if err := json.Unmarshal(data, &value); err != nil {
return err
}
*i = imageInputURL(value)
return nil
}
func (i *imageInputURL) GetURL() string {
if i == nil {
return ""
}
if i.ImageURL != nil && i.ImageURL.Url != "" {
return i.ImageURL.Url
}
return i.URL
}
type imageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image *imageInputURL `json:"image,omitempty"`
Images []imageInputURL `json:"images,omitempty"`
ImageURL *imageInputURL `json:"image_url,omitempty"`
Mask *imageInputURL `json:"mask,omitempty"`
MaskURL *imageInputURL `json:"mask_url,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
func (r *imageEditRequest) GetImageURLs() []string {
urls := make([]string, 0, len(r.Images)+2)
for _, image := range r.Images {
if url := image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.Image != nil {
if url := r.Image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.ImageURL != nil {
if url := r.ImageURL.GetURL(); url != "" {
urls = append(urls, url)
}
}
return urls
}
func (r *imageEditRequest) HasMask() bool {
if r.Mask != nil && r.Mask.GetURL() != "" {
return true
}
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
}
type imageVariationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
Image *imageInputURL `json:"image,omitempty"`
Images []imageInputURL `json:"images,omitempty"`
ImageURL *imageInputURL `json:"image_url,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
func (r *imageVariationRequest) GetImageURLs() []string {
urls := make([]string, 0, len(r.Images)+2)
for _, image := range r.Images {
if url := image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.Image != nil {
if url := r.Image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.ImageURL != nil {
if url := r.ImageURL.GetURL(); url != "" {
urls = append(urls, url)
}
}
return urls
}
type imageGenerationData struct {
URL string `json:"url,omitempty"`
B64 string `json:"b64_json,omitempty"`

View File

@@ -0,0 +1,156 @@
package provider
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
)
type multipartImageRequest struct {
Model string
Prompt string
Size string
OutputFormat string
N int
ImageURLs []string
HasMask bool
}
func isMultipartFormData(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
return strings.EqualFold(mediaType, "multipart/form-data")
}
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("unable to parse content-type: %v", err)
}
boundary := params["boundary"]
if boundary == "" {
return nil, fmt.Errorf("missing multipart boundary")
}
req := &multipartImageRequest{
ImageURLs: make([]string, 0),
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("unable to read multipart part: %v", err)
}
fieldName := part.FormName()
if fieldName == "" {
_ = part.Close()
continue
}
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
partData, err := io.ReadAll(part)
_ = part.Close()
if err != nil {
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
}
value := strings.TrimSpace(string(partData))
switch fieldName {
case "model":
req.Model = value
continue
case "prompt":
req.Prompt = value
continue
case "size":
req.Size = value
continue
case "output_format":
req.OutputFormat = value
continue
case "n":
if value != "" {
if parsed, err := strconv.Atoi(value); err == nil {
req.N = parsed
}
}
continue
}
if isMultipartImageField(fieldName) {
if isMultipartImageURLValue(value) {
req.ImageURLs = append(req.ImageURLs, value)
continue
}
if len(partData) == 0 {
continue
}
imageURL := buildMultipartDataURL(partContentType, partData)
req.ImageURLs = append(req.ImageURLs, imageURL)
continue
}
if isMultipartMaskField(fieldName) {
if len(partData) > 0 || value != "" {
req.HasMask = true
}
continue
}
}
return req, nil
}
func isMultipartImageField(fieldName string) bool {
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
}
func isMultipartMaskField(fieldName string) bool {
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
}
func isMultipartImageURLValue(value string) bool {
if value == "" {
return false
}
loweredValue := strings.ToLower(value)
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
}
func buildMultipartDataURL(contentType string, data []byte) string {
mimeType := strings.TrimSpace(contentType)
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
mimeType = http.DetectContentType(data)
}
mimeType = normalizeMultipartMimeType(mimeType)
if mimeType == "" {
mimeType = "application/octet-stream"
}
encoded := base64.StdEncoding.EncodeToString(data)
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
}
func normalizeMultipartMimeType(contentType string) string {
contentType = strings.TrimSpace(contentType)
if contentType == "" {
return ""
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err == nil && mediaType != "" {
return strings.TrimSpace(mediaType)
}
if idx := strings.Index(contentType, ";"); idx > 0 {
return strings.TrimSpace(contentType[:idx])
}
return contentType
}

View File

@@ -763,19 +763,19 @@ func (c *ProviderConfig) GetRandomToken() string {
func isStatefulAPI(apiName string) bool {
// These APIs maintain session state and should be routed to the same provider consistently
statefulAPIs := map[string]bool{
string(ApiNameResponses): true, // Response API - uses previous_response_id
string(ApiNameFiles): true, // Files API - maintains file state
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
string(ApiNameBatches): true, // Batch API - maintains batch state
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
string(ApiNameResponses): true, // Response API - uses previous_response_id
string(ApiNameFiles): true, // Files API - maintains file state
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
string(ApiNameBatches): true, // Batch API - maintains batch state
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
}
return statefulAPIs[apiName]
}
@@ -845,6 +845,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
return err
}
return c.setRequestModel(ctx, req)
case *imageEditRequest:
if err := decodeImageEditRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
case *imageVariationRequest:
if err := decodeImageVariationRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
default:
return errors.New("unsupported request type")
}
@@ -860,6 +870,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
model = &req.Model
case *imageGenerationRequest:
model = &req.Model
case *imageEditRequest:
model = &req.Model
case *imageVariationRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}

View File

@@ -4,8 +4,8 @@ import (
"encoding/json"
"fmt"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
)
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
return nil
}
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}) error {
body, err := json.Marshal(request)
if err != nil {

View File

@@ -46,6 +46,7 @@ const (
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
contextVertexRawMarker = "isVertexRawRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
)
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameImageGeneration): vertexPathTemplate,
string(ApiNameImageEdit): vertexPathTemplate,
string(ApiNameImageVariation): vertexPathTemplate,
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
}
}
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
return v.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return v.onImageGenerationRequestBody(ctx, body, headers)
case ApiNameImageEdit:
return v.onImageEditRequestBody(ctx, body, headers)
case ApiNameImageVariation:
return v.onImageVariationRequestBody(ctx, body, headers)
default:
return body, nil
}
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexImageGenerationRequest(request)
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageEditRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
if parsedRequest.HasMask {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
if request.HasMask() {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
if request.Prompt == "" {
return nil, fmt.Errorf("missing prompt in request")
}
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageVariationRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
prompt := request.Prompt
if prompt == "" {
prompt = vertexImageVariationDefaultPrompt
}
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
}
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
// 构建安全设置
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
}
// 解析尺寸参数
aspectRatio, imageSize := v.parseImageSize(request.Size)
aspectRatio, imageSize := v.parseImageSize(size)
// 确定输出 MIME 类型
mimeType := "image/png"
if request.OutputFormat != "" {
switch request.OutputFormat {
if outputFormat != "" {
switch outputFormat {
case "jpeg", "jpg":
mimeType = "image/jpeg"
case "webp":
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
}
}
parts := make([]vertexPart, 0, len(imageURLs)+1)
for _, imageURL := range imageURLs {
part, err := convertMediaContent(imageURL)
if err != nil {
return nil, err
}
parts = append(parts, part)
}
if prompt != "" {
parts = append(parts, vertexPart{
Text: prompt,
})
}
if len(parts) == 0 {
return nil, fmt.Errorf("missing prompt and image_url in request")
}
vertexRequest := &vertexChatRequest{
Contents: []vertexChatContent{{
Role: roleUser,
Parts: []vertexPart{{
Text: request.Prompt,
}},
Role: roleUser,
Parts: parts,
}},
SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
},
}
return vertexRequest
return vertexRequest, nil
}
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
return v.onChatCompletionResponseBody(ctx, body)
case ApiNameEmbeddings:
return v.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration:
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
return v.onImageGenerationResponseBody(ctx, body)
default:
return body, nil
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
switch apiName {
case ApiNameEmbeddings:
action = vertexEmbeddingAction
case ApiNameImageGeneration:
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
// 图片生成使用非流式端点,需要完整响应
action = vertexChatCompletionAction
default: