Compare commits

...

8 Commits

15 changed files with 1174 additions and 59 deletions

View File

@@ -1 +1 @@
v2.2.0 v2.2.1

View File

@@ -4,6 +4,6 @@ dependencies:
version: 2.2.0 version: 2.2.0
- name: higress-console - name: higress-console
repository: https://higress.io/helm-charts/ repository: https://higress.io/helm-charts/
version: 2.2.0 version: 2.2.1
digest: sha256:2cb148fa6d52856344e1905d3fea018466c2feb52013e08997c2d5c7d50f2e5d digest: sha256:23fe7b0f84965c13ac7ceabe6334212fc3d323b7b781277a6d2b6fd38e935dda
generated: "2026-02-11T17:45:59.187965929+08:00" generated: "2026-03-07T12:45:44.267732+08:00"

View File

@@ -1,5 +1,5 @@
apiVersion: v2 apiVersion: v2
appVersion: 2.2.0 appVersion: 2.2.1
description: Helm chart for deploying Higress gateways description: Helm chart for deploying Higress gateways
icon: https://higress.io/img/higress_logo_small.png icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/ home: http://higress.io/
@@ -15,6 +15,6 @@ dependencies:
version: 2.2.0 version: 2.2.0
- name: higress-console - name: higress-console
repository: "https://higress.io/helm-charts/" repository: "https://higress.io/helm-charts/"
version: 2.2.0 version: 2.2.1
type: application type: application
version: 2.2.0 version: 2.2.1

View File

@@ -140,10 +140,16 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
// Send the initial endpoint event // Send the initial endpoint event
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint) initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
err = s.redisClient.Publish(channel, initialEvent) go func() {
if err != nil { defer func() {
api.LogErrorf("Failed to send initial event: %v", err) if r := recover(); r != nil {
api.LogErrorf("Failed to send initial event: %v", r)
} }
}()
defer cb.EncoderFilterCallbacks().RecoverPanic()
api.LogDebugf("SSE Send message: %s", initialEvent)
cb.EncoderFilterCallbacks().InjectData([]byte(initialEvent))
}()
// Start health check handler // Start health check handler
go func() { go func() {

View File

@@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
} }
} }
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) { if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
ctx.DontReadRequestBody() ctx.DontReadRequestBody()
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType) log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
} }
if apiName == "" { if apiName == "" {
@@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil { if err == nil {
return action return action
} }
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) _ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
} }
return types.ActionContinue return types.ActionContinue
@@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName {
return "" return ""
} }
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
return true
}
contentType = strings.ToLower(contentType)
if strings.HasPrefix(contentType, "multipart/form-data") {
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
}
return false
}

View File

