mirror of
https://github.com/alibaba/higress.git
synced 2026-05-11 22:37:32 +08:00
[ai-proxy] vertex image edits & variations (#3536)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type imageInputURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||
var rawURL string
|
||||
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||
i.URL = rawURL
|
||||
i.ImageURL = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type alias imageInputURL
|
||||
var value alias
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = imageInputURL(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imageInputURL) GetURL() string {
|
||||
if i == nil {
|
||||
return ""
|
||||
}
|
||||
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||
return i.ImageURL.Url
|
||||
}
|
||||
return i.URL
|
||||
}
|
||||
|
||||
type imageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Mask *imageInputURL `json:"mask,omitempty"`
|
||||
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) HasMask() bool {
|
||||
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||
return true
|
||||
}
|
||||
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||
}
|
||||
|
||||
type imageVariationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
type imageGenerationData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64 string `json:"b64_json,omitempty"`
|
||||
|
||||
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Size string
|
||||
OutputFormat string
|
||||
N int
|
||||
ImageURLs []string
|
||||
HasMask bool
|
||||
}
|
||||
|
||||
func isMultipartFormData(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
ImageURLs: make([]string, 0),
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
fieldName := part.FormName()
|
||||
if fieldName == "" {
|
||||
_ = part.Close()
|
||||
continue
|
||||
}
|
||||
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||
|
||||
partData, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(partData))
|
||||
switch fieldName {
|
||||
case "model":
|
||||
req.Model = value
|
||||
continue
|
||||
case "prompt":
|
||||
req.Prompt = value
|
||||
continue
|
||||
case "size":
|
||||
req.Size = value
|
||||
continue
|
||||
case "output_format":
|
||||
req.OutputFormat = value
|
||||
continue
|
||||
case "n":
|
||||
if value != "" {
|
||||
if parsed, err := strconv.Atoi(value); err == nil {
|
||||
req.N = parsed
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isMultipartImageField(fieldName) {
|
||||
if isMultipartImageURLValue(value) {
|
||||
req.ImageURLs = append(req.ImageURLs, value)
|
||||
continue
|
||||
}
|
||||
if len(partData) == 0 {
|
||||
continue
|
||||
}
|
||||
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||
continue
|
||||
}
|
||||
if isMultipartMaskField(fieldName) {
|
||||
if len(partData) > 0 || value != "" {
|
||||
req.HasMask = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
func isMultipartMaskField(fieldName string) bool {
|
||||
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||
}
|
||||
|
||||
func isMultipartImageURLValue(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
loweredValue := strings.ToLower(value)
|
||||
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||
}
|
||||
|
||||
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||
mimeType := strings.TrimSpace(contentType)
|
||||
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||
mimeType = http.DetectContentType(data)
|
||||
}
|
||||
mimeType = normalizeMultipartMimeType(mimeType)
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
}
|
||||
|
||||
func normalizeMultipartMimeType(contentType string) string {
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && mediaType != "" {
|
||||
return strings.TrimSpace(mediaType)
|
||||
}
|
||||
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||
return strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
@@ -763,19 +763,19 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
@@ -845,6 +845,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageEditRequest:
|
||||
if err := decodeImageEditRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageVariationRequest:
|
||||
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -860,6 +870,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
model = &req.Model
|
||||
case *imageGenerationRequest:
|
||||
model = &req.Model
|
||||
case *imageEditRequest:
|
||||
model = &req.Model
|
||||
case *imageVariationRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameImageEdit): vertexPathTemplate,
|
||||
string(ApiNameImageVariation): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
case ApiNameImageEdit:
|
||||
return v.onImageEditRequestBody(ctx, body, headers)
|
||||
case ApiNameImageVariation:
|
||||
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageEditRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedRequest.HasMask {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.HasMask() {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
if request.Prompt == "" {
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageVariationRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
if prompt == "" {
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
aspectRatio, imageSize := v.parseImageSize(size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
if outputFormat != "" {
|
||||
switch outputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||
for _, imageURL := range imageURLs {
|
||||
part, err := convertMediaContent(imageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
if prompt != "" {
|
||||
parts = append(parts, vertexPart{
|
||||
Text: prompt,
|
||||
})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
Role: roleUser,
|
||||
Parts: parts,
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
return vertexRequest, nil
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user