feat(ai-proxy): add models & image generation support for gemini (#2380)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
Xijun Dai
2025-06-08 15:25:22 +08:00
committed by GitHub
parent 26cd6837d5
commit e674c780c6
4 changed files with 196 additions and 82 deletions

View File

@@ -230,9 +230,10 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置
Gemini 所对应的 `type``gemini`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| --------------------- | ------------- | -------- | ------ | -------------------------------------------------------------------------------------------------------------- |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| --------------------- | ------------- | -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1``v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
#### DeepL

View File

@@ -19,8 +19,8 @@ const (
claudeDomain = "api.anthropic.com"
claudeChatCompletionPath = "/v1/messages"
claudeCompletionPath = "/v1/complete"
defaultVersion = "2023-06-01"
defaultMaxTokens = 4096
claudeDefaultVersion = "2023-06-01"
claudeDefaultMaxTokens = 4096
)
type claudeProviderInitializer struct{}
@@ -124,11 +124,11 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
if c.config.apiVersion == "" {
c.config.apiVersion = claudeDefaultVersion
}
headers.Set("anthropic-version", c.config.claudeVersion)
headers.Set("anthropic-version", c.config.apiVersion)
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
@@ -212,7 +212,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
TopP: origRequest.TopP,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = defaultMaxTokens
claudeRequest.MaxTokens = claudeDefaultMaxTokens
}
for _, message := range origRequest.Messages {

View File

@@ -19,10 +19,13 @@ import (
const (
geminiApiKeyHeader = "x-goog-api-key"
geminiDefaultApiVersion = "v1beta" // 可选: v1, v1beta
geminiDomain = "generativelanguage.googleapis.com"
geminiChatCompletionPath = "generateContent"
geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
geminiEmbeddingPath = "batchEmbedContents"
geminiModelsPath = "models"
geminiImageGenerationPath = "predict"
)
type geminiProviderInitializer struct{}
@@ -36,8 +39,10 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): "",
string(ApiNameEmbeddings): "",
string(ApiNameChatCompletion): "",
string(ApiNameEmbeddings): "",
string(ApiNameModels): "",
string(ApiNameImageGeneration): "",
}
}
@@ -78,11 +83,38 @@ func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return g.onChatCompletionRequestBody(ctx, body, headers)
} else {
case ApiNameEmbeddings:
return g.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return g.onImageGenerationRequestBody(ctx, body, headers)
}
return body, nil
}
func (g *geminiProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageGenerationRequest{}
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
path := g.getRequestPath(ApiNameImageGeneration, request.Model, false)
log.Debugf("request path:%s", path)
util.OverwriteRequestPathHeader(headers, path)
geminiRequest := g.buildGeminiImageGenerationRequest(request)
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) buildGeminiImageGenerationRequest(request *imageGenerationRequest) *geminiImageGenerationRequest {
geminiRequest := &geminiImageGenerationRequest{
Instances: []geminiImageGenerationInstance{{Prompt: request.Prompt}},
Parameters: &geminiImageGenerationParameters{
SampleCount: request.N,
},
}
return geminiRequest
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
@@ -111,7 +143,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
log.Debugf("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
@@ -147,14 +179,43 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return g.onChatCompletionResponseBody(ctx, body)
} else {
case ApiNameEmbeddings:
return g.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration:
return g.onImageGenerationResponseBody(ctx, body)
default:
return body, nil
}
}
func (g *geminiProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
geminiResponse := &geminiImageGenerationResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini image generation response: %v", err)
}
response := g.buildImageGenerationResponse(ctx, geminiResponse)
return json.Marshal(response)
}
func (g *geminiProvider) buildImageGenerationResponse(ctx wrapper.HttpContext, geminiResponse *geminiImageGenerationResponse) *imageGenerationResponse {
data := make([]imageGenerationData, len(geminiResponse.Predictions))
for i, prediction := range geminiResponse.Predictions {
data[i] = imageGenerationData{
B64: prediction.BytesBase64Encoded,
}
}
response := &imageGenerationResponse{
Created: time.Now().UnixMilli() / 1000,
Data: data,
}
return response
}
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
log.Debugf("chat completion response body:%s", string(body))
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
@@ -180,26 +241,37 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
return json.Marshal(response)
}
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
action = geminiEmbeddingPath
} else if stream {
action = geminiChatCompletionStreamPath
} else {
action = geminiChatCompletionPath
if g.config.apiVersion == "" {
g.config.apiVersion = geminiDefaultApiVersion
}
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
switch apiName {
case ApiNameModels:
return fmt.Sprintf("/%s/%s", g.config.apiVersion, geminiModelsPath)
case ApiNameEmbeddings:
action = geminiEmbeddingPath
case ApiNameChatCompletion:
if stream {
action = geminiChatCompletionStreamPath
} else {
action = geminiChatCompletionPath
}
case ApiNameImageGeneration:
action = geminiImageGenerationPath
}
return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action)
}
type geminiChatRequest struct {
type geminiGenerationContentRequest struct {
// Model and Stream are only used when using the gemini original protocol
Model string `json:"model,omitempty"`
Stream bool `json:"stream,omitempty"`
Contents []geminiChatContent `json:"contents"`
SafetySettings []geminiChatSafetySetting `json:"safety_settings,omitempty"`
GenerationConfig geminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []geminiChatTools `json:"tools,omitempty"`
Model string `json:"model,omitempty"`
Stream bool `json:"stream,omitempty"`
Contents []geminiChatContent `json:"contents"`
SystemInstruction *geminiChatContent `json:"system_instruction,omitempty"`
SafetySettings []geminiChatSafetySetting `json:"safetySettings,omitempty"`
GenerationConfig geminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []geminiChatTools `json:"tools,omitempty"`
}
type geminiChatContent struct {
@@ -212,13 +284,26 @@ type geminiChatSafetySetting struct {
Threshold string `json:"threshold"`
}
type geminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget int64 `json:"thinkingBudget,omitempty"`
}
type geminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int64 `json:"topK,omitempty"`
Seed int64 `json:"seed,omitempty"`
Logprobs bool `json:"logprobs,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
PresencePenalty int64 `json:"presencePenalty,omitempty"`
FrequencyPenalty int64 `json:"frequencyPenalty,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
NegativePrompt string `json:"negativePrompt,omitempty"`
ThinkingConfig *geminiThinkingConfig `json:"thinkingConfig,omitempty"`
MediaResolution string `json:"mediaResolution,omitempty"`
}
type geminiChatTools struct {
@@ -241,25 +326,52 @@ type geminiFunctionCall struct {
Arguments any `json:"args"`
}
func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiChatRequest {
// geminiImageGenerationRequest is the request body for generate image using Imagen 3
type geminiImageGenerationRequest struct {
Instances []geminiImageGenerationInstance `json:"instances"`
Parameters *geminiImageGenerationParameters `json:"parameters,omitempty"`
}
type geminiImageGenerationInstance struct {
Prompt string `json:"prompt"`
}
type geminiImageGenerationParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
}
type geminiImageGenerationPrediction struct {
BytesBase64Encoded string `json:"bytesBase64Encoded"`
MimeType string `json:"mimeType"`
}
type geminiImageGenerationResponse struct {
Predictions []geminiImageGenerationPrediction `json:"predictions"`
}
func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiGenerationContentRequest {
var safetySettings []geminiChatSafetySetting
{
}
for category, threshold := range g.config.geminiSafetySetting {
safetySettings = append(safetySettings, geminiChatSafetySetting{
Category: category,
Threshold: threshold,
})
}
geminiRequest := geminiChatRequest{
geminiRequest := geminiGenerationContentRequest{
Contents: make([]geminiChatContent, 0, len(request.Messages)),
SafetySettings: safetySettings,
GenerationConfig: geminiChatGenerationConfig{
Temperature: request.Temperature,
TopP: request.TopP,
MaxOutputTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
MaxOutputTokens: request.MaxTokens,
PresencePenalty: int64(request.PresencePenalty),
FrequencyPenalty: int64(request.FrequencyPenalty),
Logprobs: request.Logprobs,
ResponseModalities: request.Modalities,
},
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
for _, tool := range request.Tools {
@@ -271,7 +383,7 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
},
}
}
shouldAddDummyModelMessage := false
// shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := geminiChatContent{
Role: message.Role,
@@ -283,32 +395,22 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
}
// there's no assistant role in gemini and API shall vomit if role is not user or model
if content.Role == roleAssistant {
switch content.Role {
case roleSystem:
content.Role = ""
geminiRequest.SystemInstruction = &content
continue
case roleAssistant:
content.Role = "model"
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
content.Role = roleUser
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// if a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, geminiChatContent{
Role: "model",
Parts: []geminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
func (g *geminiProvider) setSystemContent(request *geminiChatRequest, content string) {
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
systemContents := []geminiChatContent{{
Role: roleUser,
Parts: []geminiPart{
@@ -398,32 +500,34 @@ func (g *geminiProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
Object: objectChatCompletion,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: make([]chatCompletionChoice, 0, len(response.Candidates)),
Usage: usage{
PromptTokens: response.UsageMetadata.PromptTokenCount,
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
TotalTokens: response.UsageMetadata.TotalTokenCount,
},
}
for i, candidate := range response.Candidates {
choice := chatCompletionChoice{
Index: i,
Message: &chatMessage{
Role: roleAssistant,
},
FinishReason: finishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = g.buildToolCalls(&candidate)
} else {
choice.Message.Content = candidate.Content.Parts[0].Text
choiceIndex := 0
for _, candidate := range response.Candidates {
for _, part := range candidate.Content.Parts {
choice := chatCompletionChoice{
Index: choiceIndex,
Message: &chatMessage{
Role: roleAssistant,
},
FinishReason: finishReasonStop,
}
} else {
choice.Message.Content = ""
if part.FunctionCall != nil {
choice.Message.ToolCalls = g.buildToolCalls(&candidate)
} else if part.InlineData != nil {
choice.Message.Content = part.InlineData.Data
} else {
choice.Message.Content = part.Text
}
choice.FinishReason = candidate.FinishReason
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
choiceIndex += 1
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
@@ -511,5 +615,8 @@ func (g *geminiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, geminiEmbeddingPath) {
return ApiNameEmbeddings
}
if strings.Contains(path, geminiImageGenerationPath) {
return ApiNameImageGeneration
}
return ""
}

View File

@@ -290,8 +290,8 @@ type ProviderConfig struct {
// @Description zh-CN 配置一个外部获取对话上下文的文件来源用于在AI请求中补充对话上下文
context *ContextConfig `required:"false" yaml:"context" json:"context"`
// @Title zh-CN 版本
// @Description zh-CN 请求AI服务的版本目前仅适用于Claude AI服务
claudeVersion string `required:"false" yaml:"version" json:"version"`
// @Description zh-CN 请求AI服务的版本目前仅适用于 Gemini 和 Claude AI服务
apiVersion string `required:"false" yaml:"apiVersion" json:"apiVersion"`
// @Title zh-CN Cloudflare Account ID
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
@@ -375,7 +375,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.context = &ContextConfig{}
c.context.FromJson(contextJson)
}
c.claudeVersion = json.Get("claudeVersion").String()
// 这里获取 claudeVersion 字段,与结构体中定义 yaml/json 的 tag 不一致
c.apiVersion = json.Get("claudeVersion").String()
if c.apiVersion == "" {
// 增加获取 version 字段,用于适配其他模型的配置,并保持与结构体中定义的 tag 一致
c.apiVersion = json.Get("apiVersion").String()
}
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
c.awsAccessKey = json.Get("awsAccessKey").String()