@@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) {
} }
} }
func Test_isSupportedRequestContentType(t *testing.T) {
tests := []struct {
name string
apiName provider.ApiName
contentType string
want bool
}{
{
name: "json chat completion",
apiName: provider.ApiNameChatCompletion,
contentType: "application/json",
want: true,
},
{
name: "multipart image edit",
apiName: provider.ApiNameImageEdit,
contentType: "multipart/form-data; boundary=----boundary",
want: true,
},
{
name: "multipart image variation",
apiName: provider.ApiNameImageVariation,
contentType: "multipart/form-data; boundary=----boundary",
want: true,
},
{
name: "multipart chat completion",
apiName: provider.ApiNameChatCompletion,
contentType: "multipart/form-data; boundary=----boundary",
want: false,
},
{
name: "text plain image edit",
apiName: provider.ApiNameImageEdit,
contentType: "text/plain",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
if got != tt.want {
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
}
})
}
}
func TestAi360(t *testing.T) { func TestAi360(t *testing.T) {
test.RunAi360ParseConfigTests(t) test.RunAi360ParseConfigTests(t)
test.RunAi360OnHttpRequestHeadersTests(t) test.RunAi360OnHttpRequestHeadersTests(t)
@@ -137,6 +185,8 @@ func TestVertex(t *testing.T) {
test.RunVertexExpressModeOnStreamingResponseBodyTests(t) test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
test.RunVertexExpressModeImageGenerationRequestBodyTests(t) test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
test.RunVertexExpressModeImageGenerationResponseBodyTests(t) test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
// Vertex Raw 模式测试 // Vertex Raw 模式测试
test.RunVertexRawModeOnHttpRequestHeadersTests(t) test.RunVertexRawModeOnHttpRequestHeadersTests(t)
test.RunVertexRawModeOnHttpRequestBodyTests(t) test.RunVertexRawModeOnHttpRequestBodyTests(t)

View File

@@ -40,6 +40,11 @@ const (
requestIdHeader = "X-Amzn-Requestid" requestIdHeader = "X-Amzn-Requestid"
) )
var (
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
)
type bedrockProviderInitializer struct{} type bedrockProviderInitializer struct{}
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
@@ -630,13 +635,24 @@ func (b *bedrockProvider) GetProviderType() string {
return providerTypeBedrock return providerTypeBedrock
} }
func (b *bedrockProvider) GetApiName(path string) ApiName {
switch {
case bedrockConversePathPattern.MatchString(path):
return ApiNameChatCompletion
case bedrockInvokePathPattern.MatchString(path):
return ApiNameImageGeneration
default:
return ""
}
}
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
b.config.handleRequestHeaders(b, ctx, apiName) b.config.handleRequestHeaders(b, ctx, apiName)
return nil return nil
} }
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)) util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
// If apiTokens is configured, set Bearer token authentication here // If apiTokens is configured, set Bearer token authentication here
// This follows the same pattern as other providers (qwen, zhipuai, etc.) // This follows the same pattern as other providers (qwen, zhipuai, etc.)
@@ -647,6 +663,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
} }
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
// and only apply auth headers.
if b.config.IsOriginal() {
headers := util.GetRequestHeaders()
b.setAuthHeaders(body, headers)
util.ReplaceRequestHeaders(headers)
return types.ActionContinue, replaceRequestBody(body)
}
if !b.config.isSupportedAPI(apiName) { if !b.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
@@ -654,14 +679,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
} }
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
var transformedBody []byte
var err error
switch apiName { switch apiName {
case ApiNameChatCompletion: case ApiNameChatCompletion:
return b.onChatCompletionRequestBody(ctx, body, headers) transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameImageGeneration: case ApiNameImageGeneration:
return b.onImageGenerationRequestBody(ctx, body, headers) transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
default: default:
return b.config.defaultTransformRequestBody(ctx, apiName, body) transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
} }
if err != nil {
return nil, err
}
// Always apply auth after request body/path are finalized.
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
b.setAuthHeaders(transformedBody, headers)
return transformedBody, nil
} }
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
@@ -715,9 +751,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
Quality: origRequest.Quality, Quality: origRequest.Quality,
}, },
} }
requestBytes, err := json.Marshal(request) return json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
} }
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse { func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
@@ -847,9 +881,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
request.AdditionalModelRequestFields[key] = value request.AdditionalModelRequestFields[key] = value
} }
requestBytes, err := json.Marshal(request) return json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
} }
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse { func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
@@ -1163,45 +1195,88 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
} }
// Use AWS Signature V4 authentication // Use AWS Signature V4 authentication
accessKey := strings.TrimSpace(b.config.awsAccessKey)
region := strings.TrimSpace(b.config.awsRegion)
t := time.Now().UTC() t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z") amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102") dateStamp := t.Format("20060102")
path := headers.Get(":path") path := headers.Get(":path")
signature := b.generateSignature(path, amzDate, dateStamp, body) signature := b.generateSignature(path, amzDate, dateStamp, body)
headers.Set("X-Amz-Date", amzDate) headers.Set("X-Amz-Date", amzDate)
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature)) util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
} }
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string { func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
path = encodeSigV4Path(path) canonicalURI := encodeSigV4Path(path)
hashedPayload := sha256Hex(body) hashedPayload := sha256Hex(body)
region := strings.TrimSpace(b.config.awsRegion)
secretKey := strings.TrimSpace(b.config.awsSecretKey)
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion) endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate) canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s", canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload) httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService) credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
hashedCanonReq := sha256Hex([]byte(canonicalRequest)) hashedCanonReq := sha256Hex([]byte(canonicalRequest))
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s", stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
amzDate, credentialScope, hashedCanonReq) amzDate, credentialScope, hashedCanonReq)
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService) signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
signature := hmacHex(signingKey, stringToSign) signature := hmacHex(signingKey, stringToSign)
return signature return signature
} }
func encodeSigV4Path(path string) string { func encodeSigV4Path(path string) string {
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
// and this implementation uses an empty canonical query string.
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
path = path[:queryIndex]
}
segments := strings.Split(path, "/") segments := strings.Split(path, "/")
for i, seg := range segments { for i, seg := range segments {
if seg == "" { if seg == "" {
continue continue
} }
segments[i] = url.PathEscape(seg) // Normalize to "single-encoded" form:
// - raw ":" -> %3A
// - already encoded "%3A" -> still %3A (not %253A)
decoded, err := url.PathUnescape(seg)
if err == nil {
segments[i] = sigV4EscapePathSegment(decoded)
} else {
// If segment has invalid escape sequence, fall back to escaping raw segment.
segments[i] = sigV4EscapePathSegment(seg)
}
} }
return strings.Join(segments, "/") return strings.Join(segments, "/")
} }
func sigV4EscapePathSegment(segment string) string {
const upperHex = "0123456789ABCDEF"
var b strings.Builder
b.Grow(len(segment) * 3)
for i := 0; i < len(segment); i++ {
c := segment[i]
if isSigV4Unreserved(c) {
b.WriteByte(c)
continue
}
b.WriteByte('%')
b.WriteByte(upperHex[c>>4])
b.WriteByte(upperHex[c&0x0F])
}
return b.String()
}
func isSigV4Unreserved(c byte) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '_' || c == '.' || c == '~'
}
func getSignatureKey(key, dateStamp, region, service string) []byte { func getSignatureKey(key, dateStamp, region, service string) []byte {
kDate := hmacSha256([]byte("AWS4"+key), dateStamp) kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
kRegion := hmacSha256(kDate, region) kRegion := hmacSha256(kDate, region)

View File

@@ -1,6 +1,7 @@
package provider package provider
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
} }
type imageInputURL struct {
URL string `json:"url,omitempty"`
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
}
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
// Support a plain string payload, e.g. "data:image/png;base64,..."
var rawURL string
if err := json.Unmarshal(data, &rawURL); err == nil {
i.URL = rawURL
i.ImageURL = nil
return nil
}
type alias imageInputURL
var value alias
if err := json.Unmarshal(data, &value); err != nil {
return err
}
*i = imageInputURL(value)
return nil
}
func (i *imageInputURL) GetURL() string {
if i == nil {
return ""
}
if i.ImageURL != nil && i.ImageURL.Url != "" {
return i.ImageURL.Url
}
return i.URL
}
type imageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image *imageInputURL `json:"image,omitempty"`
Images []imageInputURL `json:"images,omitempty"`
ImageURL *imageInputURL `json:"image_url,omitempty"`
Mask *imageInputURL `json:"mask,omitempty"`
MaskURL *imageInputURL `json:"mask_url,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
func (r *imageEditRequest) GetImageURLs() []string {
urls := make([]string, 0, len(r.Images)+2)
for _, image := range r.Images {
if url := image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.Image != nil {
if url := r.Image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.ImageURL != nil {
if url := r.ImageURL.GetURL(); url != "" {
urls = append(urls, url)
}
}
return urls
}
func (r *imageEditRequest) HasMask() bool {
if r.Mask != nil && r.Mask.GetURL() != "" {
return true
}
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
}
type imageVariationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
Image *imageInputURL `json:"image,omitempty"`
Images []imageInputURL `json:"images,omitempty"`
ImageURL *imageInputURL `json:"image_url,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
func (r *imageVariationRequest) GetImageURLs() []string {
urls := make([]string, 0, len(r.Images)+2)
for _, image := range r.Images {
if url := image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.Image != nil {
if url := r.Image.GetURL(); url != "" {
urls = append(urls, url)
}
}
if r.ImageURL != nil {
if url := r.ImageURL.GetURL(); url != "" {
urls = append(urls, url)
}
}
return urls
}
type imageGenerationData struct { type imageGenerationData struct {
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
B64 string `json:"b64_json,omitempty"` B64 string `json:"b64_json,omitempty"`

View File

@@ -0,0 +1,156 @@
package provider
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
)
type multipartImageRequest struct {
Model string
Prompt string
Size string
OutputFormat string
N int
ImageURLs []string
HasMask bool
}
func isMultipartFormData(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
return strings.EqualFold(mediaType, "multipart/form-data")
}
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("unable to parse content-type: %v", err)
}
boundary := params["boundary"]
if boundary == "" {
return nil, fmt.Errorf("missing multipart boundary")
}
req := &multipartImageRequest{
ImageURLs: make([]string, 0),
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("unable to read multipart part: %v", err)
}
fieldName := part.FormName()
if fieldName == "" {
_ = part.Close()
continue
}
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
partData, err := io.ReadAll(part)
_ = part.Close()
if err != nil {
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
}
value := strings.TrimSpace(string(partData))
switch fieldName {
case "model":
req.Model = value
continue
case "prompt":
req.Prompt = value
continue
case "size":
req.Size = value
continue
case "output_format":
req.OutputFormat = value
continue
case "n":
if value != "" {
if parsed, err := strconv.Atoi(value); err == nil {
req.N = parsed
}
}
continue
}
if isMultipartImageField(fieldName) {
if isMultipartImageURLValue(value) {
req.ImageURLs = append(req.ImageURLs, value)
continue
}
if len(partData) == 0 {
continue
}
imageURL := buildMultipartDataURL(partContentType, partData)
req.ImageURLs = append(req.ImageURLs, imageURL)
continue
}
if isMultipartMaskField(fieldName) {
if len(partData) > 0 || value != "" {
req.HasMask = true
}
continue
}
}
return req, nil
}
func isMultipartImageField(fieldName string) bool {
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
}
func isMultipartMaskField(fieldName string) bool {
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
}
func isMultipartImageURLValue(value string) bool {
if value == "" {
return false
}
loweredValue := strings.ToLower(value)
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
}
func buildMultipartDataURL(contentType string, data []byte) string {
mimeType := strings.TrimSpace(contentType)
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
mimeType = http.DetectContentType(data)
}
mimeType = normalizeMultipartMimeType(mimeType)
if mimeType == "" {
mimeType = "application/octet-stream"
}
encoded := base64.StdEncoding.EncodeToString(data)
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
}
func normalizeMultipartMimeType(contentType string) string {
contentType = strings.TrimSpace(contentType)
if contentType == "" {
return ""
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err == nil && mediaType != "" {
return strings.TrimSpace(mediaType)
}
if idx := strings.Index(contentType, ";"); idx > 0 {
return strings.TrimSpace(contentType[:idx])
}
return contentType
}

View File

@@ -198,7 +198,6 @@ var (
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system". // Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
developerRoleSupportedProviders = map[string]bool{ developerRoleSupportedProviders = map[string]bool{
providerTypeOpenAI: true,
providerTypeAzure: true, providerTypeAzure: true,
} }
@@ -845,6 +844,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
return err return err
} }
return c.setRequestModel(ctx, req) return c.setRequestModel(ctx, req)
case *imageEditRequest:
if err := decodeImageEditRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
case *imageVariationRequest:
if err := decodeImageVariationRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
default: default:
return errors.New("unsupported request type") return errors.New("unsupported request type")
} }
@@ -860,6 +869,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
model = &req.Model model = &req.Model
case *imageGenerationRequest: case *imageGenerationRequest:
model = &req.Model model = &req.Model
case *imageEditRequest:
model = &req.Model
case *imageVariationRequest:
model = &req.Model
default: default:
return errors.New("unsupported request type") return errors.New("unsupported request type")
} }

