mirror of
https://github.com/alibaba/higress.git
synced 2026-05-28 06:37:26 +08:00
feat(ai-proxy): gemini model multimodal support (#2698)
This commit is contained in:
@@ -1,15 +1,21 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
@@ -59,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
|||||||
return &geminiProvider{
|
return &geminiProvider{
|
||||||
config: config,
|
config: config,
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
|
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||||
|
Host: geminiDomain,
|
||||||
|
}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type geminiProvider struct {
|
type geminiProvider struct {
|
||||||
config ProviderConfig
|
config ProviderConfig
|
||||||
contextCache *contextCache
|
contextCache *contextCache
|
||||||
|
|
||||||
|
client wrapper.HttpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) GetProviderType() string {
|
func (g *geminiProvider) GetProviderType() string {
|
||||||
@@ -83,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
|||||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// to support the multimodal for gemini, we can't reuse the config's handleRequestBody
|
||||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||||
if !g.config.isSupportedAPI(apiName) {
|
if !g.config.isSupportedAPI(apiName) {
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
return types.ActionContinue, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
|
||||||
|
if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) {
|
||||||
|
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms",
|
||||||
|
strconv.FormatUint(uint64(g.config.firstByteTimeout), 10))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to set timeout header: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.config.IsOriginal() {
|
||||||
|
return types.ActionContinue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := util.GetRequestHeaders()
|
||||||
|
request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||||
|
if err != nil {
|
||||||
|
return types.ActionContinue, err
|
||||||
|
}
|
||||||
|
util.ReplaceRequestHeaders(headers)
|
||||||
|
|
||||||
|
if apiName == ApiNameChatCompletion {
|
||||||
|
if g.config.context != nil {
|
||||||
|
err = g.contextCache.GetContextFromFile(ctx, g, body)
|
||||||
|
if err == nil {
|
||||||
|
return types.ActionPause, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if action, err := g.processImageURL(ctx, request); err != nil {
|
||||||
|
return action, err
|
||||||
|
} else {
|
||||||
|
return action, replaceRequestBody(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return types.ActionContinue, replaceRequestBody(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||||
@@ -407,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
|||||||
// shouldAddDummyModelMessage := false
|
// shouldAddDummyModelMessage := false
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
content := geminiChatContent{
|
content := geminiChatContent{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Parts: []geminiPart{
|
Parts: []geminiPart{},
|
||||||
{
|
}
|
||||||
Text: message.StringContent(),
|
|
||||||
},
|
for _, c := range message.ParseContent() {
|
||||||
},
|
switch c.Type {
|
||||||
|
case contentTypeText:
|
||||||
|
content.Parts = append(content.Parts, geminiPart{
|
||||||
|
Text: c.Text,
|
||||||
|
})
|
||||||
|
case contentTypeImageUrl:
|
||||||
|
content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl))
|
||||||
|
default:
|
||||||
|
log.Debugf("currently gemini did not support this type: %s", c.Type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// there's no assistant role in gemini and API shall vomit if role is not user or model
|
// there's no assistant role in gemini and API shall vomit if role is not user or model
|
||||||
@@ -431,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
|||||||
return &geminiRequest
|
return &geminiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int {
|
||||||
|
totalImages := 0
|
||||||
|
for _, c := range request.Contents {
|
||||||
|
for _, p := range c.Parts {
|
||||||
|
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||||
|
totalImages += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return totalImages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) {
|
||||||
|
request := &geminiGenerationContentRequest{}
|
||||||
|
err := json.Unmarshal(body, request)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal")
|
||||||
|
return types.ActionContinue, err
|
||||||
|
}
|
||||||
|
var totalImages int
|
||||||
|
if totalImages = g.countImageUrl(request); totalImages == 0 {
|
||||||
|
// there are no images return directly
|
||||||
|
return types.ActionContinue, replaceRequestBody(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) {
|
||||||
|
defer func() {
|
||||||
|
_ = proxywasm.ResumeHttpRequest()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get image while handle multi modal: %v", err)
|
||||||
|
util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// replace the request
|
||||||
|
if err := replaceRequestBody(body); err != nil {
|
||||||
|
util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err)
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return types.ActionContinue, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return types.ActionPause, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error {
|
||||||
|
request := &geminiGenerationContentRequest{}
|
||||||
|
err := json.Unmarshal(body, request)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pending := totalImages
|
||||||
|
var callbackErr []error
|
||||||
|
|
||||||
|
for ci, c := range request.Contents {
|
||||||
|
for pi := range c.Parts {
|
||||||
|
p := &request.Contents[ci].Parts[pi]
|
||||||
|
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||||
|
g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err)
|
||||||
|
callbackErr = append(callbackErr, err)
|
||||||
|
} else {
|
||||||
|
*p.InlineData = *gid
|
||||||
|
}
|
||||||
|
|
||||||
|
pending -= 1
|
||||||
|
if pending == 0 {
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to marshal request while processImageURL: %v", err)
|
||||||
|
callbackErr = append(callbackErr, err)
|
||||||
|
}
|
||||||
|
callback(body, errors.Join(callbackErr...))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) {
|
||||||
|
if g.isUrl(c.Url) {
|
||||||
|
part.InlineData = &geminiInlineData{
|
||||||
|
Data: c.Url,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
part.InlineData = g.baseStr2InlineData(c.Url)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) isUrl(raw string) bool {
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
return err == nil && (u.Scheme == "http" || u.Scheme == "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData {
|
||||||
|
if strings.HasPrefix(baseStr, "data:") {
|
||||||
|
p := strings.SplitN(baseStr, ";", 2)
|
||||||
|
if len(p) != 2 {
|
||||||
|
log.Errorf("invalid base64 string: %s", p)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mime := strings.TrimPrefix(p[0], "data:")
|
||||||
|
baseData := strings.TrimPrefix(p[1], "base64,")
|
||||||
|
return &geminiInlineData{
|
||||||
|
MimeType: mime,
|
||||||
|
Data: baseData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Errorf("invalid base64 string: %s", baseStr)
|
||||||
|
return &geminiInlineData{
|
||||||
|
MimeType: "",
|
||||||
|
Data: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) {
|
||||||
|
|
||||||
|
responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
if statusCode != http.StatusOK {
|
||||||
|
callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resReader := bytes.NewReader(responseBody)
|
||||||
|
const maxSize = 100 << 20
|
||||||
|
data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1))
|
||||||
|
if err != nil {
|
||||||
|
callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(data) > maxSize {
|
||||||
|
callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := http.DetectContentType(data)
|
||||||
|
base64Data := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
callback(&geminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := (time.Second * 30).Milliseconds()
|
||||||
|
|
||||||
|
headers := [][2]string{
|
||||||
|
{"Accept", "image/*"},
|
||||||
|
{"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"},
|
||||||
|
{"Referer", "https://www.google.com/"},
|
||||||
|
}
|
||||||
|
if g.client == nil {
|
||||||
|
log.Error("client is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := g.client.Get(raw, headers, responseCallback, uint32(timeout))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get image %s data", raw)
|
||||||
|
callback(nil, fmt.Errorf("failed to get image %s", raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
|
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
|
||||||
systemContents := []geminiChatContent{{
|
systemContents := []geminiChatContent{{
|
||||||
Role: roleUser,
|
Role: roleUser,
|
||||||
|
|||||||
Reference in New Issue
Block a user