feat(ai-proxy): gemini model multimodal support (#2698)

This commit is contained in:
xingpiaoliang
2025-08-11 15:54:34 +08:00
committed by GitHub
parent 953b95cf92
commit 0af00bef6b

View File

@@ -1,15 +1,21 @@
package provider
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
@@ -59,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
return &geminiProvider{
config: config,
contextCache: createContextCache(&config),
client: wrapper.NewClusterClient(wrapper.RouteCluster{
Host: geminiDomain,
}),
}, nil
}
type geminiProvider struct {
config ProviderConfig
contextCache *contextCache
client wrapper.HttpClient
}
func (g *geminiProvider) GetProviderType() string {
@@ -83,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestAuthorizationHeader(headers, "")
}
// to support the multimodal for gemini, we can't reuse the config's handleRequestBody
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) {
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms",
strconv.FormatUint(uint64(g.config.firstByteTimeout), 10))
if err != nil {
log.Errorf("failed to set timeout header: %v", err)
}
}
if g.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers)
if err != nil {
return types.ActionContinue, err
}
util.ReplaceRequestHeaders(headers)
if apiName == ApiNameChatCompletion {
if g.config.context != nil {
err = g.contextCache.GetContextFromFile(ctx, g, body)
if err == nil {
return types.ActionPause, nil
}
}
if action, err := g.processImageURL(ctx, request); err != nil {
return action, err
} else {
return action, replaceRequestBody(request)
}
}
return types.ActionContinue, replaceRequestBody(request)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
@@ -407,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
// shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := geminiChatContent{
Role: message.Role,
Parts: []geminiPart{
{
Text: message.StringContent(),
},
},
Role: message.Role,
Parts: []geminiPart{},
}
for _, c := range message.ParseContent() {
switch c.Type {
case contentTypeText:
content.Parts = append(content.Parts, geminiPart{
Text: c.Text,
})
case contentTypeImageUrl:
content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl))
default:
log.Debugf("currently gemini did not support this type: %s", c.Type)
}
}
// there's no assistant role in gemini and API shall vomit if role is not user or model
@@ -431,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
return &geminiRequest
}
func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int {
totalImages := 0
for _, c := range request.Contents {
for _, p := range c.Parts {
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
totalImages += 1
}
}
}
return totalImages
}
func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) {
request := &geminiGenerationContentRequest{}
err := json.Unmarshal(body, request)
if err != nil {
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal")
return types.ActionContinue, err
}
var totalImages int
if totalImages = g.countImageUrl(request); totalImages == 0 {
// there are no images return directly
return types.ActionContinue, replaceRequestBody(body)
}
if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to get image while handle multi modal: %v", err)
util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err)
return
}
// replace the request
if err := replaceRequestBody(body); err != nil {
util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err)
}
}); err != nil {
return types.ActionContinue, err
}
return types.ActionPause, nil
}
func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error {
request := &geminiGenerationContentRequest{}
err := json.Unmarshal(body, request)
if err != nil {
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err)
return err
}
pending := totalImages
var callbackErr []error
for ci, c := range request.Contents {
for pi := range c.Parts {
p := &request.Contents[ci].Parts[pi]
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) {
if err != nil {
log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err)
callbackErr = append(callbackErr, err)
} else {
*p.InlineData = *gid
}
pending -= 1
if pending == 0 {
body, err := json.Marshal(request)
if err != nil {
log.Errorf("failed to marshal request while processImageURL: %v", err)
callbackErr = append(callbackErr, err)
}
callback(body, errors.Join(callbackErr...))
}
})
}
}
}
return nil
}
func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) {
if g.isUrl(c.Url) {
part.InlineData = &geminiInlineData{
Data: c.Url,
}
return
}
part.InlineData = g.baseStr2InlineData(c.Url)
return
}
func (g *geminiProvider) isUrl(raw string) bool {
u, err := url.Parse(raw)
return err == nil && (u.Scheme == "http" || u.Scheme == "https")
}
func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData {
if strings.HasPrefix(baseStr, "data:") {
p := strings.SplitN(baseStr, ";", 2)
if len(p) != 2 {
log.Errorf("invalid base64 string: %s", p)
return nil
}
mime := strings.TrimPrefix(p[0], "data:")
baseData := strings.TrimPrefix(p[1], "base64,")
return &geminiInlineData{
MimeType: mime,
Data: baseData,
}
}
log.Errorf("invalid base64 string: %s", baseStr)
return &geminiInlineData{
MimeType: "",
Data: "",
}
}
func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) {
responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode))
return
}
resReader := bytes.NewReader(responseBody)
const maxSize = 100 << 20
data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1))
if err != nil {
callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err))
return
}
if len(data) > maxSize {
callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw))
return
}
mimeType := http.DetectContentType(data)
base64Data := base64.StdEncoding.EncodeToString(data)
callback(&geminiInlineData{
MimeType: mimeType,
Data: base64Data,
}, nil)
}
timeout := (time.Second * 30).Milliseconds()
headers := [][2]string{
{"Accept", "image/*"},
{"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"},
{"Referer", "https://www.google.com/"},
}
if g.client == nil {
log.Error("client is nil")
return
}
err := g.client.Get(raw, headers, responseCallback, uint32(timeout))
if err != nil {
log.Errorf("failed to get image %s data", raw)
callback(nil, fmt.Errorf("failed to get image %s", raw))
return
}
}
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
systemContents := []geminiChatContent{{
Role: roleUser,