View File

@@ -37,7 +37,7 @@ const (
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}" qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
qwenBailianPath = "/api/v1/apps" qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenAnthropicMessagesPath = "/api/v2/apps/claude-code-proxy/v1/messages" qwenAnthropicMessagesPath = "/apps/anthropic/v1/messages"
qwenAsyncAIGCPath = "/api/v1/services/aigc/" qwenAsyncAIGCPath = "/api/v1/services/aigc/"
qwenAsyncTaskPath = "/api/v1/tasks/" qwenAsyncTaskPath = "/api/v1/tasks/"

View File

@@ -4,8 +4,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
) )
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error { func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
return nil return nil
} }
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}) error { func replaceJsonRequestBody(request interface{}) error {
body, err := json.Marshal(request) body, err := json.Marshal(request)
if err != nil { if err != nil {

View File

@@ -46,6 +46,7 @@ const (
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
contextVertexRawMarker = "isVertexRawRequest" contextVertexRawMarker = "isVertexRawRequest"
vertexAnthropicVersion = "vertex-2023-10-16" vertexAnthropicVersion = "vertex-2023-10-16"
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
) )
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径 // vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
string(ApiNameChatCompletion): vertexPathTemplate, string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate, string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameImageGeneration): vertexPathTemplate, string(ApiNameImageGeneration): vertexPathTemplate,
string(ApiNameImageEdit): vertexPathTemplate,
string(ApiNameImageVariation): vertexPathTemplate,
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换 string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
} }
} }
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
return v.onEmbeddingsRequestBody(ctx, body, headers) return v.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration: case ApiNameImageGeneration:
return v.onImageGenerationRequestBody(ctx, body, headers) return v.onImageGenerationRequestBody(ctx, body, headers)
case ApiNameImageEdit:
return v.onImageEditRequestBody(ctx, body, headers)
case ApiNameImageVariation:
return v.onImageVariationRequestBody(ctx, body, headers)
default: default:
return body, nil return body, nil
} }
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false) path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path) util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexImageGenerationRequest(request) vertexRequest, err := v.buildVertexImageGenerationRequest(request)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest) return json.Marshal(vertexRequest)
} }
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest { func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageEditRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
if parsedRequest.HasMask {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
if request.HasMask() {
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
if request.Prompt == "" {
return nil, fmt.Errorf("missing prompt in request")
}
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageVariationRequest{}
imageURLs := make([]string, 0)
contentType := headers.Get("Content-Type")
if isMultipartFormData(contentType) {
parsedRequest, err := parseMultipartImageRequest(body, contentType)
if err != nil {
return nil, err
}
request.Model = parsedRequest.Model
request.Prompt = parsedRequest.Prompt
request.Size = parsedRequest.Size
request.OutputFormat = parsedRequest.OutputFormat
request.N = parsedRequest.N
imageURLs = parsedRequest.ImageURLs
if err := v.config.mapModel(ctx, &request.Model); err != nil {
return nil, err
}
} else {
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
imageURLs = request.GetImageURLs()
}
if len(imageURLs) == 0 {
return nil, fmt.Errorf("missing image_url in request")
}
prompt := request.Prompt
if prompt == "" {
prompt = vertexImageVariationDefaultPrompt
}
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
headers.Set("Content-Type", util.MimeTypeApplicationJson)
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
if err != nil {
return nil, err
}
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
}
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
// 构建安全设置 // 构建安全设置
safetySettings := make([]vertexChatSafetySetting, 0) safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting { for category, threshold := range v.config.geminiSafetySetting {
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
} }
// 解析尺寸参数 // 解析尺寸参数
aspectRatio, imageSize := v.parseImageSize(request.Size) aspectRatio, imageSize := v.parseImageSize(size)
// 确定输出 MIME 类型 // 确定输出 MIME 类型
mimeType := "image/png" mimeType := "image/png"
if request.OutputFormat != "" { if outputFormat != "" {
switch request.OutputFormat { switch outputFormat {
case "jpeg", "jpg": case "jpeg", "jpg":
mimeType = "image/jpeg" mimeType = "image/jpeg"
case "webp": case "webp":
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
} }
} }
parts := make([]vertexPart, 0, len(imageURLs)+1)
for _, imageURL := range imageURLs {
part, err := convertMediaContent(imageURL)
if err != nil {
return nil, err
}
parts = append(parts, part)
}
if prompt != "" {
parts = append(parts, vertexPart{
Text: prompt,
})
}
if len(parts) == 0 {
return nil, fmt.Errorf("missing prompt and image_url in request")
}
vertexRequest := &vertexChatRequest{ vertexRequest := &vertexChatRequest{
Contents: []vertexChatContent{{ Contents: []vertexChatContent{{
Role: roleUser, Role: roleUser,
Parts: []vertexPart{{ Parts: parts,
Text: request.Prompt,
}},
}}, }},
SafetySettings: safetySettings, SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{ GenerationConfig: vertexChatGenerationConfig{
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
}, },
} }
return vertexRequest return vertexRequest, nil
} }
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize // parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
return v.onChatCompletionResponseBody(ctx, body) return v.onChatCompletionResponseBody(ctx, body)
case ApiNameEmbeddings: case ApiNameEmbeddings:
return v.onEmbeddingsResponseBody(ctx, body) return v.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration: case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
return v.onImageGenerationResponseBody(ctx, body) return v.onImageGenerationResponseBody(ctx, body)
default: default:
return body, nil return body, nil
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
switch apiName { switch apiName {
case ApiNameEmbeddings: case ApiNameEmbeddings:
action = vertexEmbeddingAction action = vertexEmbeddingAction
case ApiNameImageGeneration: case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
// 图片生成使用非流式端点,需要完整响应 // 图片生成使用非流式端点,需要完整响应
action = vertexChatCompletionAction action = vertexChatCompletionAction
default: default:

View File

@@ -25,6 +25,76 @@ var basicBedrockConfig = func() json.RawMessage {
return data return data
}() }()
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key
var bedrockOriginalAkSkConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"protocol": "original",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
},
})
return data
}()
// Test config: Bedrock original protocol config with api token
var bedrockOriginalApiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"protocol": "original",
"awsRegion": "us-east-1",
"apiTokens": []string{
"test-token-for-unit-test",
},
},
})
return data
}()
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key and custom settings
var bedrockOriginalAkSkWithCustomSettingsConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"protocol": "original",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"customSettings": []map[string]interface{}{
{
"name": "foo",
"value": "\"bar\"",
"mode": "raw",
"overwrite": true,
},
},
},
})
return data
}()
// Test config: Bedrock config with embeddings capability to verify generic SigV4 flow
var bedrockEmbeddingsCapabilityConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"capabilities": map[string]string{
"openai/v1/embeddings": "/model/amazon.titan-embed-text-v2:0/invoke",
},
"modelMapping": map[string]string{
"*": "amazon.titan-embed-text-v2:0",
},
},
})
return data
}()
// Test config: Bedrock config with Bearer Token authentication // Test config: Bedrock config with Bearer Token authentication
var bedrockApiTokenConfig = func() json.RawMessage { var bedrockApiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{ data, _ := json.Marshal(map[string]interface{}{
@@ -352,6 +422,169 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint") require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
}) })
// Test Bedrock generic request body processing with AWS Signature V4 authentication
t.Run("bedrock embeddings request body with ak/sk should use sigv4", func(t *testing.T) {
host, status := test.NewTestHost(bedrockEmbeddingsCapabilityConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "text-embedding-3-small",
"input": "Hello from embeddings"
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
})
// Test Bedrock original converse-stream path with AWS Signature V4 authentication
t.Run("bedrock original converse-stream with ak/sk should use sigv4", func(t *testing.T) {
host, status := test.NewTestHost(bedrockOriginalAkSkConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", originalPath},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"messages": [
{
"role": "user",
"content": [{"text": "Hello from original bedrock path"}]
}
],
"inferenceConfig": {
"maxTokens": 64
}
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
})
// Test Bedrock original converse-stream path with Bearer Token authentication
t.Run("bedrock original converse-stream with api token should pass bearer auth", func(t *testing.T) {
host, status := test.NewTestHost(bedrockOriginalApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", originalPath},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"messages": [
{
"role": "user",
"content": [{"text": "Hello from original bedrock path"}]
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain configured token")
_, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
require.False(t, hasDate, "X-Amz-Date should not be set in Bearer token mode")
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
})
// Test Bedrock original converse-stream path keeps signed body consistent with custom settings
t.Run("bedrock original converse-stream with custom settings should replace body before forwarding", func(t *testing.T) {
host, status := test.NewTestHost(bedrockOriginalAkSkWithCustomSettingsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
originalPath := "/model/amazon.nova-2-lite-v1:0/converse-stream"
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", originalPath},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"messages": [
{
"role": "user",
"content": [{"text": "Hello"}]
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
require.Equal(t, "\"bar\"", bodyMap["foo"], "Custom settings should be applied to forwarded body")
authValue, hasAuth := test.GetHeaderValue(host.GetRequestHeaders(), "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
})
// Test Bedrock streaming request // Test Bedrock streaming request
t.Run("bedrock streaming request", func(t *testing.T) { t.Run("bedrock streaming request", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig) host, status := test.NewTestHost(bedrockApiTokenConfig)

View File

@@ -1,7 +1,9 @@
package test package test
import ( import (
"bytes"
"encoding/json" "encoding/json"
"mime/multipart"
"strings" "strings"
"testing" "testing"
@@ -1273,6 +1275,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
}) })
} }
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
var buffer bytes.Buffer
writer := multipart.NewWriter(&buffer)
for key, value := range fields {
require.NoError(t, writer.WriteField(key, value))
}
for fieldName, data := range files {
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
require.NoError(t, err)
_, err = part.Write(data)
require.NoError(t, err)
}
require.NoError(t, writer.Close())
return buffer.Bytes(), writer.FormDataContentType()
}
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
})
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
})
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
body, contentType := buildMultipartRequestBody(t, map[string]string{
"model": "gemini-2.0-flash-exp",
"prompt": "Add sunglasses to the cat",
"size": "1024x1024",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", contentType},
})
action := host.CallOnHttpRequestBody(body)
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
})
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
body, contentType := buildMultipartRequestBody(t, map[string]string{
"model": "gemini-2.0-flash-exp",
"size": "1024x1024",
}, map[string][]byte{
"image": []byte("fake-image-content"),
})
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/variations"},
{":method", "POST"},
{"Content-Type", contentType},
})
action := host.CallOnHttpRequestBody(body)
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
})
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
})
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/variations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
bodyStr := string(processedBody)
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
})
})
}
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
t.Run("vertex express mode image edit response body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/edits"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
host.CallOnHttpRequestBody([]byte(requestBody))
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"inlineData": {
"mimeType": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}]
}
}],
"usageMetadata": {
"promptTokenCount": 12,
"candidatesTokenCount": 1024,
"totalTokenCount": 1036
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
require.Contains(t, responseStr, "usage", "Response should contain usage field")
})
t.Run("vertex express mode image variation response body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/variations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
host.CallOnHttpRequestBody([]byte(requestBody))
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
responseBody := `{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"inlineData": {
"mimeType": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}]
}
}],
"usageMetadata": {
"promptTokenCount": 8,
"candidatesTokenCount": 768,
"totalTokenCount": 776
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
require.Contains(t, responseStr, "usage", "Response should contain usage field")
})
})
}
// ==================== Vertex Raw 模式测试 ==================== // ==================== Vertex Raw 模式测试 ====================
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) { func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {