mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 19:20:51 +08:00
Compare commits
8 Commits
add-releas
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
528e6c9908 | ||
|
|
13b808c1e4 | ||
|
|
aa502e7e62 | ||
|
|
2e3f6868df | ||
|
|
6c9747d778 | ||
|
|
c12183cae5 | ||
|
|
e2a22d1171 | ||
|
|
e9aecb6e1f |
@@ -4,6 +4,6 @@ dependencies:
|
|||||||
version: 2.2.0
|
version: 2.2.0
|
||||||
- name: higress-console
|
- name: higress-console
|
||||||
repository: https://higress.io/helm-charts/
|
repository: https://higress.io/helm-charts/
|
||||||
version: 2.2.0
|
version: 2.2.1
|
||||||
digest: sha256:2cb148fa6d52856344e1905d3fea018466c2feb52013e08997c2d5c7d50f2e5d
|
digest: sha256:23fe7b0f84965c13ac7ceabe6334212fc3d323b7b781277a6d2b6fd38e935dda
|
||||||
generated: "2026-02-11T17:45:59.187965929+08:00"
|
generated: "2026-03-07T12:45:44.267732+08:00"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
apiVersion: v2
|
apiVersion: v2
|
||||||
appVersion: 2.2.0
|
appVersion: 2.2.1
|
||||||
description: Helm chart for deploying Higress gateways
|
description: Helm chart for deploying Higress gateways
|
||||||
icon: https://higress.io/img/higress_logo_small.png
|
icon: https://higress.io/img/higress_logo_small.png
|
||||||
home: http://higress.io/
|
home: http://higress.io/
|
||||||
@@ -15,6 +15,6 @@ dependencies:
|
|||||||
version: 2.2.0
|
version: 2.2.0
|
||||||
- name: higress-console
|
- name: higress-console
|
||||||
repository: "https://higress.io/helm-charts/"
|
repository: "https://higress.io/helm-charts/"
|
||||||
version: 2.2.0
|
version: 2.2.1
|
||||||
type: application
|
type: application
|
||||||
version: 2.2.0
|
version: 2.2.1
|
||||||
|
|||||||
@@ -140,10 +140,16 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
|||||||
|
|
||||||
// Send the initial endpoint event
|
// Send the initial endpoint event
|
||||||
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
|
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||||
err = s.redisClient.Publish(channel, initialEvent)
|
go func() {
|
||||||
if err != nil {
|
defer func() {
|
||||||
api.LogErrorf("Failed to send initial event: %v", err)
|
if r := recover(); r != nil {
|
||||||
|
api.LogErrorf("Failed to send initial event: %v", r)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
||||||
|
api.LogDebugf("SSE Send message: %s", initialEvent)
|
||||||
|
cb.EncoderFilterCallbacks().InjectData([]byte(initialEvent))
|
||||||
|
}()
|
||||||
|
|
||||||
// Start health check handler
|
// Start health check handler
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
|
||||||
ctx.DontReadRequestBody()
|
ctx.DontReadRequestBody()
|
||||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiName == "" {
|
if apiName == "" {
|
||||||
@@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return action
|
return action
|
||||||
}
|
}
|
||||||
|
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
|
||||||
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||||
}
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
@@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName {
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
|
||||||
|
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
contentType = strings.ToLower(contentType)
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_isSupportedRequestContentType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
apiName provider.ApiName
|
||||||
|
contentType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "json chat completion",
|
||||||
|
apiName: provider.ApiNameChatCompletion,
|
||||||
|
contentType: "application/json",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart image edit",
|
||||||
|
apiName: provider.ApiNameImageEdit,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart image variation",
|
||||||
|
apiName: provider.ApiNameImageVariation,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart chat completion",
|
||||||
|
apiName: provider.ApiNameChatCompletion,
|
||||||
|
contentType: "multipart/form-data; boundary=----boundary",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text plain image edit",
|
||||||
|
apiName: provider.ApiNameImageEdit,
|
||||||
|
contentType: "text/plain",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAi360(t *testing.T) {
|
func TestAi360(t *testing.T) {
|
||||||
test.RunAi360ParseConfigTests(t)
|
test.RunAi360ParseConfigTests(t)
|
||||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||||
@@ -137,6 +185,8 @@ func TestVertex(t *testing.T) {
|
|||||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||||
|
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
|
||||||
|
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
|
||||||
// Vertex Raw 模式测试
|
// Vertex Raw 模式测试
|
||||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||||
|
|||||||
@@ -40,6 +40,11 @@ const (
|
|||||||
requestIdHeader = "X-Amzn-Requestid"
|
requestIdHeader = "X-Amzn-Requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
|
||||||
|
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
|
||||||
|
)
|
||||||
|
|
||||||
type bedrockProviderInitializer struct{}
|
type bedrockProviderInitializer struct{}
|
||||||
|
|
||||||
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
@@ -630,13 +635,24 @@ func (b *bedrockProvider) GetProviderType() string {
|
|||||||
return providerTypeBedrock
|
return providerTypeBedrock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) GetApiName(path string) ApiName {
|
||||||
|
switch {
|
||||||
|
case bedrockConversePathPattern.MatchString(path):
|
||||||
|
return ApiNameChatCompletion
|
||||||
|
case bedrockInvokePathPattern.MatchString(path):
|
||||||
|
return ApiNameImageGeneration
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||||
b.config.handleRequestHeaders(b, ctx, apiName)
|
b.config.handleRequestHeaders(b, ctx, apiName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
|
||||||
|
|
||||||
// If apiTokens is configured, set Bearer token authentication here
|
// If apiTokens is configured, set Bearer token authentication here
|
||||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||||
@@ -647,6 +663,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||||
|
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
|
||||||
|
// and only apply auth headers.
|
||||||
|
if b.config.IsOriginal() {
|
||||||
|
headers := util.GetRequestHeaders()
|
||||||
|
b.setAuthHeaders(body, headers)
|
||||||
|
util.ReplaceRequestHeaders(headers)
|
||||||
|
return types.ActionContinue, replaceRequestBody(body)
|
||||||
|
}
|
||||||
|
|
||||||
if !b.config.isSupportedAPI(apiName) {
|
if !b.config.isSupportedAPI(apiName) {
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
return types.ActionContinue, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
@@ -654,14 +679,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
var transformedBody []byte
|
||||||
|
var err error
|
||||||
switch apiName {
|
switch apiName {
|
||||||
case ApiNameChatCompletion:
|
case ApiNameChatCompletion:
|
||||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration:
|
||||||
return b.onImageGenerationRequestBody(ctx, body, headers)
|
transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
|
||||||
default:
|
default:
|
||||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always apply auth after request body/path are finalized.
|
||||||
|
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
|
||||||
|
b.setAuthHeaders(transformedBody, headers)
|
||||||
|
return transformedBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||||
@@ -715,9 +751,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
|||||||
Quality: origRequest.Quality,
|
Quality: origRequest.Quality,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
requestBytes, err := json.Marshal(request)
|
return json.Marshal(request)
|
||||||
b.setAuthHeaders(requestBytes, headers)
|
|
||||||
return requestBytes, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||||
@@ -847,9 +881,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
|||||||
request.AdditionalModelRequestFields[key] = value
|
request.AdditionalModelRequestFields[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
requestBytes, err := json.Marshal(request)
|
return json.Marshal(request)
|
||||||
b.setAuthHeaders(requestBytes, headers)
|
|
||||||
return requestBytes, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
|
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
|
||||||
@@ -1163,45 +1195,88 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use AWS Signature V4 authentication
|
// Use AWS Signature V4 authentication
|
||||||
|
accessKey := strings.TrimSpace(b.config.awsAccessKey)
|
||||||
|
region := strings.TrimSpace(b.config.awsRegion)
|
||||||
t := time.Now().UTC()
|
t := time.Now().UTC()
|
||||||
amzDate := t.Format("20060102T150405Z")
|
amzDate := t.Format("20060102T150405Z")
|
||||||
dateStamp := t.Format("20060102")
|
dateStamp := t.Format("20060102")
|
||||||
path := headers.Get(":path")
|
path := headers.Get(":path")
|
||||||
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
||||||
headers.Set("X-Amz-Date", amzDate)
|
headers.Set("X-Amz-Date", amzDate)
|
||||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||||
path = encodeSigV4Path(path)
|
canonicalURI := encodeSigV4Path(path)
|
||||||
hashedPayload := sha256Hex(body)
|
hashedPayload := sha256Hex(body)
|
||||||
|
region := strings.TrimSpace(b.config.awsRegion)
|
||||||
|
secretKey := strings.TrimSpace(b.config.awsSecretKey)
|
||||||
|
|
||||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
|
endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
|
||||||
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
||||||
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
|
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
|
||||||
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||||
|
|
||||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
|
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
|
||||||
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
|
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
|
||||||
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
|
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
|
||||||
amzDate, credentialScope, hashedCanonReq)
|
amzDate, credentialScope, hashedCanonReq)
|
||||||
|
|
||||||
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
|
signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
|
||||||
signature := hmacHex(signingKey, stringToSign)
|
signature := hmacHex(signingKey, stringToSign)
|
||||||
return signature
|
return signature
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeSigV4Path(path string) string {
|
func encodeSigV4Path(path string) string {
|
||||||
|
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
|
||||||
|
// and this implementation uses an empty canonical query string.
|
||||||
|
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
|
||||||
|
path = path[:queryIndex]
|
||||||
|
}
|
||||||
|
|
||||||
segments := strings.Split(path, "/")
|
segments := strings.Split(path, "/")
|
||||||
for i, seg := range segments {
|
for i, seg := range segments {
|
||||||
if seg == "" {
|
if seg == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
segments[i] = url.PathEscape(seg)
|
// Normalize to "single-encoded" form:
|
||||||
|
// - raw ":" -> %3A
|
||||||
|
// - already encoded "%3A" -> still %3A (not %253A)
|
||||||
|
decoded, err := url.PathUnescape(seg)
|
||||||
|
if err == nil {
|
||||||
|
segments[i] = sigV4EscapePathSegment(decoded)
|
||||||
|
} else {
|
||||||
|
// If segment has invalid escape sequence, fall back to escaping raw segment.
|
||||||
|
segments[i] = sigV4EscapePathSegment(seg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(segments, "/")
|
return strings.Join(segments, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sigV4EscapePathSegment(segment string) string {
|
||||||
|
const upperHex = "0123456789ABCDEF"
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(segment) * 3)
|
||||||
|
for i := 0; i < len(segment); i++ {
|
||||||
|
c := segment[i]
|
||||||
|
if isSigV4Unreserved(c) {
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte('%')
|
||||||
|
b.WriteByte(upperHex[c>>4])
|
||||||
|
b.WriteByte(upperHex[c&0x0F])
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSigV4Unreserved(c byte) bool {
|
||||||
|
return (c >= 'A' && c <= 'Z') ||
|
||||||
|
(c >= 'a' && c <= 'z') ||
|
||||||
|
(c >= '0' && c <= '9') ||
|
||||||
|
c == '-' || c == '_' || c == '.' || c == '~'
|
||||||
|
}
|
||||||
|
|
||||||
func getSignatureKey(key, dateStamp, region, service string) []byte {
|
func getSignatureKey(key, dateStamp, region, service string) []byte {
|
||||||
kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
|
kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
|
||||||
kRegion := hmacSha256(kDate, region)
|
kRegion := hmacSha256(kDate, region)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
|
|||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type imageInputURL struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||||
|
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||||
|
var rawURL string
|
||||||
|
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||||
|
i.URL = rawURL
|
||||||
|
i.ImageURL = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type alias imageInputURL
|
||||||
|
var value alias
|
||||||
|
if err := json.Unmarshal(data, &value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = imageInputURL(value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *imageInputURL) GetURL() string {
|
||||||
|
if i == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||||
|
return i.ImageURL.Url
|
||||||
|
}
|
||||||
|
return i.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageEditRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Image *imageInputURL `json:"image,omitempty"`
|
||||||
|
Images []imageInputURL `json:"images,omitempty"`
|
||||||
|
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||||
|
Mask *imageInputURL `json:"mask,omitempty"`
|
||||||
|
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageEditRequest) GetImageURLs() []string {
|
||||||
|
urls := make([]string, 0, len(r.Images)+2)
|
||||||
|
for _, image := range r.Images {
|
||||||
|
if url := image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Image != nil {
|
||||||
|
if url := r.Image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.ImageURL != nil {
|
||||||
|
if url := r.ImageURL.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageEditRequest) HasMask() bool {
|
||||||
|
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageVariationRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Image *imageInputURL `json:"image,omitempty"`
|
||||||
|
Images []imageInputURL `json:"images,omitempty"`
|
||||||
|
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||||
|
urls := make([]string, 0, len(r.Images)+2)
|
||||||
|
for _, image := range r.Images {
|
||||||
|
if url := image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Image != nil {
|
||||||
|
if url := r.Image.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.ImageURL != nil {
|
||||||
|
if url := r.ImageURL.GetURL(); url != "" {
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
|
||||||
type imageGenerationData struct {
|
type imageGenerationData struct {
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
B64 string `json:"b64_json,omitempty"`
|
B64 string `json:"b64_json,omitempty"`
|
||||||
|
|||||||
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type multipartImageRequest struct {
|
||||||
|
Model string
|
||||||
|
Prompt string
|
||||||
|
Size string
|
||||||
|
OutputFormat string
|
||||||
|
N int
|
||||||
|
ImageURLs []string
|
||||||
|
HasMask bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartFormData(contentType string) bool {
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||||
|
_, params, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||||
|
}
|
||||||
|
boundary := params["boundary"]
|
||||||
|
if boundary == "" {
|
||||||
|
return nil, fmt.Errorf("missing multipart boundary")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &multipartImageRequest{
|
||||||
|
ImageURLs: make([]string, 0),
|
||||||
|
}
|
||||||
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||||
|
for {
|
||||||
|
part, err := reader.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||||
|
}
|
||||||
|
fieldName := part.FormName()
|
||||||
|
if fieldName == "" {
|
||||||
|
_ = part.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
partData, err := io.ReadAll(part)
|
||||||
|
_ = part.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value := strings.TrimSpace(string(partData))
|
||||||
|
switch fieldName {
|
||||||
|
case "model":
|
||||||
|
req.Model = value
|
||||||
|
continue
|
||||||
|
case "prompt":
|
||||||
|
req.Prompt = value
|
||||||
|
continue
|
||||||
|
case "size":
|
||||||
|
req.Size = value
|
||||||
|
continue
|
||||||
|
case "output_format":
|
||||||
|
req.OutputFormat = value
|
||||||
|
continue
|
||||||
|
case "n":
|
||||||
|
if value != "" {
|
||||||
|
if parsed, err := strconv.Atoi(value); err == nil {
|
||||||
|
req.N = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isMultipartImageField(fieldName) {
|
||||||
|
if isMultipartImageURLValue(value) {
|
||||||
|
req.ImageURLs = append(req.ImageURLs, value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(partData) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||||
|
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isMultipartMaskField(fieldName) {
|
||||||
|
if len(partData) > 0 || value != "" {
|
||||||
|
req.HasMask = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartImageField(fieldName string) bool {
|
||||||
|
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartMaskField(fieldName string) bool {
|
||||||
|
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartImageURLValue(value string) bool {
|
||||||
|
if value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loweredValue := strings.ToLower(value)
|
||||||
|
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||||
|
mimeType := strings.TrimSpace(contentType)
|
||||||
|
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||||
|
mimeType = http.DetectContentType(data)
|
||||||
|
}
|
||||||
|
mimeType = normalizeMultipartMimeType(mimeType)
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMultipartMimeType(contentType string) string {
|
||||||
|
contentType = strings.TrimSpace(contentType)
|
||||||
|
if contentType == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err == nil && mediaType != "" {
|
||||||
|
return strings.TrimSpace(mediaType)
|
||||||
|
}
|
||||||
|
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||||
|
return strings.TrimSpace(contentType[:idx])
|
||||||
|
}
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
@@ -198,7 +198,6 @@ var (
|
|||||||
|
|
||||||
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
||||||
developerRoleSupportedProviders = map[string]bool{
|
developerRoleSupportedProviders = map[string]bool{
|
||||||
providerTypeOpenAI: true,
|
|
||||||
providerTypeAzure: true,
|
providerTypeAzure: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -845,6 +844,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.setRequestModel(ctx, req)
|
return c.setRequestModel(ctx, req)
|
||||||
|
case *imageEditRequest:
|
||||||
|
if err := decodeImageEditRequest(body, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setRequestModel(ctx, req)
|
||||||
|
case *imageVariationRequest:
|
||||||
|
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setRequestModel(ctx, req)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
@@ -860,6 +869,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
|||||||
model = &req.Model
|
model = &req.Model
|
||||||
case *imageGenerationRequest:
|
case *imageGenerationRequest:
|
||||||
model = &req.Model
|
model = &req.Model
|
||||||
|
case *imageEditRequest:
|
||||||
|
model = &req.Model
|
||||||
|
case *imageVariationRequest:
|
||||||
|
model = &req.Model
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ const (
|
|||||||
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||||
qwenBailianPath = "/api/v1/apps"
|
qwenBailianPath = "/api/v1/apps"
|
||||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
qwenAnthropicMessagesPath = "/api/v2/apps/claude-code-proxy/v1/messages"
|
qwenAnthropicMessagesPath = "/apps/anthropic/v1/messages"
|
||||||
|
|
||||||
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
||||||
qwenAsyncTaskPath = "/api/v1/tasks/"
|
qwenAsyncTaskPath = "/api/v1/tasks/"
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func replaceJsonRequestBody(request interface{}) error {
|
func replaceJsonRequestBody(request interface{}) error {
|
||||||
body, err := json.Marshal(request)
|
body, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ const (
|
|||||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||||
contextVertexRawMarker = "isVertexRawRequest"
|
contextVertexRawMarker = "isVertexRawRequest"
|
||||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
|
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||||
)
|
)
|
||||||
|
|
||||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||||
@@ -98,6 +99,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
|||||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||||
|
string(ApiNameImageEdit): vertexPathTemplate,
|
||||||
|
string(ApiNameImageVariation): vertexPathTemplate,
|
||||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -307,6 +310,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
|||||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration:
|
||||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||||
|
case ApiNameImageEdit:
|
||||||
|
return v.onImageEditRequestBody(ctx, body, headers)
|
||||||
|
case ApiNameImageVariation:
|
||||||
|
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
@@ -387,11 +394,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
|||||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||||
util.OverwriteRequestPathHeader(headers, path)
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
|
||||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return json.Marshal(vertexRequest)
|
return json.Marshal(vertexRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
request := &imageEditRequest{}
|
||||||
|
imageURLs := make([]string, 0)
|
||||||
|
contentType := headers.Get("Content-Type")
|
||||||
|
if isMultipartFormData(contentType) {
|
||||||
|
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request.Model = parsedRequest.Model
|
||||||
|
request.Prompt = parsedRequest.Prompt
|
||||||
|
request.Size = parsedRequest.Size
|
||||||
|
request.OutputFormat = parsedRequest.OutputFormat
|
||||||
|
request.N = parsedRequest.N
|
||||||
|
imageURLs = parsedRequest.ImageURLs
|
||||||
|
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if parsedRequest.HasMask {
|
||||||
|
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if request.HasMask() {
|
||||||
|
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||||
|
}
|
||||||
|
imageURLs = request.GetImageURLs()
|
||||||
|
}
|
||||||
|
if len(imageURLs) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing image_url in request")
|
||||||
|
}
|
||||||
|
if request.Prompt == "" {
|
||||||
|
return nil, fmt.Errorf("missing prompt in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||||
|
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(vertexRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
request := &imageVariationRequest{}
|
||||||
|
imageURLs := make([]string, 0)
|
||||||
|
contentType := headers.Get("Content-Type")
|
||||||
|
if isMultipartFormData(contentType) {
|
||||||
|
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request.Model = parsedRequest.Model
|
||||||
|
request.Prompt = parsedRequest.Prompt
|
||||||
|
request.Size = parsedRequest.Size
|
||||||
|
request.OutputFormat = parsedRequest.OutputFormat
|
||||||
|
request.N = parsedRequest.N
|
||||||
|
imageURLs = parsedRequest.ImageURLs
|
||||||
|
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
imageURLs = request.GetImageURLs()
|
||||||
|
}
|
||||||
|
if len(imageURLs) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing image_url in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := request.Prompt
|
||||||
|
if prompt == "" {
|
||||||
|
prompt = vertexImageVariationDefaultPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||||
|
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(vertexRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||||
|
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||||
// 构建安全设置
|
// 构建安全设置
|
||||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||||
for category, threshold := range v.config.geminiSafetySetting {
|
for category, threshold := range v.config.geminiSafetySetting {
|
||||||
@@ -402,12 +506,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析尺寸参数
|
// 解析尺寸参数
|
||||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
aspectRatio, imageSize := v.parseImageSize(size)
|
||||||
|
|
||||||
// 确定输出 MIME 类型
|
// 确定输出 MIME 类型
|
||||||
mimeType := "image/png"
|
mimeType := "image/png"
|
||||||
if request.OutputFormat != "" {
|
if outputFormat != "" {
|
||||||
switch request.OutputFormat {
|
switch outputFormat {
|
||||||
case "jpeg", "jpg":
|
case "jpeg", "jpg":
|
||||||
mimeType = "image/jpeg"
|
mimeType = "image/jpeg"
|
||||||
case "webp":
|
case "webp":
|
||||||
@@ -417,12 +521,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||||
|
for _, imageURL := range imageURLs {
|
||||||
|
part, err := convertMediaContent(imageURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
if prompt != "" {
|
||||||
|
parts = append(parts, vertexPart{
|
||||||
|
Text: prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||||
|
}
|
||||||
|
|
||||||
vertexRequest := &vertexChatRequest{
|
vertexRequest := &vertexChatRequest{
|
||||||
Contents: []vertexChatContent{{
|
Contents: []vertexChatContent{{
|
||||||
Role: roleUser,
|
Role: roleUser,
|
||||||
Parts: []vertexPart{{
|
Parts: parts,
|
||||||
Text: request.Prompt,
|
|
||||||
}},
|
|
||||||
}},
|
}},
|
||||||
SafetySettings: safetySettings,
|
SafetySettings: safetySettings,
|
||||||
GenerationConfig: vertexChatGenerationConfig{
|
GenerationConfig: vertexChatGenerationConfig{
|
||||||
@@ -440,7 +559,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return vertexRequest
|
return vertexRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||||
@@ -553,7 +672,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
|||||||
return v.onChatCompletionResponseBody(ctx, body)
|
return v.onChatCompletionResponseBody(ctx, body)
|
||||||
case ApiNameEmbeddings:
|
case ApiNameEmbeddings:
|
||||||
return v.onEmbeddingsResponseBody(ctx, body)
|
return v.onEmbeddingsResponseBody(ctx, body)
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||||
return v.onImageGenerationResponseBody(ctx, body)
|
return v.onImageGenerationResponseBody(ctx, body)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
@@ -784,7 +903,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
|||||||
switch apiName {
|
switch apiName {
|
||||||
case ApiNameEmbeddings:
|
case ApiNameEmbeddings:
|
||||||
action = vertexEmbeddingAction
|
action = vertexEmbeddingAction
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||||
// 图片生成使用非流式端点,需要完整响应
|
// 图片生成使用非流式端点,需要完整响应
|
||||||
action = vertexChatCompletionAction
|
action = vertexChatCompletionAction
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -25,6 +25,76 @@ var basicBedrockConfig = func() json.RawMessage {
|
|||||||
return data
|
return data
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key
|
||||||
|
var bedrockOriginalAkSkConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"protocol": "original",
|
||||||
|
"awsAccessKey": "test-ak-for-unit-test",
|
||||||
|
"awsSecretKey": "test-sk-for-unit-test",
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Test config: Bedrock original protocol config with api token
|
||||||
|
var bedrockOriginalApiTokenConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"protocol": "original",
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
"apiTokens": []string{
|
||||||
|
"test-token-for-unit-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key and custom settings
|
||||||
|
var bedrockOriginalAkSkWithCustomSettingsConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"protocol": "original",
|
||||||
|
"awsAccessKey": "test-ak-for-unit-test",
|
||||||
|
"awsSecretKey": "test-sk-for-unit-test",
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
"customSettings": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "foo",
|
||||||
|
"value": "\"bar\"",
|
||||||
|
"mode": "raw",
|
||||||
|
"overwrite": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Test config: Bedrock config with embeddings capability to verify generic SigV4 flow
|
||||||
|
var bedrockEmbeddingsCapabilityConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"awsAccessKey": "test-ak-for-unit-test",
|
||||||
|
"awsSecretKey": "test-sk-for-unit-test",
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
"capabilities": map[string]string{
|
||||||
|
"openai/v1/embeddings": "/model/amazon.titan-embed-text-v2:0/invoke",
|
||||||
|
},
|
||||||
|
"modelMapping": map[string]string{
|
||||||
|
"*": "amazon.titan-embed-text-v2:0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
// Test config: Bedrock config with Bearer Token authentication
|
// Test config: Bedrock config with Bearer Token authentication
|
||||||
var bedrockApiTokenConfig = func() json.RawMessage {
|
var bedrockApiTokenConfig = func() json.RawMessage {
|
||||||
data, _ := json.Marshal(map[string]interface{}{
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
@@ -352,6 +422,169 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Test Bedrock generic request body processing with AWS Signature V4 authentication
|
||||||
|
t.Run("bedrock embeddings request body with ak/sk should use sigv4", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockEmbeddingsCapabilityConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/embeddings"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"input": "Hello from embeddings"
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||||
|
require.True(t, hasAuth, "Authorization header should exist")
|
||||||
|
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||||
|
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||||
|
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||||
|
|
||||||
|
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||||
|
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||||
|
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Bedrock original converse-stream path with AWS Signature V4 authentication
|
||||||
|
t.Run("bedrock original converse-stream with ak/sk should use sigv4", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockOriginalAkSkConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", originalPath},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Hello from original bedrock path"}]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"inferenceConfig": {
|
||||||
|
"maxTokens": 64
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||||
|
require.True(t, hasAuth, "Authorization header should exist")
|
||||||
|
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||||
|
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||||
|
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||||
|
|
||||||
|
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||||
|
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||||
|
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||||
|
|
||||||
|
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||||
|
require.True(t, hasPath, "Path header should exist")
|
||||||
|
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Bedrock original converse-stream path with Bearer Token authentication
|
||||||
|
t.Run("bedrock original converse-stream with api token should pass bearer auth", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockOriginalApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", originalPath},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Hello from original bedrock path"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||||
|
require.True(t, hasAuth, "Authorization header should exist")
|
||||||
|
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
|
||||||
|
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain configured token")
|
||||||
|
|
||||||
|
_, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||||
|
require.False(t, hasDate, "X-Amz-Date should not be set in Bearer token mode")
|
||||||
|
|
||||||
|
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||||
|
require.True(t, hasPath, "Path header should exist")
|
||||||
|
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Bedrock original converse-stream path keeps signed body consistent with custom settings
|
||||||
|
t.Run("bedrock original converse-stream with custom settings should replace body before forwarding", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockOriginalAkSkWithCustomSettingsConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
originalPath := "/model/amazon.nova-2-lite-v1:0/converse-stream"
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", originalPath},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": "Hello"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "\"bar\"", bodyMap["foo"], "Custom settings should be applied to forwarded body")
|
||||||
|
|
||||||
|
authValue, hasAuth := test.GetHeaderValue(host.GetRequestHeaders(), "Authorization")
|
||||||
|
require.True(t, hasAuth, "Authorization header should exist")
|
||||||
|
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||||
|
})
|
||||||
|
|
||||||
// Test Bedrock streaming request
|
// Test Bedrock streaming request
|
||||||
t.Run("bedrock streaming request", func(t *testing.T) {
|
t.Run("bedrock streaming request", func(t *testing.T) {
|
||||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"mime/multipart"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -1273,6 +1275,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buffer)
|
||||||
|
|
||||||
|
for key, value := range fields {
|
||||||
|
require.NoError(t, writer.WriteField(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, data := range files {
|
||||||
|
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
return buffer.Bytes(), writer.FormDataContentType()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
|
||||||
|
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||||
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
"prompt": "Add sunglasses to the cat",
|
||||||
|
"size": "1024x1024",
|
||||||
|
}, map[string][]byte{
|
||||||
|
"image": []byte("fake-image-content"),
|
||||||
|
})
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", contentType},
|
||||||
|
})
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestBody(body)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||||
|
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||||
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
"size": "1024x1024",
|
||||||
|
}, map[string][]byte{
|
||||||
|
"image": []byte("fake-image-content"),
|
||||||
|
})
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", contentType},
|
||||||
|
})
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestBody(body)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||||
|
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
bodyStr := string(processedBody)
|
||||||
|
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||||
|
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||||
|
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
|
||||||
|
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
t.Run("vertex express mode image edit response body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/edits"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 12,
|
||||||
|
"candidatesTokenCount": 1024,
|
||||||
|
"totalTokenCount": 1036
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, processedResponseBody)
|
||||||
|
|
||||||
|
responseStr := string(processedResponseBody)
|
||||||
|
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||||
|
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("vertex express mode image variation response body", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/images/variations"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||||
|
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 8,
|
||||||
|
"candidatesTokenCount": 768,
|
||||||
|
"totalTokenCount": 776
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, processedResponseBody)
|
||||||
|
|
||||||
|
responseStr := string(processedResponseBody)
|
||||||
|
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||||
|
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== Vertex Raw 模式测试 ====================
|
// ==================== Vertex Raw 模式测试 ====================
|
||||||
|
|
||||||
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user