[ai-proxy] vertex image edits & variations (#3536)

This commit is contained in:
woody
2026-02-27 10:18:30 +08:00
committed by GitHub
parent e9aecb6e1f
commit e2a22d1171
8 changed files with 830 additions and 28 deletions

View File

@@ -46,6 +46,7 @@ const (
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
contextVertexRawMarker = "isVertexRawRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
)
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameImageGeneration): vertexPathTemplate,
string(ApiNameImageEdit): vertexPathTemplate,
string(ApiNameImageVariation): vertexPathTemplate,
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
}
}
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
return v.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return v.onImageGenerationRequestBody(ctx, body, headers)
case ApiNameImageEdit:
return v.onImageEditRequestBody(ctx, body, headers)
case ApiNameImageVariation:
return v.onImageVariationRequestBody(ctx, body, headers)
default:
return body, nil
}
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexImageGenerationRequest(request)
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageEditRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
if parsedRequest.HasMask {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
if request.HasMask() {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
if request.Prompt == "" {
return nil, fmt.Errorf("missing prompt in request")
}
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageVariationRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
prompt := request.Prompt
if prompt == "" {
prompt = vertexImageVariationDefaultPrompt
}
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
}
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
// 构建安全设置
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
}
// 解析尺寸参数
aspectRatio, imageSize := v.parseImageSize(request.Size)
aspectRatio, imageSize := v.parseImageSize(size)
// 确定输出 MIME 类型
mimeType := "image/png"
if request.OutputFormat != "" {
switch request.OutputFormat {
if outputFormat != "" {
switch outputFormat {
case "jpeg", "jpg":
mimeType = "image/jpeg"
case "webp":
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
}
}
parts := make([]vertexPart, 0, len(imageURLs)+1)
for _, imageURL := range imageURLs {
part, err := convertMediaContent(imageURL)
if err != nil {
return nil, err
}
parts = append(parts, part)
}
if prompt != "" {
parts = append(parts, vertexPart{
Text: prompt,
})
}
if len(parts) == 0 {
return nil, fmt.Errorf("missing prompt and image_url in request")
}
vertexRequest := &vertexChatRequest{
Contents: []vertexChatContent{{
Role: roleUser,
Parts: []vertexPart{{
Text: request.Prompt,
}},
Role: roleUser,
Parts: parts,
}},
SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
},
}
return vertexRequest
return vertexRequest, nil
}
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
return v.onChatCompletionResponseBody(ctx, body)
case ApiNameEmbeddings:
return v.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration:
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
return v.onImageGenerationResponseBody(ctx, body)
default:
return body, nil
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
switch apiName {
case ApiNameEmbeddings:
action = vertexEmbeddingAction
case ApiNameImageGeneration:
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
// 图片生成使用非流式端点,需要完整响应
action = vertexChatCompletionAction
default: