feat(ai-proxy): support Amazon Bedrock Image Generation (#2212)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
Xijun Dai
2025-05-10 09:54:31 +08:00
committed by GitHub
parent b5eadcdbee
commit 8b3f1aab1a
7 changed files with 195 additions and 15 deletions

View File

@@ -1,5 +1,5 @@
ARG BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.20.14-tinygo0.29.0-oras1.0.0
FROM $BUILDER as builder
FROM $BUILDER AS builder
ARG GOPROXY
@@ -26,6 +26,6 @@ RUN \
tinygo build -o /main.wasm -scheduler=none -gc=custom -tags="custommalloc nottinygc_finalizer $EXTRA_TAGS" -target=wasi ./ ; \
fi
FROM scratch as output
FROM scratch AS output
COPY --from=builder /main.wasm plugin.wasm

View File

@@ -11,11 +11,11 @@ require (
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.3
github.com/wasilibs/go-re2 v1.6.0
)
require (
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/wasilibs/go-re2 v1.6.0 // indirect
golang.org/x/sys v0.21.0 // indirect
)

View File

@@ -6,8 +6,6 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -29,6 +27,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

View File

@@ -13,6 +13,7 @@ import (
"hash/crc32"
"io"
"net/http"
"strconv"
"strings"
"time"
@@ -32,7 +33,10 @@ const (
bedrockChatCompletionPath = "/model/%s/converse"
// converseStream路径 /model/{modelId}/converse-stream
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
bedrockSignedHeaders = "host;x-amz-date"
// invoke_model 路径 /model/{modelId}/invoke
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
)
type bedrockProviderInitializer struct {
@@ -50,7 +54,8 @@ func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) erro
func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): bedrockChatCompletionPath,
string(ApiNameChatCompletion): bedrockChatCompletionPath,
string(ApiNameImageGeneration): bedrockInvokeModelPath,
}
}
@@ -99,7 +104,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
}
choices = append(choices, chatChoice)
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
requestId := ctx.GetStringContext(requestIdHeader, "")
openAIFormattedChunk := &chatCompletionResponse{
Id: requestId,
Created: time.Now().UnixMilli() / 1000,
@@ -152,6 +157,74 @@ type toolUseBlockDelta struct {
Input string `json:"input"`
}
type bedrockImageGenerationResponse struct {
Images []string `json:"images"`
Error string `json:"error"`
}
type bedrockImageGenerationTextToImageParams struct {
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
ConditionImage string `json:"conditionImage,omitempty"`
ControlMode string `json:"controlMode,omitempty"`
ControlStrength float32 `json:"controlLength,omitempty"`
}
type bedrockImageGenerationConfig struct {
Width int `json:"width"`
Height int `json:"height"`
Quality string `json:"quality,omitempty"`
CfgScale float32 `json:"cfgScale,omitempty"`
Seed int `json:"seed,omitempty"`
NumberOfImages int `json:"numberOfImages"`
}
type bedrockImageGenerationColorGuidedGenerationParams struct {
Colors []string `json:"colors"`
ReferenceImage string `json:"referenceImage"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationImageVariationParams struct {
Images []string `json:"images"`
SimilarityStrength float32 `json:"similarityStrength"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationInPaintingParams struct {
Image string `json:"image"`
MaskPrompt string `json:"maskPrompt"`
MaskImage string `json:"maskImage"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationOutPaintingParams struct {
Image string `json:"image"`
MaskPrompt string `json:"maskPrompt"`
MaskImage string `json:"maskImage"`
OutPaintingMode string `json:"outPaintingMode"`
Text string `json:"text"`
NegativeText string `json:"negativeText,omitempty"`
}
type bedrockImageGenerationBackgroundRemovalParams struct {
Image string `json:"image"`
}
type bedrockImageGenerationRequest struct {
TaskType string `json:"taskType"`
ImageGenerationConfig *bedrockImageGenerationConfig `json:"imageGenerationConfig"`
TextToImageParams *bedrockImageGenerationTextToImageParams `json:"textToImageParams,omitempty"`
ColorGuidedGenerationParams *bedrockImageGenerationColorGuidedGenerationParams `json:"colorGuidedGenerationParams,omitempty"`
ImageVariationParams *bedrockImageGenerationImageVariationParams `json:"imageVariationParams,omitempty"`
InPaintingParams *bedrockImageGenerationInPaintingParams `json:"inPaintingParams,omitempty"`
OutPaintingParams *bedrockImageGenerationOutPaintingParams `json:"outPaintingParams,omitempty"`
BackgroundRemovalParams *bedrockImageGenerationBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"`
}
func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
@@ -489,7 +562,7 @@ func validateCRC(r io.Reader, expect uint32) error {
}
func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
ctx.SetContext("X-Amzn-Requestid", headers.Get("X-Amzn-Requestid"))
ctx.SetContext(requestIdHeader, headers.Get(requestIdHeader))
if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" {
headers.Set("Content-Type", "text/event-stream; charset=utf-8")
}
@@ -537,18 +610,83 @@ func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return b.onImageGenerationRequestBody(ctx, body, headers)
default:
return b.config.defaultTransformRequestBody(ctx, apiName, body)
}
}
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionResponseBody(ctx, body)
case ApiNameImageGeneration:
return b.onImageGenerationResponseBody(ctx, body)
}
return nil, errUnsupportedApiName
}
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
bedrockResponse := &bedrockImageGenerationResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
}
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
return json.Marshal(response)
}
func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageGenerationRequest{}
err := b.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
headers.Set("Accept", "*/*")
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
return b.buildBedrockImageGenerationRequest(request, headers)
}
func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageGenerationRequest, headers http.Header) ([]byte, error) {
width, height := 1024, 1024
pairs := strings.Split(origRequest.Size, "x")
if len(pairs) == 2 {
width, _ = strconv.Atoi(pairs[0])
height, _ = strconv.Atoi(pairs[1])
}
request := &bedrockImageGenerationRequest{
TaskType: "TEXT_IMAGE",
TextToImageParams: &bedrockImageGenerationTextToImageParams{
Text: origRequest.Prompt,
},
ImageGenerationConfig: &bedrockImageGenerationConfig{
NumberOfImages: origRequest.N,
Width: width,
Height: height,
Quality: origRequest.Quality,
},
}
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
}
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
data := make([]imageGenerationData, len(bedrockResponse.Images))
for i, image := range bedrockResponse.Images {
data[i] = imageGenerationData{
B64: image,
}
}
return &imageGenerationResponse{
Created: time.Now().UnixMilli() / 1000,
Data: data,
}
}
func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
bedrockResponse := &bedrockConverseResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
@@ -613,7 +751,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
},
}
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
requestId := ctx.GetStringContext(requestIdHeader, "")
return &chatCompletionResponse{
Id: requestId,
Created: time.Now().UnixMilli() / 1000,

View File

@@ -356,10 +356,39 @@ func (e *StreamEvent) ToHttpString() string {
// https://platform.openai.com/docs/guides/images
type imageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
type imageGenerationData struct {
URL string `json:"url,omitempty"`
B64 string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
type imageGenerationUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails struct {
TextTokens int `json:"text_tokens"`
ImageTokens int `json:"image_tokens"`
} `json:"input_tokens_details"`
}
type imageGenerationResponse struct {
Created int64 `json:"created"`
Data []imageGenerationData `json:"data"`
Usage *imageGenerationUsage `json:"usage,omitempty"`
}
// https://platform.openai.com/docs/guides/speech-to-text

View File

@@ -534,6 +534,11 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
return err
}
return c.setRequestModel(ctx, req)
case *imageGenerationRequest:
if err := decodeImageGenerationRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
default:
return errors.New("unsupported request type")
}
@@ -547,6 +552,8 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
model = &req.Model
case *embeddingsRequest:
model = &req.Model
case *imageGenerationRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}

View File

@@ -25,6 +25,13 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
return nil
}
func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}) error {
body, err := json.Marshal(request)
if err != nil {