mirror of
https://github.com/alibaba/higress.git
synced 2026-05-12 06:47:28 +08:00
feat(provider): 优化 Azure multipart 处理 || feat(provider): Optimize Azure multipart processing (#3651)
This commit is contained in:
@@ -158,6 +158,7 @@ func TestGemini(t *testing.T) {
|
||||
|
||||
func TestAzure(t *testing.T) {
|
||||
test.RunAzureParseConfigTests(t)
|
||||
test.RunAzureMultipartHelperTests(t)
|
||||
test.RunAzureOnHttpRequestHeadersTests(t)
|
||||
test.RunAzureOnHttpRequestBodyTests(t)
|
||||
test.RunAzureOnHttpResponseHeadersTests(t)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
@@ -151,17 +152,44 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func isAzureMultipartImageRequest(apiName ApiName, contentType string) bool {
|
||||
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
|
||||
return false
|
||||
}
|
||||
return isMultipartFormData(contentType)
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
|
||||
transformedBody = body
|
||||
err = nil
|
||||
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType)
|
||||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||||
|
||||
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
if isMultipartImageRequest {
|
||||
if err != nil {
|
||||
log.Debugf("[azure multipart] body transform failed: api=%s, err=%v", apiName, err)
|
||||
} else {
|
||||
log.Debugf("[azure multipart] body transformed: api=%s, originalModel=%s, mappedModel=%s, bodyBytes=%d->%d",
|
||||
apiName,
|
||||
ctx.GetStringContext(ctxKeyOriginalRequestModel, ""),
|
||||
ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
len(body),
|
||||
len(transformedBody),
|
||||
)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
|
||||
if path := m.transformRequestPath(ctx, apiName); path != "" {
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] body path overwrite: api=%s, path=%s, modelInContext=%s",
|
||||
apiName, path, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
}
|
||||
err = util.OverwriteRequestPath(path)
|
||||
if err == nil {
|
||||
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
|
||||
@@ -222,16 +250,30 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
contentType := headers.Get(util.HeaderContentType)
|
||||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||||
|
||||
// We need to overwrite the request path in the request headers stage,
|
||||
// because for some APIs, we don't read the request body and the path is model irrelevant.
|
||||
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
|
||||
util.OverwriteRequestPathHeader(headers, overwrittenPath)
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] header path overwrite: api=%s, path=%s, modelInContext=%s",
|
||||
apiName, overwrittenPath, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
}
|
||||
}
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) {
|
||||
supportedAPI := m.config.isSupportedAPI(apiName)
|
||||
needProcessBody := m.config.needToProcessRequestBody(apiName)
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] body processing decision: api=%s, supported=%t, needProcessBody=%t",
|
||||
apiName, supportedAPI, needProcessBody)
|
||||
}
|
||||
|
||||
if !supportedAPI || !needProcessBody {
|
||||
// If the API is not supported or there is no need to process the body,
|
||||
// we should not read the request body and keep it as it is.
|
||||
ctx.DontReadRequestBody()
|
||||
|
||||
@@ -8,10 +8,15 @@ import (
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var newMultipartWriter = func(w io.Writer) *multipart.Writer {
|
||||
return multipart.NewWriter(w)
|
||||
}
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
@@ -30,14 +35,22 @@ func isMultipartFormData(contentType string) bool {
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
func parseMultipartBoundary(contentType string) (string, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||
return "", fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("missing multipart boundary")
|
||||
return "", fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
return boundary, nil
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
@@ -111,6 +124,110 @@ func parseMultipartImageRequest(body []byte, contentType string) (*multipartImag
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func extractMultipartModel(body []byte, contentType string) (string, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
model := ""
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
var readErr error
|
||||
if fieldName == "model" {
|
||||
var partData []byte
|
||||
partData, readErr = io.ReadAll(part)
|
||||
if readErr == nil {
|
||||
model = strings.TrimSpace(string(partData))
|
||||
}
|
||||
} else {
|
||||
_, readErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if readErr != nil {
|
||||
return "", fmt.Errorf("unable to read multipart field %s: %v", fieldName, readErr)
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func rewriteMultipartFormModel(body []byte, contentType string, model string) ([]byte, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := newMultipartWriter(&buffer)
|
||||
if err := writer.SetBoundary(boundary); err != nil {
|
||||
return nil, fmt.Errorf("unable to set multipart boundary: %v", err)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
modelFound := false
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
newPart, err := writer.CreatePart(cloneMultipartPartHeader(part.Header))
|
||||
if err != nil {
|
||||
_ = part.Close()
|
||||
return nil, fmt.Errorf("unable to create multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
var copyErr error
|
||||
if fieldName == "model" {
|
||||
modelFound = true
|
||||
if _, copyErr = io.WriteString(newPart, model); copyErr == nil {
|
||||
_, copyErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
} else {
|
||||
_, copyErr = io.Copy(newPart, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if copyErr != nil {
|
||||
return nil, fmt.Errorf("unable to write multipart field %s: %v", fieldName, copyErr)
|
||||
}
|
||||
}
|
||||
|
||||
if !modelFound && model != "" {
|
||||
if err := writer.WriteField("model", model); err != nil {
|
||||
return nil, fmt.Errorf("unable to append multipart model field: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to finalize multipart body: %v", err)
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func cloneMultipartPartHeader(header textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
cloned := make(textproto.MIMEHeader, len(header))
|
||||
for key, values := range header {
|
||||
copied := make([]string, len(values))
|
||||
copy(copied, values)
|
||||
cloned[key] = copied
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/iface"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockMultipartHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func newMockMultipartHttpContext() *mockMultipartHttpContext {
|
||||
return &mockMultipartHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *mockMultipartHttpContext) SetContext(key string, value interface{}) {
|
||||
m.contextMap[key] = value
|
||||
}
|
||||
func (m *mockMultipartHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *mockMultipartHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *mockMultipartHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *mockMultipartHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *mockMultipartHttpContext) Scheme() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Host() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Path() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Method() string { return "" }
|
||||
func (m *mockMultipartHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *mockMultipartHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *mockMultipartHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *mockMultipartHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *mockMultipartHttpContext) DontReadRequestBody() {}
|
||||
func (m *mockMultipartHttpContext) DontReadResponseBody() {}
|
||||
func (m *mockMultipartHttpContext) BufferRequestBody() {}
|
||||
func (m *mockMultipartHttpContext) BufferResponseBody() {}
|
||||
func (m *mockMultipartHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *mockMultipartHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *mockMultipartHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *mockMultipartHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *mockMultipartHttpContext) DisableReroute() {}
|
||||
func (m *mockMultipartHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockMultipartHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockMultipartHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockMultipartHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *mockMultipartHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
|
||||
func buildProviderMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := multipart.NewWriter(&buffer)
|
||||
|
||||
for key, value := range fields {
|
||||
require.NoError(t, writer.WriteField(key, value))
|
||||
}
|
||||
for fieldName, data := range files {
|
||||
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, writer.Close())
|
||||
return buffer.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
type failAfterNWriteWriter struct {
|
||||
target io.Writer
|
||||
failAtCall int
|
||||
writeCalls int
|
||||
}
|
||||
|
||||
func (w *failAfterNWriteWriter) Write(p []byte) (int, error) {
|
||||
w.writeCalls++
|
||||
if w.writeCalls >= w.failAtCall {
|
||||
return 0, errors.New("injected write failure")
|
||||
}
|
||||
return w.target.Write(p)
|
||||
}
|
||||
|
||||
func withInjectedMultipartWriterFactory(t *testing.T, failAtCall int, testFunc func()) {
|
||||
t.Helper()
|
||||
|
||||
originalFactory := newMultipartWriter
|
||||
newMultipartWriter = func(target io.Writer) *multipart.Writer {
|
||||
return multipart.NewWriter(&failAfterNWriteWriter{
|
||||
target: target,
|
||||
failAtCall: failAtCall,
|
||||
})
|
||||
}
|
||||
defer func() {
|
||||
newMultipartWriter = originalFactory
|
||||
}()
|
||||
|
||||
testFunc()
|
||||
}
|
||||
|
||||
func TestRewriteMultipartFormModel(t *testing.T) {
|
||||
t.Run("rewrites existing model field", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
assert.Contains(t, string(transformed), "fake-image-content")
|
||||
})
|
||||
|
||||
t.Run("appends model field when missing", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
})
|
||||
|
||||
t.Run("returns error on invalid content type", func(t *testing.T) {
|
||||
_, err := rewriteMultipartFormModel([]byte("not-multipart"), "multipart/form-data", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing multipart boundary")
|
||||
})
|
||||
|
||||
t.Run("returns error when boundary cannot be set", func(t *testing.T) {
|
||||
longBoundary := strings.Repeat("a", 71)
|
||||
_, err := rewriteMultipartFormModel([]byte(""), "multipart/form-data; boundary="+longBoundary, "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to set multipart boundary")
|
||||
})
|
||||
|
||||
t.Run("returns error on malformed multipart header", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart part")
|
||||
})
|
||||
|
||||
t.Run("returns error when multipart part copy fails", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"image\"; filename=\"a.png\"\r\nContent-Type: image/png\r\n\r\nabc\r\n--ab")
|
||||
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to write multipart field image")
|
||||
})
|
||||
|
||||
t.Run("returns error when creating rewritten multipart part fails", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to create multipart field")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("returns error when appending model field fails", func(t *testing.T) {
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to append multipart model field")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("returns error when finalizing multipart body fails", func(t *testing.T) {
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to finalize multipart body")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultTransformMultipartRequestBody(t *testing.T) {
|
||||
t.Run("maps multipart model and keeps body valid", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"gpt-image-1.5": "gpt-image-1",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1.5", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
assert.Contains(t, string(transformed), "fake-image-content")
|
||||
})
|
||||
|
||||
t.Run("appends mapped model when multipart request omits model", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"*": "gpt-image-1",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageVariation, body, contentType)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
})
|
||||
|
||||
t.Run("returns original body when multipart model is unchanged", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1",
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, transformed)
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
})
|
||||
|
||||
t.Run("ignores non image multipart apis", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1",
|
||||
}, nil)
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"gpt-image-1": "mapped-model",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameChatCompletion, body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, transformed)
|
||||
assert.Nil(t, ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Nil(t, ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
})
|
||||
|
||||
t.Run("surfaces multipart parse errors", func(t *testing.T) {
|
||||
config := &ProviderConfig{}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
_, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, []byte("bad-body"), "multipart/form-data")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing multipart boundary")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractMultipartModel(t *testing.T) {
|
||||
t.Run("extracts model value", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
}, nil)
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1.5", model)
|
||||
})
|
||||
|
||||
t.Run("returns empty model when field missing", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, nil)
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", model)
|
||||
})
|
||||
|
||||
t.Run("returns parse error for invalid content type", func(t *testing.T) {
|
||||
_, err := extractMultipartModel([]byte("bad-body"), "multipart/form-data; boundary=\"")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to parse content-type")
|
||||
})
|
||||
|
||||
t.Run("returns parse error for malformed multipart header", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart part")
|
||||
})
|
||||
|
||||
t.Run("returns field read error on truncated model part", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
|
||||
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart field model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseMultipartImageRequestContentTypeError(t *testing.T) {
|
||||
_, err := parseMultipartImageRequest([]byte("bad-body"), "multipart/form-data; boundary=\"")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to parse content-type")
|
||||
}
|
||||
|
||||
func TestIsMultipartFormData(t *testing.T) {
|
||||
assert.True(t, isMultipartFormData("multipart/form-data; boundary=abc"))
|
||||
assert.False(t, isMultipartFormData("application/json"))
|
||||
assert.False(t, isMultipartFormData("multipart/form-data; boundary=\""))
|
||||
}
|
||||
@@ -1264,6 +1264,10 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if contentType, err := proxywasm.GetHttpRequestHeader(util.HeaderContentType); err == nil && isMultipartFormData(contentType) {
|
||||
return c.defaultTransformMultipartRequestBody(ctx, apiName, body, contentType)
|
||||
}
|
||||
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion,
|
||||
ApiNameVideos,
|
||||
@@ -1283,6 +1287,28 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
||||
return sjson.SetBytes(body, "model", mappedModel)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) defaultTransformMultipartRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, contentType string) ([]byte, error) {
|
||||
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
|
||||
mappedModel := getMappedModel(model, c.modelMapping)
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
|
||||
if mappedModel == model || (mappedModel == "" && model == "") {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return rewriteMultipartFormModel(body, contentType, mappedModel)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
if c.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -80,6 +84,22 @@ var azureDomainOnlyConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
var azureDomainOnlyImageMultipartConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-image-multipart",
|
||||
},
|
||||
"azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-image-1.5": "gpt-image-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI多模型配置
|
||||
var azureMultiModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -99,6 +119,74 @@ var azureMultiModelConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
func getMultipartTextField(body []byte, contentType string, fieldName string) (string, bool, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if part.FormName() == fieldName {
|
||||
return string(data), true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunAzureMultipartHelperTests(t *testing.T) {
|
||||
t.Run("multipart text field returns error for invalid content type", func(t *testing.T) {
|
||||
_, _, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data; boundary=\"", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns not found for missing boundary", func(t *testing.T) {
|
||||
value, found, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data", "model")
|
||||
require.NoError(t, err)
|
||||
require.False(t, found)
|
||||
require.Equal(t, "", value)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns not found on eof", func(t *testing.T) {
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
}, nil)
|
||||
|
||||
value, found, err := getMultipartTextField(body, contentType, "prompt")
|
||||
require.NoError(t, err)
|
||||
require.False(t, found)
|
||||
require.Equal(t, "", value)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns next part error on malformed body", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns read error on truncated part", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
|
||||
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置:Azure OpenAI无效配置(缺少azureServiceUrl)
|
||||
var azureInvalidConfigMissingUrl = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -615,6 +703,121 @@ func RunAzureOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.Equal(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions?api-version=2024-02-15-preview", "Path should use model from request body")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart image edit request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "把小狗换成白色",
|
||||
"size": "1024x1024",
|
||||
"n": "1",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
|
||||
require.NoError(t, err)
|
||||
require.True(t, found, "Model field should exist in multipart body")
|
||||
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
|
||||
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, "/openai/deployments/gpt-image-1/images/edits?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
|
||||
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart image variation request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "生成类似风格",
|
||||
"size": "1024x1024",
|
||||
"n": "1",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
|
||||
require.NoError(t, err)
|
||||
require.True(t, found, "Model field should exist in multipart body")
|
||||
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
|
||||
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, "/openai/deployments/gpt-image-1/images/variations?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
|
||||
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart malformed body logs transform failure", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "multipart/form-data"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte("bad-multipart-body"))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasMultipartTransformFailureLog := false
|
||||
for _, debugLog := range debugLogs {
|
||||
if strings.Contains(debugLog, "[azure multipart] body transform failed") {
|
||||
hasMultipartTransformFailureLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasMultipartTransformFailureLog, "Should log azure multipart transform failure")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI模型无关请求处理(仅域名配置)
|
||||
t.Run("azure domain only model independent", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyConfig)
|
||||
|
||||
Reference in New Issue
Block a user