mirror of
https://github.com/alibaba/higress.git
synced 2026-02-28 06:30:49 +08:00
Compare commits
2 Commits
add-releas
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2a22d1171 | ||
|
|
e9aecb6e1f |
@@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
|
||||||
ctx.DontReadRequestBody()
|
ctx.DontReadRequestBody()
|
||||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiName == "" {
|
if apiName == "" {
|
||||||
@@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return action
|
return action
|
||||||
}
|
}
|
||||||
|
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
|
||||||
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||||
}
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
@@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName {
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
|
||||||
|
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
contentType = strings.ToLower(contentType)
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_isSupportedRequestContentType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
apiName provider.ApiName
|
||||||
|
contentType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "json chat completion",
|
||||||
|
apiName: provider.ApiNameChatCompletion,
|
||||||
|
contentType: "application/json",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart image edit",
|
||||||
|
apiName: provider.ApiNameImageEdit,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart image variation",
|
||||||
|
apiName: provider.ApiNameImageVariation,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart chat completion",
|
||||||
|
apiName: provider.ApiNameChatCompletion,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text plain image edit",
|
||||||
|
apiName: provider.ApiNameImageEdit,
|
||||||
|
contentType: "text/plain",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAi360(t *testing.T) {
|
func TestAi360(t *testing.T) {
|
||||||
test.RunAi360ParseConfigTests(t)
|
test.RunAi360ParseConfigTests(t)
|
||||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||||
@@ -137,6 +185,8 @@ func TestVertex(t *testing.T) {
|
|||||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||||
|
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
|
||||||
|
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
|
||||||
// Vertex Raw 模式测试
|
// Vertex Raw 模式测试
|
||||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
|
|||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type imageInputURL struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||||
|
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||||
|
var rawURL string
|
||||||
|
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||||
|
i.URL = rawURL
|
||||||
|
i.ImageURL = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type alias imageInputURL
|
||||||
|
var value alias
|
||||||
|
if err := json.Unmarshal(data, &value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = imageInputURL(value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *imageInputURL) GetURL() string {
|
||||||
|
if i == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||||
|
return i.ImageURL.Url
|
||||||
|
}
|
||||||
|
return i.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageEditRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Image *imageInputURL `json:"image,omitempty"`
|
||||||
|
Images []imageInputURL `json:"images,omitempty"`
|
||||||
|
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||||
|
Mask *imageInputURL `json:"mask,omitempty"`
|
||||||
|
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageEditRequest) GetImageURLs() []string {
|
||||||
|
urls := make([]string, 0, len(r.Images)+2)
|
||||||
|
for _, image := range r.Images {
|
||||||
|
if url := image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Image != nil {
|
||||||
|
if url := r.Image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.ImageURL != nil {
|
||||||
|
if url := r.ImageURL.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageEditRequest) HasMask() bool {
|
||||||
|
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageVariationRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Image *imageInputURL `json:"image,omitempty"`
|
||||||
|
Images []imageInputURL `json:"images,omitempty"`
|
||||||
|
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||||
|
urls := make([]string, 0, len(r.Images)+2)
|
||||||
|
for _, image := range r.Images {
|
||||||
|
if url := image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Image != nil {
|
||||||
|
if url := r.Image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.ImageURL != nil {
|
||||||
|
if url := r.ImageURL.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
|
||||||
type imageGenerationData struct {
|
type imageGenerationData struct {
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
B64 string `json:"b64_json,omitempty"`
|
B64 string `json:"b64_json,omitempty"`
|
||||||
|
|||||||
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type multipartImageRequest struct {
|
||||||
|
Model string
|
||||||
|
Prompt string
|
||||||
|
Size string
|
||||||
|
OutputFormat string
|
||||||
|
N int
|
||||||
|
ImageURLs []string
|
||||||
|
HasMask bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartFormData(contentType string) bool {
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||||
|
_, params, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||||
|
}
|
||||||
|
boundary := params["boundary"]
|
||||||
|
if boundary == "" {
|
||||||
|
return nil, fmt.Errorf("missing multipart boundary")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &multipartImageRequest{
|
||||||
|
ImageURLs: make([]string, 0),
|
||||||
|
}
|
||||||
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||||
|
for {
|
||||||
|
part, err := reader.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||||
|
}
|
||||||
|
fieldName := part.FormName()
|
||||||
|
if fieldName == "" {
|
||||||
|
_ = part.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
partData, err := io.ReadAll(part)
|
||||||
|
_ = part.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value := strings.TrimSpace(string(partData))
|
||||||
|
switch fieldName {
|
||||||
|
case "model":
|
||||||
|
req.Model = value
|
||||||
|
continue
|
||||||
|
case "prompt":
|
||||||
|
req.Prompt = value
|
||||||
|
continue
|
||||||
|
case "size":
|
||||||
|
req.Size = value
|
||||||
|
continue
|
||||||
|
case "output_format":
|
||||||
|
req.OutputFormat = value
|
||||||
|
continue
|
||||||
|
case "n":
|
||||||
|
if value != "" {
|
||||||
|
if parsed, err := strconv.Atoi(value); err == nil {
|
||||||
|
req.N = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isMultipartImageField(fieldName) {
|
||||||
|
if isMultipartImageURLValue(value) {
|
||||||
|
req.ImageURLs = append(req.ImageURLs, value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(partData) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||||
|
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isMultipartMaskField(fieldName) {
|
||||||
|
if len(partData) > 0 || value != "" {
|
||||||
|
req.HasMask = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartImageField(fieldName string) bool {
|
||||||
|
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartMaskField(fieldName string) bool {
|
||||||
|
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartImageURLValue(value string) bool {
|
||||||
|
if value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loweredValue := strings.ToLower(value)
|
||||||
|
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||||
|
mimeType := strings.TrimSpace(contentType)
|
||||||
|
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||||
|
mimeType = http.DetectContentType(data)
|
||||||
|
}
|
||||||
|
mimeType = normalizeMultipartMimeType(mimeType)
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMultipartMimeType(contentType string) string {
|
||||||
|
contentType = strings.TrimSpace(contentType)
|
||||||
|
if contentType == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err == nil && mediaType != "" {
|
||||||
|
return strings.TrimSpace(mediaType)
|
||||||
|
}
|
||||||
|
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||||
|
return strings.TrimSpace(contentType[:idx])
|
||||||
|
}
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
@@ -763,19 +763,19 @@ func (c *ProviderConfig) GetRandomToken() string {
|
|||||||
func isStatefulAPI(apiName string) bool {
|
func isStatefulAPI(apiName string) bool {
|
||||||
// These APIs maintain session state and should be routed to the same provider consistently
|
// These APIs maintain session state and should be routed to the same provider consistently
|
||||||
statefulAPIs := map[string]bool{
|
statefulAPIs := map[string]bool{
|
||||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||||
string(ApiNameFiles): true, // Files API - maintains file state
|
string(ApiNameFiles): true, // Files API - maintains file state
|
||||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||||
}
|
}
|
||||||
return statefulAPIs[apiName]
|
return statefulAPIs[apiName]
|
||||||
}
|
}
|
||||||
@@ -845,6 +845,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.setRequestModel(ctx, req)
|
return c.setRequestModel(ctx, req)
|
||||||
|
case *imageEditRequest:
|
||||||
|
if err := decodeImageEditRequest(body, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setRequestModel(ctx, req)
|
||||||
|
case *imageVariationRequest:
|
||||||
|
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setRequestModel(ctx, req)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
@@ -860,6 +870,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
|||||||
model = &req.Model
|
model = &req.Model
|
||||||
case *imageGenerationRequest:
|
case *imageGenerationRequest:
|
||||||
model = &req.Model
|
model = &req.Model
|
||||||
|
case *imageEditRequest:
|
||||||
|
model = &req.Model
|
||||||
|
case *imageVariationRequest:
|
||||||
|
model = &req.Model
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ const (
|
|||||||
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||||
qwenBailianPath = "/api/v1/apps"
|
qwenBailianPath = "/api/v1/apps"
|
||||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
qwenAnthropicMessagesPath = "/api/v2/apps/claude-code-proxy/v1/messages"
|
qwenAnthropicMessagesPath = "/apps/anthropic/v1/messages"
|
||||||
|
|
||||||
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
||||||
qwenAsyncTaskPath = "/api/v1/tasks/"
|
qwenAsyncTaskPath = "/api/v1/tasks/"
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func replaceJsonRequestBody(request interface{}) error {
|
func replaceJsonRequestBody(request interface{}) error {
|
||||||
body, err := json.Marshal(request)
|
body, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ const (
|
|||||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||||
contextVertexRawMarker = "isVertexRawRequest"
|
contextVertexRawMarker = "isVertexRawRequest"
|
||||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
|
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||||
)
|
)
|
||||||
|
|
||||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||||
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
|||||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||||
|
string(ApiNameImageEdit): vertexPathTemplate,
|
||||||
|
string(ApiNameImageVariation): vertexPathTemplate,
|
||||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
|||||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration:
|
||||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||||
|
case ApiNameImageEdit:
|
||||||
|
return v.onImageEditRequestBody(ctx, body, headers)
|
||||||
|
case ApiNameImageVariation:
|
||||||
|
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
|||||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||||
util.OverwriteRequestPathHeader(headers, path)
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
|
||||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return json.Marshal(vertexRequest)
|
return json.Marshal(vertexRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
request := &imageEditRequest{}
|
||||||
|
imageURLs := make([]string, 0)
|
||||||
|
contentType := headers.Get("Content-Type")
|
||||||
|
if isMultipartFormData(contentType) {
|
||||||
|
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request.Model = parsedRequest.Model
|
||||||
|
request.Prompt = parsedRequest.Prompt
|
||||||
|
request.Size = parsedRequest.Size
|
||||||
|
request.OutputFormat = parsedRequest.OutputFormat
|
||||||
|
request.N = parsedRequest.N
|
||||||
|
imageURLs = parsedRequest.ImageURLs
|
||||||
|
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if parsedRequest.HasMask {
|
||||||
|
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if request.HasMask() {
|
||||||
|
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||||
|
}
|
||||||
|
imageURLs = request.GetImageURLs()
|
||||||
|
}
|
||||||
|
if len(imageURLs) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing image_url in request")
|
||||||
|
}
|
||||||
|
if request.Prompt == "" {
|
||||||
|
return nil, fmt.Errorf("missing prompt in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||||
|
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(vertexRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
request := &imageVariationRequest{}
|
||||||
|
imageURLs := make([]string, 0)
|
||||||
|
contentType := headers.Get("Content-Type")
|
||||||
|
if isMultipartFormData(contentType) {
|
||||||
|
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request.Model = parsedRequest.Model
|
||||||
|
request.Prompt = parsedRequest.Prompt
|
||||||
|
request.Size = parsedRequest.Size
|
||||||
|
request.OutputFormat = parsedRequest.OutputFormat
|
||||||
|
request.N = parsedRequest.N
|
||||||
|
imageURLs = parsedRequest.ImageURLs
|
||||||
|
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
imageURLs = request.GetImageURLs()
|
||||||
|
}
|
||||||
|
if len(imageURLs) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing image_url in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := request.Prompt
|
||||||
|
if prompt == "" {
|
||||||
|
prompt = vertexImageVariationDefaultPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||||
|
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(vertexRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||||
|
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||||
// 构建安全设置
|
// 构建安全设置
|
||||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||||
for category, threshold := range v.config.geminiSafetySetting {
|
for category, threshold := range v.config.geminiSafetySetting {
|
||||||
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析尺寸参数
|
// 解析尺寸参数
|
||||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
aspectRatio, imageSize := v.parseImageSize(size)
|
||||||
|
|
||||||
// 确定输出 MIME 类型
|
// 确定输出 MIME 类型
|
||||||
mimeType := "image/png"
|
mimeType := "image/png"
|
||||||
if request.OutputFormat != "" {
|
if outputFormat != "" {
|
||||||
switch request.OutputFormat {
|
switch outputFormat {
|
||||||
case "jpeg", "jpg":
|
case "jpeg", "jpg":
|
||||||
mimeType = "image/jpeg"
|
mimeType = "image/jpeg"
|
||||||
case "webp":
|
case "webp":
|
||||||
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||||
|
for _, imageURL := range imageURLs {
|
||||||
|
part, err := convertMediaContent(imageURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
if prompt != "" {
|
||||||
|
parts = append(parts, vertexPart{
|
||||||
|
Text: prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||||
|
}
|
||||||
|
|
||||||
vertexRequest := &vertexChatRequest{
|
vertexRequest := &vertexChatRequest{
|
||||||
Contents: []vertexChatContent{{
|
Contents: []vertexChatContent{{
|
||||||
Role: roleUser,
|
Role: roleUser,
|
||||||
Parts: []vertexPart{{
|
Parts: parts,
|
||||||
Text: request.Prompt,
|
|
||||||
}},
|
|
||||||
}},
|
}},
|
||||||
SafetySettings: safetySettings,
|
SafetySettings: safetySettings,
|
||||||
GenerationConfig: vertexChatGenerationConfig{
|
GenerationConfig: vertexChatGenerationConfig{
|
||||||
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return vertexRequest
|
return vertexRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||||
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
|||||||
return v.onChatCompletionResponseBody(ctx, body)
|
return v.onChatCompletionResponseBody(ctx, body)
|
||||||
case ApiNameEmbeddings:
|
case ApiNameEmbeddings:
|
||||||
return v.onEmbeddingsResponseBody(ctx, body)
|
return v.onEmbeddingsResponseBody(ctx, body)
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||||
return v.onImageGenerationResponseBody(ctx, body)
|
return v.onImageGenerationResponseBody(ctx, body)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
|||||||
switch apiName {
|
switch apiName {
|
||||||
case ApiNameEmbeddings:
|
case ApiNameEmbeddings:
|
||||||
action = vertexEmbeddingAction
|
action = vertexEmbeddingAction
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||||
// 图片生成使用非流式端点,需要完整响应
|
// 图片生成使用非流式端点,需要完整响应
|
||||||
action = vertexChatCompletionAction
|
action = vertexChatCompletionAction
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"mime/multipart"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -1273,6 +1275,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buffer)
|
||||||
|
|
||||||
|
for key, value := range fields {
|
||||||
|
require.NoError(t, writer.WriteField(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, data := range files {
|
||||||
|
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
return buffer.Bytes(), writer.FormDataContentType()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
|
||||||
|
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||||
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
"prompt": "Add sunglasses to the cat",
|
||||||
|
"size": "1024x1024",
|
||||||
|
}, map[string][]byte{
|
||||||
|
"image": []byte("fake-image-content"),
|
||||||
|
})
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", contentType},
|
||||||
|
})
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestBody(body)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||||
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
"size": "1024x1024",
|
||||||
|
}, map[string][]byte{
|
||||||
|
"image": []byte("fake-image-content"),
|
||||||
|
})
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", contentType},
|
||||||
|
})
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestBody(body)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||||
|
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||||
|
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
|
||||||
|
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit response body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 12,
|
||||||
|
"candidatesTokenCount": 1024,
|
||||||
|
"totalTokenCount": 1036
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, processedResponseBody)
|
||||||
|
|
||||||
|
responseStr := string(processedResponseBody)
|
||||||
|
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||||
|
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation response body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 8,
|
||||||
|
"candidatesTokenCount": 768,
|
||||||
|
"totalTokenCount": 776
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, processedResponseBody)
|
||||||
|
|
||||||
|
responseStr := string(processedResponseBody)
|
||||||
|
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||||
|
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== Vertex Raw 模式测试 ====================
|
// ==================== Vertex Raw 模式测试 ====================
|
||||||
|
|
||||||
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user