mirror of
https://github.com/alibaba/higress.git
synced 2026-02-28 06:30:49 +08:00
[ai-proxy] vertex image edits & variations (#3536)
This commit is contained in:
@@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
}
|
||||
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
|
||||
ctx.DontReadRequestBody()
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
|
||||
}
|
||||
|
||||
if apiName == "" {
|
||||
@@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
|
||||
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue
|
||||
@@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
|
||||
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
return true
|
||||
}
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSupportedRequestContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName provider.ApiName
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "json chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "application/json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image variation",
|
||||
apiName: provider.ApiNameImageVariation,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "text plain image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "text/plain",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAi360(t *testing.T) {
|
||||
test.RunAi360ParseConfigTests(t)
|
||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||
@@ -137,6 +185,8 @@ func TestVertex(t *testing.T) {
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
|
||||
// Vertex Raw 模式测试
|
||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type imageInputURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||
var rawURL string
|
||||
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||
i.URL = rawURL
|
||||
i.ImageURL = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type alias imageInputURL
|
||||
var value alias
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = imageInputURL(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imageInputURL) GetURL() string {
|
||||
if i == nil {
|
||||
return ""
|
||||
}
|
||||
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||
return i.ImageURL.Url
|
||||
}
|
||||
return i.URL
|
||||
}
|
||||
|
||||
type imageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Mask *imageInputURL `json:"mask,omitempty"`
|
||||
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) HasMask() bool {
|
||||
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||
return true
|
||||
}
|
||||
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||
}
|
||||
|
||||
type imageVariationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
type imageGenerationData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64 string `json:"b64_json,omitempty"`
|
||||
|
||||
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Size string
|
||||
OutputFormat string
|
||||
N int
|
||||
ImageURLs []string
|
||||
HasMask bool
|
||||
}
|
||||
|
||||
func isMultipartFormData(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
ImageURLs: make([]string, 0),
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
fieldName := part.FormName()
|
||||
if fieldName == "" {
|
||||
_ = part.Close()
|
||||
continue
|
||||
}
|
||||
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||
|
||||
partData, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(partData))
|
||||
switch fieldName {
|
||||
case "model":
|
||||
req.Model = value
|
||||
continue
|
||||
case "prompt":
|
||||
req.Prompt = value
|
||||
continue
|
||||
case "size":
|
||||
req.Size = value
|
||||
continue
|
||||
case "output_format":
|
||||
req.OutputFormat = value
|
||||
continue
|
||||
case "n":
|
||||
if value != "" {
|
||||
if parsed, err := strconv.Atoi(value); err == nil {
|
||||
req.N = parsed
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isMultipartImageField(fieldName) {
|
||||
if isMultipartImageURLValue(value) {
|
||||
req.ImageURLs = append(req.ImageURLs, value)
|
||||
continue
|
||||
}
|
||||
if len(partData) == 0 {
|
||||
continue
|
||||
}
|
||||
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||
continue
|
||||
}
|
||||
if isMultipartMaskField(fieldName) {
|
||||
if len(partData) > 0 || value != "" {
|
||||
req.HasMask = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
func isMultipartMaskField(fieldName string) bool {
|
||||
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||
}
|
||||
|
||||
func isMultipartImageURLValue(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
loweredValue := strings.ToLower(value)
|
||||
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||
}
|
||||
|
||||
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||
mimeType := strings.TrimSpace(contentType)
|
||||
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||
mimeType = http.DetectContentType(data)
|
||||
}
|
||||
mimeType = normalizeMultipartMimeType(mimeType)
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
}
|
||||
|
||||
func normalizeMultipartMimeType(contentType string) string {
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && mediaType != "" {
|
||||
return strings.TrimSpace(mediaType)
|
||||
}
|
||||
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||
return strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
@@ -763,19 +763,19 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
@@ -845,6 +845,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageEditRequest:
|
||||
if err := decodeImageEditRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageVariationRequest:
|
||||
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -860,6 +870,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
model = &req.Model
|
||||
case *imageGenerationRequest:
|
||||
model = &req.Model
|
||||
case *imageEditRequest:
|
||||
model = &req.Model
|
||||
case *imageVariationRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameImageEdit): vertexPathTemplate,
|
||||
string(ApiNameImageVariation): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
case ApiNameImageEdit:
|
||||
return v.onImageEditRequestBody(ctx, body, headers)
|
||||
case ApiNameImageVariation:
|
||||
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageEditRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedRequest.HasMask {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.HasMask() {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
if request.Prompt == "" {
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageVariationRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
if prompt == "" {
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
aspectRatio, imageSize := v.parseImageSize(size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
if outputFormat != "" {
|
||||
switch outputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||
for _, imageURL := range imageURLs {
|
||||
part, err := convertMediaContent(imageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
if prompt != "" {
|
||||
parts = append(parts, vertexPart{
|
||||
Text: prompt,
|
||||
})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
Role: roleUser,
|
||||
Parts: parts,
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
return vertexRequest, nil
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -1273,6 +1275,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
var buffer bytes.Buffer
|
||||
writer := multipart.NewWriter(&buffer)
|
||||
|
||||
for key, value := range fields {
|
||||
require.NoError(t, writer.WriteField(key, value))
|
||||
}
|
||||
|
||||
for fieldName, data := range files {
|
||||
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, writer.Close())
|
||||
return buffer.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "Add sunglasses to the cat",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 12,
|
||||
"candidatesTokenCount": 1024,
|
||||
"totalTokenCount": 1036
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 8,
|
||||
"candidatesTokenCount": 768,
|
||||
"totalTokenCount": 776
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== Vertex Raw 模式测试 ====================
|
||||
|
||||
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user