mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
feat(ai-proxy): support Amazon Bedrock Image Generation (#2212)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com> Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
ARG BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.20.14-tinygo0.29.0-oras1.0.0
|
ARG BUILDER=higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go1.20.14-tinygo0.29.0-oras1.0.0
|
||||||
FROM $BUILDER as builder
|
FROM $BUILDER AS builder
|
||||||
|
|
||||||
|
|
||||||
ARG GOPROXY
|
ARG GOPROXY
|
||||||
@@ -26,6 +26,6 @@ RUN \
|
|||||||
tinygo build -o /main.wasm -scheduler=none -gc=custom -tags="custommalloc nottinygc_finalizer $EXTRA_TAGS" -target=wasi ./ ; \
|
tinygo build -o /main.wasm -scheduler=none -gc=custom -tags="custommalloc nottinygc_finalizer $EXTRA_TAGS" -target=wasi ./ ; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
FROM scratch as output
|
FROM scratch AS output
|
||||||
|
|
||||||
COPY --from=builder /main.wasm plugin.wasm
|
COPY --from=builder /main.wasm plugin.wasm
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ require (
|
|||||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/tidwall/gjson v1.17.3
|
github.com/tidwall/gjson v1.17.3
|
||||||
|
github.com/wasilibs/go-re2 v1.6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||||
github.com/wasilibs/go-re2 v1.6.0 // indirect
|
|
||||||
golang.org/x/sys v0.21.0 // indirect
|
golang.org/x/sys v0.21.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
|
|||||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
|
||||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
|
||||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
|
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
|
||||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
@@ -29,6 +27,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
|
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
|
||||||
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
|
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
|
||||||
|
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
|
||||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,7 +33,10 @@ const (
|
|||||||
bedrockChatCompletionPath = "/model/%s/converse"
|
bedrockChatCompletionPath = "/model/%s/converse"
|
||||||
// converseStream路径 /model/{modelId}/converse-stream
|
// converseStream路径 /model/{modelId}/converse-stream
|
||||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||||
|
// invoke_model 路径 /model/{modelId}/invoke
|
||||||
|
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||||
bedrockSignedHeaders = "host;x-amz-date"
|
bedrockSignedHeaders = "host;x-amz-date"
|
||||||
|
requestIdHeader = "X-Amzn-Requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bedrockProviderInitializer struct {
|
type bedrockProviderInitializer struct {
|
||||||
@@ -51,6 +55,7 @@ func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
|||||||
func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string {
|
func (b *bedrockProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): bedrockChatCompletionPath,
|
string(ApiNameChatCompletion): bedrockChatCompletionPath,
|
||||||
|
string(ApiNameImageGeneration): bedrockInvokeModelPath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +104,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
|||||||
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
|
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
|
||||||
}
|
}
|
||||||
choices = append(choices, chatChoice)
|
choices = append(choices, chatChoice)
|
||||||
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
|
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||||
openAIFormattedChunk := &chatCompletionResponse{
|
openAIFormattedChunk := &chatCompletionResponse{
|
||||||
Id: requestId,
|
Id: requestId,
|
||||||
Created: time.Now().UnixMilli() / 1000,
|
Created: time.Now().UnixMilli() / 1000,
|
||||||
@@ -152,6 +157,74 @@ type toolUseBlockDelta struct {
|
|||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationResponse struct {
|
||||||
|
Images []string `json:"images"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationTextToImageParams struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
NegativeText string `json:"negativeText,omitempty"`
|
||||||
|
ConditionImage string `json:"conditionImage,omitempty"`
|
||||||
|
ControlMode string `json:"controlMode,omitempty"`
|
||||||
|
ControlStrength float32 `json:"controlLength,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationConfig struct {
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
CfgScale float32 `json:"cfgScale,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
NumberOfImages int `json:"numberOfImages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationColorGuidedGenerationParams struct {
|
||||||
|
Colors []string `json:"colors"`
|
||||||
|
ReferenceImage string `json:"referenceImage"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
NegativeText string `json:"negativeText,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationImageVariationParams struct {
|
||||||
|
Images []string `json:"images"`
|
||||||
|
SimilarityStrength float32 `json:"similarityStrength"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
NegativeText string `json:"negativeText,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationInPaintingParams struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
MaskPrompt string `json:"maskPrompt"`
|
||||||
|
MaskImage string `json:"maskImage"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
NegativeText string `json:"negativeText,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationOutPaintingParams struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
MaskPrompt string `json:"maskPrompt"`
|
||||||
|
MaskImage string `json:"maskImage"`
|
||||||
|
OutPaintingMode string `json:"outPaintingMode"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
NegativeText string `json:"negativeText,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationBackgroundRemovalParams struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bedrockImageGenerationRequest struct {
|
||||||
|
TaskType string `json:"taskType"`
|
||||||
|
ImageGenerationConfig *bedrockImageGenerationConfig `json:"imageGenerationConfig"`
|
||||||
|
TextToImageParams *bedrockImageGenerationTextToImageParams `json:"textToImageParams,omitempty"`
|
||||||
|
ColorGuidedGenerationParams *bedrockImageGenerationColorGuidedGenerationParams `json:"colorGuidedGenerationParams,omitempty"`
|
||||||
|
ImageVariationParams *bedrockImageGenerationImageVariationParams `json:"imageVariationParams,omitempty"`
|
||||||
|
InPaintingParams *bedrockImageGenerationInPaintingParams `json:"inPaintingParams,omitempty"`
|
||||||
|
OutPaintingParams *bedrockImageGenerationOutPaintingParams `json:"outPaintingParams,omitempty"`
|
||||||
|
BackgroundRemovalParams *bedrockImageGenerationBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent {
|
func extractAmazonEventStreamEvents(ctx wrapper.HttpContext, chunk []byte) []ConverseStreamEvent {
|
||||||
body := chunk
|
body := chunk
|
||||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||||
@@ -489,7 +562,7 @@ func validateCRC(r io.Reader, expect uint32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
func (b *bedrockProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||||
ctx.SetContext("X-Amzn-Requestid", headers.Get("X-Amzn-Requestid"))
|
ctx.SetContext(requestIdHeader, headers.Get(requestIdHeader))
|
||||||
if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" {
|
if headers.Get("Content-Type") == "application/vnd.amazon.eventstream" {
|
||||||
headers.Set("Content-Type", "text/event-stream; charset=utf-8")
|
headers.Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
}
|
}
|
||||||
@@ -537,18 +610,83 @@ func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
|||||||
switch apiName {
|
switch apiName {
|
||||||
case ApiNameChatCompletion:
|
case ApiNameChatCompletion:
|
||||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||||
|
case ApiNameImageGeneration:
|
||||||
|
return b.onImageGenerationRequestBody(ctx, body, headers)
|
||||||
default:
|
default:
|
||||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||||
if apiName == ApiNameChatCompletion {
|
switch apiName {
|
||||||
|
case ApiNameChatCompletion:
|
||||||
return b.onChatCompletionResponseBody(ctx, body)
|
return b.onChatCompletionResponseBody(ctx, body)
|
||||||
|
case ApiNameImageGeneration:
|
||||||
|
return b.onImageGenerationResponseBody(ctx, body)
|
||||||
}
|
}
|
||||||
return nil, errUnsupportedApiName
|
return nil, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||||
|
bedrockResponse := &bedrockImageGenerationResponse{}
|
||||||
|
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||||
|
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
|
||||||
|
}
|
||||||
|
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
|
||||||
|
return json.Marshal(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
|
request := &imageGenerationRequest{}
|
||||||
|
err := b.config.parseRequestAndMapModel(ctx, request, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
headers.Set("Accept", "*/*")
|
||||||
|
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
|
||||||
|
return b.buildBedrockImageGenerationRequest(request, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageGenerationRequest, headers http.Header) ([]byte, error) {
|
||||||
|
width, height := 1024, 1024
|
||||||
|
pairs := strings.Split(origRequest.Size, "x")
|
||||||
|
if len(pairs) == 2 {
|
||||||
|
width, _ = strconv.Atoi(pairs[0])
|
||||||
|
height, _ = strconv.Atoi(pairs[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &bedrockImageGenerationRequest{
|
||||||
|
TaskType: "TEXT_IMAGE",
|
||||||
|
TextToImageParams: &bedrockImageGenerationTextToImageParams{
|
||||||
|
Text: origRequest.Prompt,
|
||||||
|
},
|
||||||
|
ImageGenerationConfig: &bedrockImageGenerationConfig{
|
||||||
|
NumberOfImages: origRequest.N,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
Quality: origRequest.Quality,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
|
||||||
|
requestBytes, err := json.Marshal(request)
|
||||||
|
b.setAuthHeaders(requestBytes, headers)
|
||||||
|
return requestBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||||
|
data := make([]imageGenerationData, len(bedrockResponse.Images))
|
||||||
|
for i, image := range bedrockResponse.Images {
|
||||||
|
data[i] = imageGenerationData{
|
||||||
|
B64: image,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &imageGenerationResponse{
|
||||||
|
Created: time.Now().UnixMilli() / 1000,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
func (b *bedrockProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||||
bedrockResponse := &bedrockConverseResponse{}
|
bedrockResponse := &bedrockConverseResponse{}
|
||||||
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||||
@@ -613,7 +751,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
|||||||
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
|
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
requestId := ctx.GetStringContext("X-Amzn-Requestid", "")
|
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||||
return &chatCompletionResponse{
|
return &chatCompletionResponse{
|
||||||
Id: requestId,
|
Id: requestId,
|
||||||
Created: time.Now().UnixMilli() / 1000,
|
Created: time.Now().UnixMilli() / 1000,
|
||||||
|
|||||||
@@ -358,10 +358,39 @@ func (e *StreamEvent) ToHttpString() string {
|
|||||||
type imageGenerationRequest struct {
|
type imageGenerationRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type imageGenerationData struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64 string `json:"b64_json,omitempty"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageGenerationUsage struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
InputTokensDetails struct {
|
||||||
|
TextTokens int `json:"text_tokens"`
|
||||||
|
ImageTokens int `json:"image_tokens"`
|
||||||
|
} `json:"input_tokens_details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageGenerationResponse struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []imageGenerationData `json:"data"`
|
||||||
|
Usage *imageGenerationUsage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// https://platform.openai.com/docs/guides/speech-to-text
|
// https://platform.openai.com/docs/guides/speech-to-text
|
||||||
type audioSpeechRequest struct {
|
type audioSpeechRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|||||||
@@ -534,6 +534,11 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.setRequestModel(ctx, req)
|
return c.setRequestModel(ctx, req)
|
||||||
|
case *imageGenerationRequest:
|
||||||
|
if err := decodeImageGenerationRequest(body, req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.setRequestModel(ctx, req)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
@@ -547,6 +552,8 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
|||||||
model = &req.Model
|
model = &req.Model
|
||||||
case *embeddingsRequest:
|
case *embeddingsRequest:
|
||||||
model = &req.Model
|
model = &req.Model
|
||||||
|
case *imageGenerationRequest:
|
||||||
|
model = &req.Model
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported request type")
|
return errors.New("unsupported request type")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,13 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) error {
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func replaceJsonRequestBody(request interface{}) error {
|
func replaceJsonRequestBody(request interface{}) error {
|
||||||
body, err := json.Marshal(request)
|
body, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user