From 889ea670131eb34c56bfd63687746dda3456fc05 Mon Sep 17 00:00:00 2001 From: woody Date: Mon, 30 Mar 2026 13:45:41 +0800 Subject: [PATCH] =?UTF-8?q?feat(provider):=20=E4=BC=98=E5=8C=96=20Azure=20?= =?UTF-8?q?multipart=20=E5=A4=84=E7=90=86=20||=20feat(provider):=20Optimiz?= =?UTF-8?q?e=20Azure=20multipart=20processing=20(#3651)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../wasm-go/extensions/ai-proxy/main_test.go | 1 + .../extensions/ai-proxy/provider/azure.go | 44 ++- .../ai-proxy/provider/multipart_helper.go | 123 +++++- .../provider/multipart_transform_test.go | 363 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 26 ++ .../wasm-go/extensions/ai-proxy/test/azure.go | 203 ++++++++++ 6 files changed, 756 insertions(+), 4 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/multipart_transform_test.go diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index bd7a421f7..8801c1c24 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -158,6 +158,7 @@ func TestGemini(t *testing.T) { func TestAzure(t *testing.T) { test.RunAzureParseConfigTests(t) + test.RunAzureMultipartHelperTests(t) test.RunAzureOnHttpRequestHeadersTests(t) test.RunAzureOnHttpRequestBodyTests(t) test.RunAzureOnHttpResponseHeadersTests(t) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 7b27dd38c..152435bcc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" @@ -151,17 +152,44 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body) } +func isAzureMultipartImageRequest(apiName ApiName, contentType string) bool { + if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation { + return false + } + return isMultipartFormData(contentType) +} + func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) { transformedBody = body err = nil + contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType) + isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType) + transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body) + if isMultipartImageRequest { + if err != nil { + log.Debugf("[azure multipart] body transform failed: api=%s, err=%v", apiName, err) + } else { + log.Debugf("[azure multipart] body transformed: api=%s, originalModel=%s, mappedModel=%s, bodyBytes=%d->%d", + apiName, + ctx.GetStringContext(ctxKeyOriginalRequestModel, ""), + ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + len(body), + len(transformedBody), + ) + } + } if err != nil { return } // This must be called after the body is transformed, because it uses the model from the context filled by that call. if path := m.transformRequestPath(ctx, apiName); path != "" { + if isMultipartImageRequest { + log.Debugf("[azure multipart] body path overwrite: api=%s, path=%s, modelInContext=%s", + apiName, path, ctx.GetStringContext(ctxKeyFinalRequestModel, "")) + } err = util.OverwriteRequestPath(path) if err == nil { log.Debugf("azureProvider: overwrite request path to %s succeeded", path) @@ -222,16 +250,30 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap } func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + contentType := headers.Get(util.HeaderContentType) + isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType) + // We need to overwrite the request path in the request headers stage, // because for some APIs, we don't read the request body and the path is model irrelevant. if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" { util.OverwriteRequestPathHeader(headers, overwrittenPath) + if isMultipartImageRequest { + log.Debugf("[azure multipart] header path overwrite: api=%s, path=%s, modelInContext=%s", + apiName, overwrittenPath, ctx.GetStringContext(ctxKeyFinalRequestModel, "")) + } } util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) headers.Set("api-key", m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") - if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) { + supportedAPI := m.config.isSupportedAPI(apiName) + needProcessBody := m.config.needToProcessRequestBody(apiName) + if isMultipartImageRequest { + log.Debugf("[azure multipart] body processing decision: api=%s, supported=%t, needProcessBody=%t", + apiName, supportedAPI, needProcessBody) + } + + if !supportedAPI || !needProcessBody { // If the API is not supported or there is no need to process the body, // we should not read the request body and keep it as it is. ctx.DontReadRequestBody() diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go index a54f9282d..ea58b67e6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go @@ -8,10 +8,15 @@ import ( "mime" "mime/multipart" "net/http" + "net/textproto" "strconv" "strings" ) +var newMultipartWriter = func(w io.Writer) *multipart.Writer { + return multipart.NewWriter(w) +} + type multipartImageRequest struct { Model string Prompt string @@ -30,14 +35,22 @@ func isMultipartFormData(contentType string) bool { return strings.EqualFold(mediaType, "multipart/form-data") } -func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) { +func parseMultipartBoundary(contentType string) (string, error) { _, params, err := mime.ParseMediaType(contentType) if err != nil { - return nil, fmt.Errorf("unable to parse content-type: %v", err) + return "", fmt.Errorf("unable to parse content-type: %v", err) } boundary := params["boundary"] if boundary == "" { - return nil, fmt.Errorf("missing multipart boundary") + return "", fmt.Errorf("missing multipart boundary") + } + return boundary, nil +} + +func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) { + boundary, err := parseMultipartBoundary(contentType) + if err != nil { + return nil, err } req := &multipartImageRequest{ @@ -111,6 +124,110 @@ func parseMultipartImageRequest(body []byte, contentType string) (*multipartImag return req, nil } +func extractMultipartModel(body []byte, contentType string) (string, error) { + boundary, err := parseMultipartBoundary(contentType) + if err != nil { + return "", err + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + model := "" + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return "", fmt.Errorf("unable to read multipart part: %v", err) + } + + fieldName := part.FormName() + var readErr error + if fieldName == "model" { + var partData []byte + partData, readErr = io.ReadAll(part) + if readErr == nil { + model = strings.TrimSpace(string(partData)) + } + } else { + _, readErr = io.Copy(io.Discard, part) + } + _ = part.Close() + if readErr != nil { + return "", fmt.Errorf("unable to read multipart field %s: %v", fieldName, readErr) + } + } + + return model, nil +} + +func rewriteMultipartFormModel(body []byte, contentType string, model string) ([]byte, error) { + boundary, err := parseMultipartBoundary(contentType) + if err != nil { + return nil, err + } + + var buffer bytes.Buffer + writer := newMultipartWriter(&buffer) + if err := writer.SetBoundary(boundary); err != nil { + return nil, fmt.Errorf("unable to set multipart boundary: %v", err) + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + modelFound := false + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("unable to read multipart part: %v", err) + } + + fieldName := part.FormName() + newPart, err := writer.CreatePart(cloneMultipartPartHeader(part.Header)) + if err != nil { + _ = part.Close() + return nil, fmt.Errorf("unable to create multipart field %s: %v", fieldName, err) + } + + var copyErr error + if fieldName == "model" { + modelFound = true + if _, copyErr = io.WriteString(newPart, model); copyErr == nil { + _, copyErr = io.Copy(io.Discard, part) + } + } else { + _, copyErr = io.Copy(newPart, part) + } + _ = part.Close() + if copyErr != nil { + return nil, fmt.Errorf("unable to write multipart field %s: %v", fieldName, copyErr) + } + } + + if !modelFound && model != "" { + if err := writer.WriteField("model", model); err != nil { + return nil, fmt.Errorf("unable to append multipart model field: %v", err) + } + } + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("unable to finalize multipart body: %v", err) + } + + return buffer.Bytes(), nil +} + +func cloneMultipartPartHeader(header textproto.MIMEHeader) textproto.MIMEHeader { + cloned := make(textproto.MIMEHeader, len(header)) + for key, values := range header { + copied := make([]string, len(values)) + copy(copied, values) + cloned[key] = copied + } + return cloned +} + func isMultipartImageField(fieldName string) bool { return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/multipart_transform_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_transform_test.go new file mode 100644 index 000000000..6623e8e2b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/multipart_transform_test.go @@ -0,0 +1,363 @@ +package provider + +import ( + "bytes" + "errors" + "io" + "mime/multipart" + "strings" + "testing" + + "github.com/higress-group/wasm-go/pkg/iface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockMultipartHttpContext struct { + contextMap map[string]interface{} +} + +func newMockMultipartHttpContext() *mockMultipartHttpContext { + return &mockMultipartHttpContext{contextMap: make(map[string]interface{})} +} + +func (m *mockMultipartHttpContext) SetContext(key string, value interface{}) { + m.contextMap[key] = value +} +func (m *mockMultipartHttpContext) GetContext(key string) interface{} { return m.contextMap[key] } +func (m *mockMultipartHttpContext) GetBoolContext(key string, def bool) bool { return def } +func (m *mockMultipartHttpContext) GetStringContext(key, def string) string { return def } +func (m *mockMultipartHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def } +func (m *mockMultipartHttpContext) Scheme() string { return "" } +func (m *mockMultipartHttpContext) Host() string { return "" } +func (m *mockMultipartHttpContext) Path() string { return "" } +func (m *mockMultipartHttpContext) Method() string { return "" } +func (m *mockMultipartHttpContext) GetUserAttribute(key string) interface{} { return nil } +func (m *mockMultipartHttpContext) SetUserAttribute(key string, value interface{}) {} +func (m *mockMultipartHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {} +func (m *mockMultipartHttpContext) GetUserAttributeMap() map[string]interface{} { return nil } +func (m *mockMultipartHttpContext) WriteUserAttributeToLog() error { return nil } +func (m *mockMultipartHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil } +func (m *mockMultipartHttpContext) WriteUserAttributeToTrace() error { return nil } +func (m *mockMultipartHttpContext) DontReadRequestBody() {} +func (m *mockMultipartHttpContext) DontReadResponseBody() {} +func (m *mockMultipartHttpContext) BufferRequestBody() {} +func (m *mockMultipartHttpContext) BufferResponseBody() {} +func (m *mockMultipartHttpContext) NeedPauseStreamingResponse() {} +func (m *mockMultipartHttpContext) PushBuffer(buffer []byte) {} +func (m *mockMultipartHttpContext) PopBuffer() []byte { return nil } +func (m *mockMultipartHttpContext) BufferQueueSize() int { return 0 } +func (m *mockMultipartHttpContext) DisableReroute() {} +func (m *mockMultipartHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {} +func (m *mockMultipartHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {} +func (m *mockMultipartHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error { + return nil +} +func (m *mockMultipartHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 } +func (m *mockMultipartHttpContext) HasRequestBody() bool { return false } +func (m *mockMultipartHttpContext) HasResponseBody() bool { return false } +func (m *mockMultipartHttpContext) IsWebsocket() bool { return false } +func (m *mockMultipartHttpContext) IsBinaryRequestBody() bool { return false } +func (m *mockMultipartHttpContext) IsBinaryResponseBody() bool { return false } + +func buildProviderMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) { + t.Helper() + + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + + for key, value := range fields { + require.NoError(t, writer.WriteField(key, value)) + } + for fieldName, data := range files { + part, err := writer.CreateFormFile(fieldName, "upload-image.png") + require.NoError(t, err) + _, err = part.Write(data) + require.NoError(t, err) + } + + require.NoError(t, writer.Close()) + return buffer.Bytes(), writer.FormDataContentType() +} + +type failAfterNWriteWriter struct { + target io.Writer + failAtCall int + writeCalls int +} + +func (w *failAfterNWriteWriter) Write(p []byte) (int, error) { + w.writeCalls++ + if w.writeCalls >= w.failAtCall { + return 0, errors.New("injected write failure") + } + return w.target.Write(p) +} + +func withInjectedMultipartWriterFactory(t *testing.T, failAtCall int, testFunc func()) { + t.Helper() + + originalFactory := newMultipartWriter + newMultipartWriter = func(target io.Writer) *multipart.Writer { + return multipart.NewWriter(&failAfterNWriteWriter{ + target: target, + failAtCall: failAtCall, + }) + } + defer func() { + newMultipartWriter = originalFactory + }() + + testFunc() +} + +func TestRewriteMultipartFormModel(t *testing.T) { + t.Run("rewrites existing model field", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + "prompt": "Turn the dog white", + }, map[string][]byte{ + "image[]": []byte("fake-image-content"), + }) + + transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1") + require.NoError(t, err) + + req, err := parseMultipartImageRequest(transformed, contentType) + require.NoError(t, err) + assert.Equal(t, "gpt-image-1", req.Model) + assert.Equal(t, "Turn the dog white", req.Prompt) + assert.Len(t, req.ImageURLs, 1) + assert.Contains(t, string(transformed), "fake-image-content") + }) + + t.Run("appends model field when missing", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "prompt": "Turn the dog white", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1") + require.NoError(t, err) + + req, err := parseMultipartImageRequest(transformed, contentType) + require.NoError(t, err) + assert.Equal(t, "gpt-image-1", req.Model) + assert.Equal(t, "Turn the dog white", req.Prompt) + assert.Len(t, req.ImageURLs, 1) + }) + + t.Run("returns error on invalid content type", func(t *testing.T) { + _, err := rewriteMultipartFormModel([]byte("not-multipart"), "multipart/form-data", "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing multipart boundary") + }) + + t.Run("returns error when boundary cannot be set", func(t *testing.T) { + longBoundary := strings.Repeat("a", 71) + _, err := rewriteMultipartFormModel([]byte(""), "multipart/form-data; boundary="+longBoundary, "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to set multipart boundary") + }) + + t.Run("returns error on malformed multipart header", func(t *testing.T) { + body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n") + _, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to read multipart part") + }) + + t.Run("returns error when multipart part copy fails", func(t *testing.T) { + body := []byte("--abc\r\nContent-Disposition: form-data; name=\"image\"; filename=\"a.png\"\r\nContent-Type: image/png\r\n\r\nabc\r\n--ab") + _, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to write multipart field image") + }) + + t.Run("returns error when creating rewritten multipart part fails", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "prompt": "Turn the dog white", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + withInjectedMultipartWriterFactory(t, 1, func() { + _, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to create multipart field") + }) + }) + + t.Run("returns error when appending model field fails", func(t *testing.T) { + withInjectedMultipartWriterFactory(t, 1, func() { + _, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "gpt-image-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to append multipart model field") + }) + }) + + t.Run("returns error when finalizing multipart body fails", func(t *testing.T) { + withInjectedMultipartWriterFactory(t, 1, func() { + _, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to finalize multipart body") + }) + }) +} + +func TestDefaultTransformMultipartRequestBody(t *testing.T) { + t.Run("maps multipart model and keeps body valid", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + "prompt": "Turn the dog white", + "size": "1024x1024", + }, map[string][]byte{ + "image[]": []byte("fake-image-content"), + }) + + config := &ProviderConfig{ + modelMapping: map[string]string{ + "gpt-image-1.5": "gpt-image-1", + }, + } + ctx := newMockMultipartHttpContext() + + transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType) + require.NoError(t, err) + + req, err := parseMultipartImageRequest(transformed, contentType) + require.NoError(t, err) + assert.Equal(t, "gpt-image-1.5", ctx.GetContext(ctxKeyOriginalRequestModel)) + assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel)) + assert.Equal(t, "gpt-image-1", req.Model) + assert.Equal(t, "Turn the dog white", req.Prompt) + assert.Len(t, req.ImageURLs, 1) + assert.Contains(t, string(transformed), "fake-image-content") + }) + + t.Run("appends mapped model when multipart request omits model", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "prompt": "Turn the dog white", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + config := &ProviderConfig{ + modelMapping: map[string]string{ + "*": "gpt-image-1", + }, + } + ctx := newMockMultipartHttpContext() + + transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageVariation, body, contentType) + require.NoError(t, err) + + req, err := parseMultipartImageRequest(transformed, contentType) + require.NoError(t, err) + assert.Equal(t, "", ctx.GetContext(ctxKeyOriginalRequestModel)) + assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel)) + assert.Equal(t, "gpt-image-1", req.Model) + }) + + t.Run("returns original body when multipart model is unchanged", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1", + "prompt": "Turn the dog white", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + config := &ProviderConfig{} + ctx := newMockMultipartHttpContext() + + transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType) + require.NoError(t, err) + assert.Equal(t, body, transformed) + assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyOriginalRequestModel)) + assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel)) + }) + + t.Run("ignores non image multipart apis", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1", + }, nil) + + config := &ProviderConfig{ + modelMapping: map[string]string{ + "gpt-image-1": "mapped-model", + }, + } + ctx := newMockMultipartHttpContext() + + transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameChatCompletion, body, contentType) + require.NoError(t, err) + assert.Equal(t, body, transformed) + assert.Nil(t, ctx.GetContext(ctxKeyOriginalRequestModel)) + assert.Nil(t, ctx.GetContext(ctxKeyFinalRequestModel)) + }) + + t.Run("surfaces multipart parse errors", func(t *testing.T) { + config := &ProviderConfig{} + ctx := newMockMultipartHttpContext() + + _, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, []byte("bad-body"), "multipart/form-data") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing multipart boundary") + }) +} + +func TestExtractMultipartModel(t *testing.T) { + t.Run("extracts model value", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + "prompt": "Turn the dog white", + }, nil) + + model, err := extractMultipartModel(body, contentType) + require.NoError(t, err) + assert.Equal(t, "gpt-image-1.5", model) + }) + + t.Run("returns empty model when field missing", func(t *testing.T) { + body, contentType := buildProviderMultipartRequestBody(t, map[string]string{ + "prompt": "Turn the dog white", + }, nil) + + model, err := extractMultipartModel(body, contentType) + require.NoError(t, err) + assert.Equal(t, "", model) + }) + + t.Run("returns parse error for invalid content type", func(t *testing.T) { + _, err := extractMultipartModel([]byte("bad-body"), "multipart/form-data; boundary=\"") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse content-type") + }) + + t.Run("returns parse error for malformed multipart header", func(t *testing.T) { + body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n") + _, err := extractMultipartModel(body, "multipart/form-data; boundary=abc") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to read multipart part") + }) + + t.Run("returns field read error on truncated model part", func(t *testing.T) { + body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab") + _, err := extractMultipartModel(body, "multipart/form-data; boundary=abc") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to read multipart field model") + }) +} + +func TestParseMultipartImageRequestContentTypeError(t *testing.T) { + _, err := parseMultipartImageRequest([]byte("bad-body"), "multipart/form-data; boundary=\"") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse content-type") +} + +func TestIsMultipartFormData(t *testing.T) { + assert.True(t, isMultipartFormData("multipart/form-data; boundary=abc")) + assert.False(t, isMultipartFormData("application/json")) + assert.False(t, isMultipartFormData("multipart/form-data; boundary=\"")) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 351a96a29..98f85091d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1264,6 +1264,10 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt // defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能 func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if contentType, err := proxywasm.GetHttpRequestHeader(util.HeaderContentType); err == nil && isMultipartFormData(contentType) { + return c.defaultTransformMultipartRequestBody(ctx, apiName, body, contentType) + } + switch apiName { case ApiNameChatCompletion, ApiNameVideos, @@ -1283,6 +1287,28 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap return sjson.SetBytes(body, "model", mappedModel) } +func (c *ProviderConfig) defaultTransformMultipartRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, contentType string) ([]byte, error) { + if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation { + return body, nil + } + + model, err := extractMultipartModel(body, contentType) + if err != nil { + return nil, err + } + + ctx.SetContext(ctxKeyOriginalRequestModel, model) + + mappedModel := getMappedModel(model, c.modelMapping) + ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) + + if mappedModel == model || (mappedModel == "" && model == "") { + return body, nil + } + + return rewriteMultipartFormModel(body, contentType, mappedModel) +} + func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { if c.protocol == protocolOriginal { ctx.DontReadResponseBody() diff --git a/plugins/wasm-go/extensions/ai-proxy/test/azure.go b/plugins/wasm-go/extensions/ai-proxy/test/azure.go index eb23515f5..e0d0223c1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/azure.go @@ -1,7 +1,11 @@ package test import ( + "bytes" "encoding/json" + "io" + "mime" + "mime/multipart" "strings" "testing" @@ -80,6 +84,22 @@ var azureDomainOnlyConfig = func() json.RawMessage { return data }() +var azureDomainOnlyImageMultipartConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "azure", + "apiTokens": []string{ + "sk-azure-image-multipart", + }, + "azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview", + "modelMapping": map[string]string{ + "gpt-image-1.5": "gpt-image-1", + }, + }, + }) + return data +}() + // 测试配置:Azure OpenAI多模型配置 var azureMultiModelConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -99,6 +119,74 @@ var azureMultiModelConfig = func() json.RawMessage { return data }() +func getMultipartTextField(body []byte, contentType string, fieldName string) (string, bool, error) { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return "", false, err + } + boundary := params["boundary"] + if boundary == "" { + return "", false, nil + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + for { + part, err := reader.NextPart() + if err == io.EOF { + return "", false, nil + } + if err != nil { + return "", false, err + } + + data, err := io.ReadAll(part) + _ = part.Close() + if err != nil { + return "", false, err + } + if part.FormName() == fieldName { + return string(data), true, nil + } + } +} + +func RunAzureMultipartHelperTests(t *testing.T) { + t.Run("multipart text field returns error for invalid content type", func(t *testing.T) { + _, _, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data; boundary=\"", "model") + require.Error(t, err) + }) + + t.Run("multipart text field returns not found for missing boundary", func(t *testing.T) { + value, found, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data", "model") + require.NoError(t, err) + require.False(t, found) + require.Equal(t, "", value) + }) + + t.Run("multipart text field returns not found on eof", func(t *testing.T) { + body, contentType := buildMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + }, nil) + + value, found, err := getMultipartTextField(body, contentType, "prompt") + require.NoError(t, err) + require.False(t, found) + require.Equal(t, "", value) + }) + + t.Run("multipart text field returns next part error on malformed body", func(t *testing.T) { + body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n") + _, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model") + require.Error(t, err) + }) + + t.Run("multipart text field returns read error on truncated part", func(t *testing.T) { + body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab") + _, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model") + require.Error(t, err) + }) +} + // 测试配置:Azure OpenAI无效配置(缺少azureServiceUrl) var azureInvalidConfigMissingUrl = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -615,6 +703,121 @@ func RunAzureOnHttpRequestBodyTests(t *testing.T) { require.Equal(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions?api-version=2024-02-15-preview", "Path should use model from request body") }) + t.Run("azure domain only multipart image edit request body", func(t *testing.T) { + host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + body, contentType := buildMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + "prompt": "把小狗换成白色", + "size": "1024x1024", + "n": "1", + }, map[string][]byte{ + "image[]": []byte("fake-image-content"), + }) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", contentType}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + action = host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + transformedBody := host.GetRequestBody() + require.NotNil(t, transformedBody) + + modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model") + require.NoError(t, err) + require.True(t, found, "Model field should exist in multipart body") + require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body") + require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body") + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath, "Path header should exist") + require.Equal(t, "/openai/deployments/gpt-image-1/images/edits?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model") + + contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type") + require.True(t, hasContentType, "Content-Type header should exist") + require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged") + }) + + t.Run("azure domain only multipart image variation request body", func(t *testing.T) { + host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + body, contentType := buildMultipartRequestBody(t, map[string]string{ + "model": "gpt-image-1.5", + "prompt": "生成类似风格", + "size": "1024x1024", + "n": "1", + }, map[string][]byte{ + "image": []byte("fake-image-content"), + }) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/variations"}, + {":method", "POST"}, + {"Content-Type", contentType}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + action = host.CallOnHttpRequestBody(body) + require.Equal(t, types.ActionContinue, action) + + transformedBody := host.GetRequestBody() + require.NotNil(t, transformedBody) + + modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model") + require.NoError(t, err) + require.True(t, found, "Model field should exist in multipart body") + require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body") + require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body") + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath, "Path header should exist") + require.Equal(t, "/openai/deployments/gpt-image-1/images/variations?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model") + + contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type") + require.True(t, hasContentType, "Content-Type header should exist") + require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged") + }) + + t.Run("azure domain only multipart malformed body logs transform failure", func(t *testing.T) { + host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/images/edits"}, + {":method", "POST"}, + {"Content-Type", "multipart/form-data"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + action = host.CallOnHttpRequestBody([]byte("bad-multipart-body")) + require.Equal(t, types.ActionContinue, action) + + debugLogs := host.GetDebugLogs() + hasMultipartTransformFailureLog := false + for _, debugLog := range debugLogs { + if strings.Contains(debugLog, "[azure multipart] body transform failed") { + hasMultipartTransformFailureLog = true + break + } + } + require.True(t, hasMultipartTransformFailureLog, "Should log azure multipart transform failure") + }) + // 测试Azure OpenAI模型无关请求处理(仅域名配置) t.Run("azure domain only model independent", func(t *testing.T) { host, status := test.NewTestHost(azureDomainOnlyConfig)