mirror of
https://github.com/alibaba/higress.git
synced 2026-05-28 22:57:31 +08:00
feat(provider): 优化 Azure multipart 处理 || feat(provider): Optimize Azure multipart processing (#3651)
This commit is contained in:
@@ -8,10 +8,15 @@ import (
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var newMultipartWriter = func(w io.Writer) *multipart.Writer {
|
||||
return multipart.NewWriter(w)
|
||||
}
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
@@ -30,14 +35,22 @@ func isMultipartFormData(contentType string) bool {
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
func parseMultipartBoundary(contentType string) (string, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||
return "", fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("missing multipart boundary")
|
||||
return "", fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
return boundary, nil
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
@@ -111,6 +124,110 @@ func parseMultipartImageRequest(body []byte, contentType string) (*multipartImag
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func extractMultipartModel(body []byte, contentType string) (string, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
model := ""
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
var readErr error
|
||||
if fieldName == "model" {
|
||||
var partData []byte
|
||||
partData, readErr = io.ReadAll(part)
|
||||
if readErr == nil {
|
||||
model = strings.TrimSpace(string(partData))
|
||||
}
|
||||
} else {
|
||||
_, readErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if readErr != nil {
|
||||
return "", fmt.Errorf("unable to read multipart field %s: %v", fieldName, readErr)
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func rewriteMultipartFormModel(body []byte, contentType string, model string) ([]byte, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := newMultipartWriter(&buffer)
|
||||
if err := writer.SetBoundary(boundary); err != nil {
|
||||
return nil, fmt.Errorf("unable to set multipart boundary: %v", err)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
modelFound := false
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
newPart, err := writer.CreatePart(cloneMultipartPartHeader(part.Header))
|
||||
if err != nil {
|
||||
_ = part.Close()
|
||||
return nil, fmt.Errorf("unable to create multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
var copyErr error
|
||||
if fieldName == "model" {
|
||||
modelFound = true
|
||||
if _, copyErr = io.WriteString(newPart, model); copyErr == nil {
|
||||
_, copyErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
} else {
|
||||
_, copyErr = io.Copy(newPart, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if copyErr != nil {
|
||||
return nil, fmt.Errorf("unable to write multipart field %s: %v", fieldName, copyErr)
|
||||
}
|
||||
}
|
||||
|
||||
if !modelFound && model != "" {
|
||||
if err := writer.WriteField("model", model); err != nil {
|
||||
return nil, fmt.Errorf("unable to append multipart model field: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to finalize multipart body: %v", err)
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func cloneMultipartPartHeader(header textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
cloned := make(textproto.MIMEHeader, len(header))
|
||||
for key, values := range header {
|
||||
copied := make([]string, len(values))
|
||||
copy(copied, values)
|
||||
cloned[key] = copied
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user