mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(ai-proxy): support Amazon Bedrock Image Generation (#2212)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com> Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
ARG BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.20.14-tinygo0.29.0-oras1.0.0
|
||||
FROM $BUILDER as builder
|
||||
FROM $BUILDER AS builder
|
||||
|
||||
|
||||
ARG GOPROXY
|
||||
@@ -26,6 +26,6 @@ RUN \
|
||||
tinygo build -o /main.wasm -scheduler=none -gc=custom -tags="custommalloc nottinygc_finalizer $EXTRA_TAGS" -target=wasi ./ ; \
|
||||
fi
|
||||
|
||||
FROM scratch as output
|
||||
FROM scratch AS output
|
||||
|
||||
COPY --from=builder /main.wasm plugin.wasm
|
||||
|
||||
@@ -11,11 +11,11 @@ require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
github.com/wasilibs/go-re2 v1.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/wasilibs/go-re2 v1.6.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
|
||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -29,6 +27,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
|
||||
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
|
||||
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,7 +33,10 @@ const (
|
||||
bedrockChatCompletionPath = "/model/%s/converse"
|
||||
// converseStream路径 /model/{modelId}/converse-stream
|
||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
// invoke_model 路径 /model/{modelId}/invoke
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
)
|
||||
|
||||
type bedrockProviderInitializer struct {
|
||||
@@ -50,7 +54,8 @@ func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
|
||||
func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): bedrockChatCompletionPath,
|
||||
string(ApiNameChatCompletion): bedrockChatCompletionPath,
|
||||
string(ApiNameImageGeneration): bedrockInvokeModelPath,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +104,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
|
||||
}
|
||||
choices = append(choices, chatChoice)
|
||||
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
openAIFormattedChunk := &chatCompletionResponse{
|
||||
Id: requestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
@@ -152,6 +157,74 @@ type toolUseBlockDelta struct {
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationResponse struct {
|
||||
Images []string `json:"images"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationTextToImageParams struct {
|
||||
Text string `json:"text"`
|
||||
NegativeText string `json:"negativeText,omitempty"`
|
||||
ConditionImage string `json:"conditionImage,omitempty"`
|
||||
ControlMode string `json:"controlMode,omitempty"`
|
||||
ControlStrength float32 `json:"controlLength,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationConfig struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
CfgScale float32 `json:"cfgScale,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
NumberOfImages int `json:"numberOfImages"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationColorGuidedGenerationParams struct {
|
||||
Colors []string `json:"colors"`
|
||||
ReferenceImage string `json:"referenceImage"`
|
||||
Text string `json:"text"`
|
||||
NegativeText string `json:"negativeText,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationImageVariationParams struct {
|
||||
Images []string `json:"images"`
|
||||
SimilarityStrength float32 `json:"similarityStrength"`
|
||||
Text string `json:"text"`
|
||||
NegativeText string `json:"negativeText,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationInPaintingParams struct {
|
||||
Image string `json:"image"`
|
||||
MaskPrompt string `json:"maskPrompt"`
|
||||
MaskImage string `json:"maskImage"`
|
||||
Text string `json:"text"`
|
||||
NegativeText string `json:"negativeText,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationOutPaintingParams struct {
|
||||
Image string `json:"image"`
|
||||
MaskPrompt string `json:"maskPrompt"`
|
||||
MaskImage string `json:"maskImage"`
|
||||
OutPaintingMode string `json:"outPaintingMode"`
|
||||
Text string `json:"text"`
|
||||
NegativeText string `json:"negativeText,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationBackgroundRemovalParams struct {
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type bedrockImageGenerationRequest struct {
|
||||
TaskType string `json:"taskType"`
|
||||
ImageGenerationConfig *bedrockImageGenerationConfig `json:"imageGenerationConfig"`
|
||||
TextToImageParams *bedrockImageGenerationTextToImageParams `json:"textToImageParams,omitempty"`
|
||||
ColorGuidedGenerationParams *bedrockImageGenerationColorGuidedGenerationParams `json:"colorGuidedGenerationParams,omitempty"`
|
||||
ImageVariationParams *bedrockImageGenerationImageVariationParams `json:"imageVariationParams,omitempty"`
|
||||
InPaintingParams *bedrockImageGenerationInPaintingParams `json:"inPaintingParams,omitempty"`
|
||||
OutPaintingParams *bedrockImageGenerationOutPaintingParams `json:"outPaintingParams,omitempty"`
|
||||
BackgroundRemovalParams *bedrockImageGenerationBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"`
|
||||
}
|
||||
|
||||
func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
@@ -489,7 +562,7 @@ func validateCRC(r io.Reader, expect uint32) error {
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
ctx.SetContext("X-Amzn-Requestid", headers.Get("X-Amzn-Requestid"))
|
||||
ctx.SetContext(requestIdHeader, headers.Get(requestIdHeader))
|
||||
if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" {
|
||||
headers.Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
}
|
||||
@@ -537,18 +610,83 @@ func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationResponseBody(ctx, body)
|
||||
}
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
bedrockResponse := &bedrockImageGenerationResponse{}
|
||||
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
|
||||
}
|
||||
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageGenerationRequest{}
|
||||
err := b.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers.Set("Accept", "*/*")
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
|
||||
return b.buildBedrockImageGenerationRequest(request, headers)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageGenerationRequest, headers http.Header) ([]byte, error) {
|
||||
width, height := 1024, 1024
|
||||
pairs := strings.Split(origRequest.Size, "x")
|
||||
if len(pairs) == 2 {
|
||||
width, _ = strconv.Atoi(pairs[0])
|
||||
height, _ = strconv.Atoi(pairs[1])
|
||||
}
|
||||
|
||||
request := &bedrockImageGenerationRequest{
|
||||
TaskType: "TEXT_IMAGE",
|
||||
TextToImageParams: &bedrockImageGenerationTextToImageParams{
|
||||
Text: origRequest.Prompt,
|
||||
},
|
||||
ImageGenerationConfig: &bedrockImageGenerationConfig{
|
||||
NumberOfImages: origRequest.N,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Quality: origRequest.Quality,
|
||||
},
|
||||
}
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
data := make([]imageGenerationData, len(bedrockResponse.Images))
|
||||
for i, image := range bedrockResponse.Images {
|
||||
data[i] = imageGenerationData{
|
||||
B64: image,
|
||||
}
|
||||
}
|
||||
return &imageGenerationResponse{
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
bedrockResponse := &bedrockConverseResponse{}
|
||||
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||
@@ -613,7 +751,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
|
||||
},
|
||||
}
|
||||
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
return &chatCompletionResponse{
|
||||
Id: requestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
|
||||
@@ -356,10 +356,39 @@ func (e *StreamEvent) ToHttpString() string {
|
||||
|
||||
// https://platform.openai.com/docs/guides/images
|
||||
type imageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type imageGenerationData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64 string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type imageGenerationUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
}
|
||||
|
||||
type imageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []imageGenerationData `json:"data"`
|
||||
Usage *imageGenerationUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/speech-to-text
|
||||
|
||||
@@ -534,6 +534,11 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageGenerationRequest:
|
||||
if err := decodeImageGenerationRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -547,6 +552,8 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
model = &req.Model
|
||||
case *embeddingsRequest:
|
||||
model = &req.Model
|
||||
case *imageGenerationRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
@@ -25,6 +25,13 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user