feat(provider): 优化 Azure multipart 处理 || feat(provider): Optimize Azure multipart processing (#3651)

This commit is contained in:
woody
2026-03-30 13:45:41 +08:00
committed by GitHub
parent 83461887dc
commit 889ea67013
6 changed files with 756 additions and 4 deletions

View File

@@ -158,6 +158,7 @@ func TestGemini(t *testing.T) {
func TestAzure(t *testing.T) {
test.RunAzureParseConfigTests(t)
test.RunAzureMultipartHelperTests(t)
test.RunAzureOnHttpRequestHeadersTests(t)
test.RunAzureOnHttpRequestBodyTests(t)
test.RunAzureOnHttpResponseHeadersTests(t)

View File

@@ -9,6 +9,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
@@ -151,17 +152,44 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func isAzureMultipartImageRequest(apiName ApiName, contentType string) bool {
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
return false
}
return isMultipartFormData(contentType)
}
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
transformedBody = body
err = nil
contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType)
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
if isMultipartImageRequest {
if err != nil {
log.Debugf("[azure multipart] body transform failed: api=%s, err=%v", apiName, err)
} else {
log.Debugf("[azure multipart] body transformed: api=%s, originalModel=%s, mappedModel=%s, bodyBytes=%d->%d",
apiName,
ctx.GetStringContext(ctxKeyOriginalRequestModel, ""),
ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
len(body),
len(transformedBody),
)
}
}
if err != nil {
return
}
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
if path := m.transformRequestPath(ctx, apiName); path != "" {
if isMultipartImageRequest {
log.Debugf("[azure multipart] body path overwrite: api=%s, path=%s, modelInContext=%s",
apiName, path, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
}
err = util.OverwriteRequestPath(path)
if err == nil {
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
@@ -222,16 +250,30 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
contentType := headers.Get(util.HeaderContentType)
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
// We need to overwrite the request path in the request headers stage,
// because for some APIs, we don't read the request body and the path is model irrelevant.
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
util.OverwriteRequestPathHeader(headers, overwrittenPath)
if isMultipartImageRequest {
log.Debugf("[azure multipart] header path overwrite: api=%s, path=%s, modelInContext=%s",
apiName, overwrittenPath, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
}
}
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) {
supportedAPI := m.config.isSupportedAPI(apiName)
needProcessBody := m.config.needToProcessRequestBody(apiName)
if isMultipartImageRequest {
log.Debugf("[azure multipart] body processing decision: api=%s, supported=%t, needProcessBody=%t",
apiName, supportedAPI, needProcessBody)
}
if !supportedAPI || !needProcessBody {
// If the API is not supported or there is no need to process the body,
// we should not read the request body and keep it as it is.
ctx.DontReadRequestBody()

View File

@@ -8,10 +8,15 @@ import (
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"strings"
)
var newMultipartWriter = func(w io.Writer) *multipart.Writer {
return multipart.NewWriter(w)
}
type multipartImageRequest struct {
Model string
Prompt string
@@ -30,14 +35,22 @@ func isMultipartFormData(contentType string) bool {
return strings.EqualFold(mediaType, "multipart/form-data")
}
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
func parseMultipartBoundary(contentType string) (string, error) {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("unable to parse content-type: %v", err)
return "", fmt.Errorf("unable to parse content-type: %v", err)
}
boundary := params["boundary"]
if boundary == "" {
return nil, fmt.Errorf("missing multipart boundary")
return "", fmt.Errorf("missing multipart boundary")
}
return boundary, nil
}
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
boundary, err := parseMultipartBoundary(contentType)
if err != nil {
return nil, err
}
req := &multipartImageRequest{
@@ -111,6 +124,110 @@ func parseMultipartImageRequest(body []byte, contentType string) (*multipartImag
return req, nil
}
func extractMultipartModel(body []byte, contentType string) (string, error) {
boundary, err := parseMultipartBoundary(contentType)
if err != nil {
return "", err
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
model := ""
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("unable to read multipart part: %v", err)
}
fieldName := part.FormName()
var readErr error
if fieldName == "model" {
var partData []byte
partData, readErr = io.ReadAll(part)
if readErr == nil {
model = strings.TrimSpace(string(partData))
}
} else {
_, readErr = io.Copy(io.Discard, part)
}
_ = part.Close()
if readErr != nil {
return "", fmt.Errorf("unable to read multipart field %s: %v", fieldName, readErr)
}
}
return model, nil
}
func rewriteMultipartFormModel(body []byte, contentType string, model string) ([]byte, error) {
boundary, err := parseMultipartBoundary(contentType)
if err != nil {
return nil, err
}
var buffer bytes.Buffer
writer := newMultipartWriter(&buffer)
if err := writer.SetBoundary(boundary); err != nil {
return nil, fmt.Errorf("unable to set multipart boundary: %v", err)
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
modelFound := false
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("unable to read multipart part: %v", err)
}
fieldName := part.FormName()
newPart, err := writer.CreatePart(cloneMultipartPartHeader(part.Header))
if err != nil {
_ = part.Close()
return nil, fmt.Errorf("unable to create multipart field %s: %v", fieldName, err)
}
var copyErr error
if fieldName == "model" {
modelFound = true
if _, copyErr = io.WriteString(newPart, model); copyErr == nil {
_, copyErr = io.Copy(io.Discard, part)
}
} else {
_, copyErr = io.Copy(newPart, part)
}
_ = part.Close()
if copyErr != nil {
return nil, fmt.Errorf("unable to write multipart field %s: %v", fieldName, copyErr)
}
}
if !modelFound && model != "" {
if err := writer.WriteField("model", model); err != nil {
return nil, fmt.Errorf("unable to append multipart model field: %v", err)
}
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("unable to finalize multipart body: %v", err)
}
return buffer.Bytes(), nil
}
func cloneMultipartPartHeader(header textproto.MIMEHeader) textproto.MIMEHeader {
cloned := make(textproto.MIMEHeader, len(header))
for key, values := range header {
copied := make([]string, len(values))
copy(copied, values)
cloned[key] = copied
}
return cloned
}
func isMultipartImageField(fieldName string) bool {
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
}

View File

@@ -0,0 +1,363 @@
package provider
import (
"bytes"
"errors"
"io"
"mime/multipart"
"strings"
"testing"
"github.com/higress-group/wasm-go/pkg/iface"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockMultipartHttpContext struct {
contextMap map[string]interface{}
}
func newMockMultipartHttpContext() *mockMultipartHttpContext {
return &mockMultipartHttpContext{contextMap: make(map[string]interface{})}
}
func (m *mockMultipartHttpContext) SetContext(key string, value interface{}) {
m.contextMap[key] = value
}
func (m *mockMultipartHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
func (m *mockMultipartHttpContext) GetBoolContext(key string, def bool) bool { return def }
func (m *mockMultipartHttpContext) GetStringContext(key, def string) string { return def }
func (m *mockMultipartHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
func (m *mockMultipartHttpContext) Scheme() string { return "" }
func (m *mockMultipartHttpContext) Host() string { return "" }
func (m *mockMultipartHttpContext) Path() string { return "" }
func (m *mockMultipartHttpContext) Method() string { return "" }
func (m *mockMultipartHttpContext) GetUserAttribute(key string) interface{} { return nil }
func (m *mockMultipartHttpContext) SetUserAttribute(key string, value interface{}) {}
func (m *mockMultipartHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
func (m *mockMultipartHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
func (m *mockMultipartHttpContext) WriteUserAttributeToLog() error { return nil }
func (m *mockMultipartHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
func (m *mockMultipartHttpContext) WriteUserAttributeToTrace() error { return nil }
func (m *mockMultipartHttpContext) DontReadRequestBody() {}
func (m *mockMultipartHttpContext) DontReadResponseBody() {}
func (m *mockMultipartHttpContext) BufferRequestBody() {}
func (m *mockMultipartHttpContext) BufferResponseBody() {}
func (m *mockMultipartHttpContext) NeedPauseStreamingResponse() {}
func (m *mockMultipartHttpContext) PushBuffer(buffer []byte) {}
func (m *mockMultipartHttpContext) PopBuffer() []byte { return nil }
func (m *mockMultipartHttpContext) BufferQueueSize() int { return 0 }
func (m *mockMultipartHttpContext) DisableReroute() {}
func (m *mockMultipartHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
func (m *mockMultipartHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
func (m *mockMultipartHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
return nil
}
func (m *mockMultipartHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
func (m *mockMultipartHttpContext) HasRequestBody() bool { return false }
func (m *mockMultipartHttpContext) HasResponseBody() bool { return false }
func (m *mockMultipartHttpContext) IsWebsocket() bool { return false }
func (m *mockMultipartHttpContext) IsBinaryRequestBody() bool { return false }
func (m *mockMultipartHttpContext) IsBinaryResponseBody() bool { return false }
func buildProviderMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
t.Helper()
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
for key, value := range fields {
require.NoError(t, writer.WriteField(key, value))
}
for fieldName, data := range files {
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
require.NoError(t, err)
_, err = part.Write(data)
require.NoError(t, err)
}
require.NoError(t, writer.Close())
return buffer.Bytes(), writer.FormDataContentType()
}
type failAfterNWriteWriter struct {
target io.Writer
failAtCall int
writeCalls int
}
func (w *failAfterNWriteWriter) Write(p []byte) (int, error) {
w.writeCalls++
if w.writeCalls >= w.failAtCall {
return 0, errors.New("injected write failure")
}
return w.target.Write(p)
}
func withInjectedMultipartWriterFactory(t *testing.T, failAtCall int, testFunc func()) {
t.Helper()
originalFactory := newMultipartWriter
newMultipartWriter = func(target io.Writer) *multipart.Writer {
return multipart.NewWriter(&failAfterNWriteWriter{
target: target,
failAtCall: failAtCall,
})
}
defer func() {
newMultipartWriter = originalFactory
}()
testFunc()
}
func TestRewriteMultipartFormModel(t *testing.T) {
t.Run("rewrites existing model field", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
"prompt": "Turn the dog white",
}, map[string][]byte{
"image[]": []byte("fake-image-content"),
})
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
require.NoError(t, err)
req, err := parseMultipartImageRequest(transformed, contentType)
require.NoError(t, err)
assert.Equal(t, "gpt-image-1", req.Model)
assert.Equal(t, "Turn the dog white", req.Prompt)
assert.Len(t, req.ImageURLs, 1)
assert.Contains(t, string(transformed), "fake-image-content")
})
t.Run("appends model field when missing", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"prompt": "Turn the dog white",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
require.NoError(t, err)
req, err := parseMultipartImageRequest(transformed, contentType)
require.NoError(t, err)
assert.Equal(t, "gpt-image-1", req.Model)
assert.Equal(t, "Turn the dog white", req.Prompt)
assert.Len(t, req.ImageURLs, 1)
})
t.Run("returns error on invalid content type", func(t *testing.T) {
_, err := rewriteMultipartFormModel([]byte("not-multipart"), "multipart/form-data", "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "missing multipart boundary")
})
t.Run("returns error when boundary cannot be set", func(t *testing.T) {
longBoundary := strings.Repeat("a", 71)
_, err := rewriteMultipartFormModel([]byte(""), "multipart/form-data; boundary="+longBoundary, "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to set multipart boundary")
})
t.Run("returns error on malformed multipart header", func(t *testing.T) {
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to read multipart part")
})
t.Run("returns error when multipart part copy fails", func(t *testing.T) {
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"image\"; filename=\"a.png\"\r\nContent-Type: image/png\r\n\r\nabc\r\n--ab")
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to write multipart field image")
})
t.Run("returns error when creating rewritten multipart part fails", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"prompt": "Turn the dog white",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
withInjectedMultipartWriterFactory(t, 1, func() {
_, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to create multipart field")
})
})
t.Run("returns error when appending model field fails", func(t *testing.T) {
withInjectedMultipartWriterFactory(t, 1, func() {
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "gpt-image-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to append multipart model field")
})
})
t.Run("returns error when finalizing multipart body fails", func(t *testing.T) {
withInjectedMultipartWriterFactory(t, 1, func() {
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to finalize multipart body")
})
})
}
func TestDefaultTransformMultipartRequestBody(t *testing.T) {
t.Run("maps multipart model and keeps body valid", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
"prompt": "Turn the dog white",
"size": "1024x1024",
}, map[string][]byte{
"image[]": []byte("fake-image-content"),
})
config := &ProviderConfig{
modelMapping: map[string]string{
"gpt-image-1.5": "gpt-image-1",
},
}
ctx := newMockMultipartHttpContext()
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
require.NoError(t, err)
req, err := parseMultipartImageRequest(transformed, contentType)
require.NoError(t, err)
assert.Equal(t, "gpt-image-1.5", ctx.GetContext(ctxKeyOriginalRequestModel))
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
assert.Equal(t, "gpt-image-1", req.Model)
assert.Equal(t, "Turn the dog white", req.Prompt)
assert.Len(t, req.ImageURLs, 1)
assert.Contains(t, string(transformed), "fake-image-content")
})
t.Run("appends mapped model when multipart request omits model", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"prompt": "Turn the dog white",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
config := &ProviderConfig{
modelMapping: map[string]string{
"*": "gpt-image-1",
},
}
ctx := newMockMultipartHttpContext()
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageVariation, body, contentType)
require.NoError(t, err)
req, err := parseMultipartImageRequest(transformed, contentType)
require.NoError(t, err)
assert.Equal(t, "", ctx.GetContext(ctxKeyOriginalRequestModel))
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
assert.Equal(t, "gpt-image-1", req.Model)
})
t.Run("returns original body when multipart model is unchanged", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1",
"prompt": "Turn the dog white",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
config := &ProviderConfig{}
ctx := newMockMultipartHttpContext()
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
require.NoError(t, err)
assert.Equal(t, body, transformed)
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyOriginalRequestModel))
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
})
t.Run("ignores non image multipart apis", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1",
}, nil)
config := &ProviderConfig{
modelMapping: map[string]string{
"gpt-image-1": "mapped-model",
},
}
ctx := newMockMultipartHttpContext()
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameChatCompletion, body, contentType)
require.NoError(t, err)
assert.Equal(t, body, transformed)
assert.Nil(t, ctx.GetContext(ctxKeyOriginalRequestModel))
assert.Nil(t, ctx.GetContext(ctxKeyFinalRequestModel))
})
t.Run("surfaces multipart parse errors", func(t *testing.T) {
config := &ProviderConfig{}
ctx := newMockMultipartHttpContext()
_, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, []byte("bad-body"), "multipart/form-data")
require.Error(t, err)
assert.Contains(t, err.Error(), "missing multipart boundary")
})
}
func TestExtractMultipartModel(t *testing.T) {
t.Run("extracts model value", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
"prompt": "Turn the dog white",
}, nil)
model, err := extractMultipartModel(body, contentType)
require.NoError(t, err)
assert.Equal(t, "gpt-image-1.5", model)
})
t.Run("returns empty model when field missing", func(t *testing.T) {
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
"prompt": "Turn the dog white",
}, nil)
model, err := extractMultipartModel(body, contentType)
require.NoError(t, err)
assert.Equal(t, "", model)
})
t.Run("returns parse error for invalid content type", func(t *testing.T) {
_, err := extractMultipartModel([]byte("bad-body"), "multipart/form-data; boundary=\"")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to parse content-type")
})
t.Run("returns parse error for malformed multipart header", func(t *testing.T) {
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to read multipart part")
})
t.Run("returns field read error on truncated model part", func(t *testing.T) {
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to read multipart field model")
})
}
func TestParseMultipartImageRequestContentTypeError(t *testing.T) {
_, err := parseMultipartImageRequest([]byte("bad-body"), "multipart/form-data; boundary=\"")
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to parse content-type")
}
func TestIsMultipartFormData(t *testing.T) {
assert.True(t, isMultipartFormData("multipart/form-data; boundary=abc"))
assert.False(t, isMultipartFormData("application/json"))
assert.False(t, isMultipartFormData("multipart/form-data; boundary=\""))
}

