diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 7e25ca7ee..70a873e40 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -1,15 +1,21 @@ package provider import ( + "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" + "io" "net/http" + "net/url" + "strconv" "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/google/uuid" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" @@ -59,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi return &geminiProvider{ config: config, contextCache: createContextCache(&config), + client: wrapper.NewClusterClient(wrapper.RouteCluster{ + Host: geminiDomain, + }), }, nil } type geminiProvider struct { config ProviderConfig contextCache *contextCache + + client wrapper.HttpClient } func (g *geminiProvider) GetProviderType() string { @@ -83,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam util.OverwriteRequestAuthorizationHeader(headers, "") } +// to support the multimodal for gemini, we can't reuse the config's handleRequestBody func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !g.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } - return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body) + + if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) { + err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms", + strconv.FormatUint(uint64(g.config.firstByteTimeout), 10)) + if err != nil { + log.Errorf("failed to set timeout header: %v", err) + } + } + + if g.config.IsOriginal() { + return types.ActionContinue, nil + } + + headers := util.GetRequestHeaders() + request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers) + if err != nil { + return types.ActionContinue, err + } + util.ReplaceRequestHeaders(headers) + + if apiName == ApiNameChatCompletion { + if g.config.context != nil { + err = g.contextCache.GetContextFromFile(ctx, g, body) + if err == nil { + return types.ActionPause, nil + } + } + + if action, err := g.processImageURL(ctx, request); err != nil { + return action, err + } else { + return action, replaceRequestBody(request) + } + + } + return types.ActionContinue, replaceRequestBody(request) } func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { @@ -407,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) // shouldAddDummyModelMessage := false for _, message := range request.Messages { content := geminiChatContent{ - Role: message.Role, - Parts: []geminiPart{ - { - Text: message.StringContent(), - }, - }, + Role: message.Role, + Parts: []geminiPart{}, + } + + for _, c := range message.ParseContent() { + switch c.Type { + case contentTypeText: + content.Parts = append(content.Parts, geminiPart{ + Text: c.Text, + }) + case contentTypeImageUrl: + content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl)) + default: + log.Debugf("currently gemini did not support this type: %s", c.Type) + } } // there's no assistant role in gemini and API shall vomit if role is not user or model @@ -431,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) return &geminiRequest } +func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int { + totalImages := 0 + for _, c := range request.Contents { + for _, p := range c.Parts { + if p.InlineData != nil && g.isUrl(p.InlineData.Data) { + totalImages += 1 + } + } + } + return totalImages +} + +func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) { + request := &geminiGenerationContentRequest{} + err := json.Unmarshal(body, request) + if err != nil { + log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal") + return types.ActionContinue, err + } + var totalImages int + if totalImages = g.countImageUrl(request); totalImages == 0 { + // there are no images return directly + return types.ActionContinue, replaceRequestBody(body) + } + + if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + + if err != nil { + log.Errorf("failed to get image while handle multi modal: %v", err) + util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err) + return + } + // replace the request + if err := replaceRequestBody(body); err != nil { + util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err) + } + }); err != nil { + return types.ActionContinue, err + } + + return types.ActionPause, nil +} + +func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error { + request := &geminiGenerationContentRequest{} + err := json.Unmarshal(body, request) + if err != nil { + log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err) + return err + } + + pending := totalImages + var callbackErr []error + + for ci, c := range request.Contents { + for pi := range c.Parts { + p := &request.Contents[ci].Parts[pi] + if p.InlineData != nil && g.isUrl(p.InlineData.Data) { + g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) { + if err != nil { + log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err) + callbackErr = append(callbackErr, err) + } else { + *p.InlineData = *gid + } + + pending -= 1 + if pending == 0 { + body, err := json.Marshal(request) + if err != nil { + log.Errorf("failed to marshal request while processImageURL: %v", err) + callbackErr = append(callbackErr, err) + } + callback(body, errors.Join(callbackErr...)) + } + }) + } + } + } + return nil +} + +func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) { + if g.isUrl(c.Url) { + part.InlineData = &geminiInlineData{ + Data: c.Url, + } + return + } + part.InlineData = g.baseStr2InlineData(c.Url) + return +} + +func (g *geminiProvider) isUrl(raw string) bool { + u, err := url.Parse(raw) + return err == nil && (u.Scheme == "http" || u.Scheme == "https") +} + +func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData { + if strings.HasPrefix(baseStr, "data:") { + p := strings.SplitN(baseStr, ";", 2) + if len(p) != 2 { + log.Errorf("invalid base64 string: %s", p) + return nil + } + + mime := strings.TrimPrefix(p[0], "data:") + baseData := strings.TrimPrefix(p[1], "base64,") + return &geminiInlineData{ + MimeType: mime, + Data: baseData, + } + } + log.Errorf("invalid base64 string: %s", baseStr) + return &geminiInlineData{ + MimeType: "", + Data: "", + } +} + +func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) { + + responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { + callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode)) + return + } + resReader := bytes.NewReader(responseBody) + const maxSize = 100 << 20 + data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1)) + if err != nil { + callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err)) + return + } + if len(data) > maxSize { + callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw)) + return + } + + mimeType := http.DetectContentType(data) + base64Data := base64.StdEncoding.EncodeToString(data) + + callback(&geminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, nil) + } + + timeout := (time.Second * 30).Milliseconds() + + headers := [][2]string{ + {"Accept", "image/*"}, + {"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"}, + {"Referer", "https://www.google.com/"}, + } + if g.client == nil { + log.Error("client is nil") + return + } + err := g.client.Get(raw, headers, responseCallback, uint32(timeout)) + if err != nil { + log.Errorf("failed to get image %s data", raw) + callback(nil, fmt.Errorf("failed to get image %s", raw)) + return + } +} + func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) { systemContents := []geminiChatContent{{ Role: roleUser,