diff --git a/plugins/wasm-go/Dockerfile b/plugins/wasm-go/Dockerfile index 6b483aaa0..ab9cdf029 100644 --- a/plugins/wasm-go/Dockerfile +++ b/plugins/wasm-go/Dockerfile @@ -1,5 +1,5 @@ ARG BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.20.14-tinygo0.29.0-oras1.0.0 -FROM $BUILDER as builder +FROM $BUILDER AS builder ARG GOPROXY @@ -26,6 +26,6 @@ RUN \ tinygo build -o /main.wasm -scheduler=none -gc=custom -tags="custommalloc nottinygc_finalizer $EXTRA_TAGS" -target=wasi ./ ; \ fi -FROM scratch as output +FROM scratch AS output COPY --from=builder /main.wasm plugin.wasm diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index 3a9baaa2e..30341e643 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -11,11 +11,11 @@ require ( github.com/higress-group/proxy-wasm-go-sdk v1.0.0 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.3 + github.com/wasilibs/go-re2 v1.6.0 ) require ( github.com/tetratelabs/wazero v1.7.2 // indirect - github.com/wasilibs/go-re2 v1.6.0 // indirect golang.org/x/sys v0.21.0 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index b1b7172ac..066d7a2f2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -6,8 +6,6 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= -github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= -github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho= github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -29,6 +27,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8= github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04= +github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index 81a07bf20..1e4ceaddf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -13,6 +13,7 @@ import ( "hash/crc32" "io" "net/http" + "strconv" "strings" "time" @@ -32,7 +33,10 @@ const ( bedrockChatCompletionPath = "/model/%s/converse" // converseStream路径 /model/{modelId}/converse-stream bedrockStreamChatCompletionPath = "/model/%s/converse-stream" - bedrockSignedHeaders = "host;x-amz-date" + // invoke_model 路径 /model/{modelId}/invoke + bedrockInvokeModelPath = "/model/%s/invoke" + bedrockSignedHeaders = "host;x-amz-date" + requestIdHeader = "X-Amzn-Requestid" ) type bedrockProviderInitializer struct { @@ -50,7 +54,8 @@ func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) erro func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): bedrockChatCompletionPath, + string(ApiNameChatCompletion): bedrockChatCompletionPath, + string(ApiNameImageGeneration): bedrockInvokeModelPath, } } @@ -99,7 +104,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason) } choices = append(choices, chatChoice) - requestId := ctx.GetStringContext("X-Amzn-Requestid", "") + requestId := ctx.GetStringContext(requestIdHeader, "") openAIFormattedChunk := &chatCompletionResponse{ Id: requestId, Created: time.Now().UnixMilli() / 1000, @@ -152,6 +157,74 @@ type toolUseBlockDelta struct { Input string `json:"input"` } +type bedrockImageGenerationResponse struct { + Images []string `json:"images"` + Error string `json:"error"` +} + +type bedrockImageGenerationTextToImageParams struct { + Text string `json:"text"` + NegativeText string `json:"negativeText,omitempty"` + ConditionImage string `json:"conditionImage,omitempty"` + ControlMode string `json:"controlMode,omitempty"` + ControlStrength float32 `json:"controlLength,omitempty"` +} + +type bedrockImageGenerationConfig struct { + Width int `json:"width"` + Height int `json:"height"` + Quality string `json:"quality,omitempty"` + CfgScale float32 `json:"cfgScale,omitempty"` + Seed int `json:"seed,omitempty"` + NumberOfImages int `json:"numberOfImages"` +} + +type bedrockImageGenerationColorGuidedGenerationParams struct { + Colors []string `json:"colors"` + ReferenceImage string `json:"referenceImage"` + Text string `json:"text"` + NegativeText string `json:"negativeText,omitempty"` +} + +type bedrockImageGenerationImageVariationParams struct { + Images []string `json:"images"` + SimilarityStrength float32 `json:"similarityStrength"` + Text string `json:"text"` + NegativeText string `json:"negativeText,omitempty"` +} + +type bedrockImageGenerationInPaintingParams struct { + Image string `json:"image"` + MaskPrompt string `json:"maskPrompt"` + MaskImage string `json:"maskImage"` + Text string `json:"text"` + NegativeText string `json:"negativeText,omitempty"` +} + +type bedrockImageGenerationOutPaintingParams struct { + Image string `json:"image"` + MaskPrompt string `json:"maskPrompt"` + MaskImage string `json:"maskImage"` + OutPaintingMode string `json:"outPaintingMode"` + Text string `json:"text"` + NegativeText string `json:"negativeText,omitempty"` +} + +type bedrockImageGenerationBackgroundRemovalParams struct { + Image string `json:"image"` +} + +type bedrockImageGenerationRequest struct { + TaskType string `json:"taskType"` + ImageGenerationConfig *bedrockImageGenerationConfig `json:"imageGenerationConfig"` + TextToImageParams *bedrockImageGenerationTextToImageParams `json:"textToImageParams,omitempty"` + ColorGuidedGenerationParams *bedrockImageGenerationColorGuidedGenerationParams `json:"colorGuidedGenerationParams,omitempty"` + ImageVariationParams *bedrockImageGenerationImageVariationParams `json:"imageVariationParams,omitempty"` + InPaintingParams *bedrockImageGenerationInPaintingParams `json:"inPaintingParams,omitempty"` + OutPaintingParams *bedrockImageGenerationOutPaintingParams `json:"outPaintingParams,omitempty"` + BackgroundRemovalParams *bedrockImageGenerationBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"` +} + func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent { body := chunk if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { @@ -489,7 +562,7 @@ func validateCRC(r io.Reader, expect uint32) error { } func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { - ctx.SetContext("X-Amzn-Requestid", headers.Get("X-Amzn-Requestid")) + ctx.SetContext(requestIdHeader, headers.Get(requestIdHeader)) if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" { headers.Set("Content-Type", "text/event-stream; charset=utf-8") } @@ -537,18 +610,83 @@ func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a switch apiName { case ApiNameChatCompletion: return b.onChatCompletionRequestBody(ctx, body, headers) + case ApiNameImageGeneration: + return b.onImageGenerationRequestBody(ctx, body, headers) default: return b.config.defaultTransformRequestBody(ctx, apiName, body) } } func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return b.onChatCompletionResponseBody(ctx, body) + case ApiNameImageGeneration: + return b.onImageGenerationResponseBody(ctx, body) } return nil, errUnsupportedApiName } +func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + bedrockResponse := &bedrockImageGenerationResponse{} + if err := json.Unmarshal(body, bedrockResponse); err != nil { + log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err) + return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err) + } + response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse) + return json.Marshal(response) +} + +func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &imageGenerationRequest{} + err := b.config.parseRequestAndMapModel(ctx, request, body) + if err != nil { + return nil, err + } + headers.Set("Accept", "*/*") + util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model)) + return b.buildBedrockImageGenerationRequest(request, headers) +} + +func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageGenerationRequest, headers http.Header) ([]byte, error) { + width, height := 1024, 1024 + pairs := strings.Split(origRequest.Size, "x") + if len(pairs) == 2 { + width, _ = strconv.Atoi(pairs[0]) + height, _ = strconv.Atoi(pairs[1]) + } + + request := &bedrockImageGenerationRequest{ + TaskType: "TEXT_IMAGE", + TextToImageParams: &bedrockImageGenerationTextToImageParams{ + Text: origRequest.Prompt, + }, + ImageGenerationConfig: &bedrockImageGenerationConfig{ + NumberOfImages: origRequest.N, + Width: width, + Height: height, + Quality: origRequest.Quality, + }, + } + util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model)) + requestBytes, err := json.Marshal(request) + b.setAuthHeaders(requestBytes, headers) + return requestBytes, err +} + +func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse { + data := make([]imageGenerationData, len(bedrockResponse.Images)) + for i, image := range bedrockResponse.Images { + data[i] = imageGenerationData{ + B64: image, + } + } + return &imageGenerationResponse{ + Created: time.Now().UnixMilli() / 1000, + Data: data, + } +} + func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { bedrockResponse := &bedrockConverseResponse{} if err := json.Unmarshal(body, bedrockResponse); err != nil { @@ -613,7 +751,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason), }, } - requestId := ctx.GetStringContext("X-Amzn-Requestid", "") + requestId := ctx.GetStringContext(requestIdHeader, "") return &chatCompletionResponse{ Id: requestId, Created: time.Now().UnixMilli() / 1000, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index bb523296c..33de57293 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -356,10 +356,39 @@ func (e *StreamEvent) ToHttpString() string { // https://platform.openai.com/docs/guides/images type imageGenerationRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +type imageGenerationData struct { + URL string `json:"url,omitempty"` + B64 string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + +type imageGenerationUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokensDetails struct { + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` + } `json:"input_tokens_details"` +} + +type imageGenerationResponse struct { + Created int64 `json:"created"` + Data []imageGenerationData `json:"data"` + Usage *imageGenerationUsage `json:"usage,omitempty"` } // https://platform.openai.com/docs/guides/speech-to-text diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 9a0a2bee7..9b169469d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -534,6 +534,11 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques return err } return c.setRequestModel(ctx, req) + case *imageGenerationRequest: + if err := decodeImageGenerationRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req) default: return errors.New("unsupported request type") } @@ -547,6 +552,8 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf model = &req.Model case *embeddingsRequest: model = &req.Model + case *imageGenerationRequest: + model = &req.Model default: return errors.New("unsupported request type") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index 7907e33b5..52fa8a64d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -25,6 +25,13 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error { return nil } +func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) error { + if err := json.Unmarshal(body, request); err != nil { + return fmt.Errorf("unable to unmarshal request: %v", err) + } + return nil +} + func replaceJsonRequestBody(request interface{}) error { body, err := json.Marshal(request) if err != nil {