mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(ai-proxy): gemini model multimodal support (#2698)
This commit is contained in:
@@ -1,15 +1,21 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
@@ -59,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
return &geminiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: geminiDomain,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type geminiProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (g *geminiProvider) GetProviderType() string {
|
||||
@@ -83,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||
}
|
||||
|
||||
// to support the multimodal for gemini, we can't reuse the config's handleRequestBody
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
|
||||
if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) {
|
||||
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms",
|
||||
strconv.FormatUint(uint64(g.config.firstByteTimeout), 10))
|
||||
if err != nil {
|
||||
log.Errorf("failed to set timeout header: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
headers := util.GetRequestHeaders()
|
||||
request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if g.config.context != nil {
|
||||
err = g.contextCache.GetContextFromFile(ctx, g, body)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
}
|
||||
|
||||
if action, err := g.processImageURL(ctx, request); err != nil {
|
||||
return action, err
|
||||
} else {
|
||||
return action, replaceRequestBody(request)
|
||||
}
|
||||
|
||||
}
|
||||
return types.ActionContinue, replaceRequestBody(request)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
@@ -407,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
// shouldAddDummyModelMessage := false
|
||||
for _, message := range request.Messages {
|
||||
content := geminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []geminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
Role: message.Role,
|
||||
Parts: []geminiPart{},
|
||||
}
|
||||
|
||||
for _, c := range message.ParseContent() {
|
||||
switch c.Type {
|
||||
case contentTypeText:
|
||||
content.Parts = append(content.Parts, geminiPart{
|
||||
Text: c.Text,
|
||||
})
|
||||
case contentTypeImageUrl:
|
||||
content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl))
|
||||
default:
|
||||
log.Debugf("currently gemini did not support this type: %s", c.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if role is not user or model
|
||||
@@ -431,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int {
|
||||
totalImages := 0
|
||||
for _, c := range request.Contents {
|
||||
for _, p := range c.Parts {
|
||||
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||
totalImages += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalImages
|
||||
}
|
||||
|
||||
func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) {
|
||||
request := &geminiGenerationContentRequest{}
|
||||
err := json.Unmarshal(body, request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal")
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
var totalImages int
|
||||
if totalImages = g.countImageUrl(request); totalImages == 0 {
|
||||
// there are no images return directly
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to get image while handle multi modal: %v", err)
|
||||
util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err)
|
||||
return
|
||||
}
|
||||
// replace the request
|
||||
if err := replaceRequestBody(body); err != nil {
|
||||
util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err)
|
||||
}
|
||||
}); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error {
|
||||
request := &geminiGenerationContentRequest{}
|
||||
err := json.Unmarshal(body, request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
pending := totalImages
|
||||
var callbackErr []error
|
||||
|
||||
for ci, c := range request.Contents {
|
||||
for pi := range c.Parts {
|
||||
p := &request.Contents[ci].Parts[pi]
|
||||
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||
g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) {
|
||||
if err != nil {
|
||||
log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err)
|
||||
callbackErr = append(callbackErr, err)
|
||||
} else {
|
||||
*p.InlineData = *gid
|
||||
}
|
||||
|
||||
pending -= 1
|
||||
if pending == 0 {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal request while processImageURL: %v", err)
|
||||
callbackErr = append(callbackErr, err)
|
||||
}
|
||||
callback(body, errors.Join(callbackErr...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) {
|
||||
if g.isUrl(c.Url) {
|
||||
part.InlineData = &geminiInlineData{
|
||||
Data: c.Url,
|
||||
}
|
||||
return
|
||||
}
|
||||
part.InlineData = g.baseStr2InlineData(c.Url)
|
||||
return
|
||||
}
|
||||
|
||||
func (g *geminiProvider) isUrl(raw string) bool {
|
||||
u, err := url.Parse(raw)
|
||||
return err == nil && (u.Scheme == "http" || u.Scheme == "https")
|
||||
}
|
||||
|
||||
func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData {
|
||||
if strings.HasPrefix(baseStr, "data:") {
|
||||
p := strings.SplitN(baseStr, ";", 2)
|
||||
if len(p) != 2 {
|
||||
log.Errorf("invalid base64 string: %s", p)
|
||||
return nil
|
||||
}
|
||||
|
||||
mime := strings.TrimPrefix(p[0], "data:")
|
||||
baseData := strings.TrimPrefix(p[1], "base64,")
|
||||
return &geminiInlineData{
|
||||
MimeType: mime,
|
||||
Data: baseData,
|
||||
}
|
||||
}
|
||||
log.Errorf("invalid base64 string: %s", baseStr)
|
||||
return &geminiInlineData{
|
||||
MimeType: "",
|
||||
Data: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) {
|
||||
|
||||
responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode))
|
||||
return
|
||||
}
|
||||
resReader := bytes.NewReader(responseBody)
|
||||
const maxSize = 100 << 20
|
||||
data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1))
|
||||
if err != nil {
|
||||
callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err))
|
||||
return
|
||||
}
|
||||
if len(data) > maxSize {
|
||||
callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw))
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := http.DetectContentType(data)
|
||||
base64Data := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
callback(&geminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
timeout := (time.Second * 30).Milliseconds()
|
||||
|
||||
headers := [][2]string{
|
||||
{"Accept", "image/*"},
|
||||
{"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"},
|
||||
{"Referer", "https://www.google.com/"},
|
||||
}
|
||||
if g.client == nil {
|
||||
log.Error("client is nil")
|
||||
return
|
||||
}
|
||||
err := g.client.Get(raw, headers, responseCallback, uint32(timeout))
|
||||
if err != nil {
|
||||
log.Errorf("failed to get image %s data", raw)
|
||||
callback(nil, fmt.Errorf("failed to get image %s", raw))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
|
||||
systemContents := []geminiChatContent{{
|
||||
Role: roleUser,
|
||||
|
||||
Reference in New Issue
Block a user