mirror of
https://github.com/alibaba/higress.git
synced 2026-05-23 04:07:26 +08:00
[ai-proxy] vertex image edits & variations (#3536)
This commit is contained in:
@@ -46,6 +46,7 @@ const (
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameImageEdit): vertexPathTemplate,
|
||||
string(ApiNameImageVariation): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
case ApiNameImageEdit:
|
||||
return v.onImageEditRequestBody(ctx, body, headers)
|
||||
case ApiNameImageVariation:
|
||||
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageEditRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedRequest.HasMask {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.HasMask() {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
if request.Prompt == "" {
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageVariationRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
if prompt == "" {
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
aspectRatio, imageSize := v.parseImageSize(size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
if outputFormat != "" {
|
||||
switch outputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||
for _, imageURL := range imageURLs {
|
||||
part, err := convertMediaContent(imageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
if prompt != "" {
|
||||
parts = append(parts, vertexPart{
|
||||
Text: prompt,
|
||||
})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
Role: roleUser,
|
||||
Parts: parts,
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
return vertexRequest, nil
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user