View File

@@ -1264,6 +1264,10 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
// defaultTransformRequestBody 默认的请求体转换方法只做模型映射用slog替换模型名称不用序列化和反序列化提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if contentType, err := proxywasm.GetHttpRequestHeader(util.HeaderContentType); err == nil && isMultipartFormData(contentType) {
return c.defaultTransformMultipartRequestBody(ctx, apiName, body, contentType)
}
switch apiName {
case ApiNameChatCompletion,
ApiNameVideos,
@@ -1283,6 +1287,28 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
return sjson.SetBytes(body, "model", mappedModel)
}
func (c *ProviderConfig) defaultTransformMultipartRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, contentType string) ([]byte, error) {
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
return body, nil
}
model, err := extractMultipartModel(body, contentType)
if err != nil {
return nil, err
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.modelMapping)
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
if mappedModel == model || (mappedModel == "" && model == "") {
return body, nil
}
return rewriteMultipartFormModel(body, contentType, mappedModel)
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
if c.protocol == protocolOriginal {
ctx.DontReadResponseBody()

View File

@@ -1,7 +1,11 @@
package test
import (
"bytes"
"encoding/json"
"io"
"mime"
"mime/multipart"
"strings"
"testing"
@@ -80,6 +84,22 @@ var azureDomainOnlyConfig = func() json.RawMessage {
return data
}()
var azureDomainOnlyImageMultipartConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-image-multipart",
},
"azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"gpt-image-1.5": "gpt-image-1",
},
},
})
return data
}()
// 测试配置Azure OpenAI多模型配置
var azureMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
@@ -99,6 +119,74 @@ var azureMultiModelConfig = func() json.RawMessage {
return data
}()
func getMultipartTextField(body []byte, contentType string, fieldName string) (string, bool, error) {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
return "", false, err
}
boundary := params["boundary"]
if boundary == "" {
return "", false, nil
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
for {
part, err := reader.NextPart()
if err == io.EOF {
return "", false, nil
}
if err != nil {
return "", false, err
}
data, err := io.ReadAll(part)
_ = part.Close()
if err != nil {
return "", false, err
}
if part.FormName() == fieldName {
return string(data), true, nil
}
}
}
func RunAzureMultipartHelperTests(t *testing.T) {
t.Run("multipart text field returns error for invalid content type", func(t *testing.T) {
_, _, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data; boundary=\"", "model")
require.Error(t, err)
})
t.Run("multipart text field returns not found for missing boundary", func(t *testing.T) {
value, found, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data", "model")
require.NoError(t, err)
require.False(t, found)
require.Equal(t, "", value)
})
t.Run("multipart text field returns not found on eof", func(t *testing.T) {
body, contentType := buildMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
}, nil)
value, found, err := getMultipartTextField(body, contentType, "prompt")
require.NoError(t, err)
require.False(t, found)
require.Equal(t, "", value)
})
t.Run("multipart text field returns next part error on malformed body", func(t *testing.T) {
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
require.Error(t, err)
})
t.Run("multipart text field returns read error on truncated part", func(t *testing.T) {
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
require.Error(t, err)
})
}
// 测试配置Azure OpenAI无效配置缺少azureServiceUrl
var azureInvalidConfigMissingUrl = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
@@ -615,6 +703,121 @@ func RunAzureOnHttpRequestBodyTests(t *testing.T) {
require.Equal(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions?api-version=2024-02-15-preview", "Path should use model from request body")
})
t.Run("azure domain only multipart image edit request body", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
body, contentType := buildMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
"prompt": "把小狗换成白色",
"size": "1024x1024",
"n": "1",
}, map[string][]byte{
"image[]": []byte("fake-image-content"),
})
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", contentType},
})
require.Equal(t, types.HeaderStopIteration, action)
action = host.CallOnHttpRequestBody(body)
require.Equal(t, types.ActionContinue, action)
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
require.NoError(t, err)
require.True(t, found, "Model field should exist in multipart body")
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, "/openai/deployments/gpt-image-1/images/edits?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
require.True(t, hasContentType, "Content-Type header should exist")
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
})
t.Run("azure domain only multipart image variation request body", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
body, contentType := buildMultipartRequestBody(t, map[string]string{
"model": "gpt-image-1.5",
"prompt": "生成类似风格",
"size": "1024x1024",
"n": "1",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/variations"},
{":method", "POST"},
{"Content-Type", contentType},
})
require.Equal(t, types.HeaderStopIteration, action)
action = host.CallOnHttpRequestBody(body)
require.Equal(t, types.ActionContinue, action)
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
require.NoError(t, err)
require.True(t, found, "Model field should exist in multipart body")
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, "/openai/deployments/gpt-image-1/images/variations?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
require.True(t, hasContentType, "Content-Type header should exist")
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
})
t.Run("azure domain only multipart malformed body logs transform failure", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", "multipart/form-data"},
})
require.Equal(t, types.HeaderStopIteration, action)
action = host.CallOnHttpRequestBody([]byte("bad-multipart-body"))
require.Equal(t, types.ActionContinue, action)
debugLogs := host.GetDebugLogs()
hasMultipartTransformFailureLog := false
for _, debugLog := range debugLogs {
if strings.Contains(debugLog, "[azure multipart] body transform failed") {
hasMultipartTransformFailureLog = true
break
}
}
require.True(t, hasMultipartTransformFailureLog, "Should log azure multipart transform failure")
})
// 测试Azure OpenAI模型无关请求处理仅域名配置
t.Run("azure domain only model independent", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyConfig)