diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 643a0a083..ce4b15581 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -230,9 +230,10 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置 Gemini 所对应的 `type` 为 `gemini`。它特有的配置字段如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| --------------------- | ------------- | -------- | ------ | -------------------------------------------------------------------------------------------------------------- | -| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| --------------------- | ------------- | -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | +| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1` 或 `v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 | #### DeepL diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index e5140ef0c..cae6d85e7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -19,8 +19,8 @@ const ( claudeDomain = "api.anthropic.com" claudeChatCompletionPath = "/v1/messages" claudeCompletionPath = "/v1/complete" - defaultVersion = "2023-06-01" - defaultMaxTokens = 4096 + claudeDefaultVersion = "2023-06-01" + claudeDefaultMaxTokens = 4096 ) type claudeProviderInitializer struct{} @@ -124,11 +124,11 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) - if c.config.claudeVersion == "" { - c.config.claudeVersion = defaultVersion + if c.config.apiVersion == "" { + c.config.apiVersion = claudeDefaultVersion } - headers.Set("anthropic-version", c.config.claudeVersion) + headers.Set("anthropic-version", c.config.apiVersion) } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { @@ -212,7 +212,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe TopP: origRequest.TopP, } if claudeRequest.MaxTokens == 0 { - claudeRequest.MaxTokens = defaultMaxTokens + claudeRequest.MaxTokens = claudeDefaultMaxTokens } for _, message := range origRequest.Messages { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 4c60c2320..b86f9e03a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -19,10 +19,13 @@ import ( const ( geminiApiKeyHeader = "x-goog-api-key" + geminiDefaultApiVersion = "v1beta" // 可选: v1, v1beta geminiDomain = "generativelanguage.googleapis.com" geminiChatCompletionPath = "generateContent" geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse" geminiEmbeddingPath = "batchEmbedContents" + geminiModelsPath = "models" + geminiImageGenerationPath = "predict" ) type geminiProviderInitializer struct{} @@ -36,8 +39,10 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): "", - string(ApiNameEmbeddings): "", + string(ApiNameChatCompletion): "", + string(ApiNameEmbeddings): "", + string(ApiNameModels): "", + string(ApiNameImageGeneration): "", } } @@ -78,11 +83,38 @@ func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return g.onChatCompletionRequestBody(ctx, body, headers) - } else { + case ApiNameEmbeddings: return g.onEmbeddingsRequestBody(ctx, body, headers) + case ApiNameImageGeneration: + return g.onImageGenerationRequestBody(ctx, body, headers) } + return body, nil +} + +func (g *geminiProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { + request := &imageGenerationRequest{} + if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + path := g.getRequestPath(ApiNameImageGeneration, request.Model, false) + log.Debugf("request path:%s", path) + util.OverwriteRequestPathHeader(headers, path) + geminiRequest := g.buildGeminiImageGenerationRequest(request) + return json.Marshal(geminiRequest) +} + +func (g *geminiProvider) buildGeminiImageGenerationRequest(request *imageGenerationRequest) *geminiImageGenerationRequest { + geminiRequest := &geminiImageGenerationRequest{ + Instances: []geminiImageGenerationInstance{{Prompt: request.Prompt}}, + Parameters: &geminiImageGenerationParameters{ + SampleCount: request.N, + }, + } + + return geminiRequest } func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { @@ -111,7 +143,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ } func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { - log.Infof("chunk body:%s", string(chunk)) + log.Debugf("chunk body:%s", string(chunk)) if isLastChunk || len(chunk) == 0 { return nil, nil } @@ -147,14 +179,43 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A } func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { - if apiName == ApiNameChatCompletion { + switch apiName { + case ApiNameChatCompletion: return g.onChatCompletionResponseBody(ctx, body) - } else { + case ApiNameEmbeddings: return g.onEmbeddingsResponseBody(ctx, body) + case ApiNameImageGeneration: + return g.onImageGenerationResponseBody(ctx, body) + default: + return body, nil } } +func (g *geminiProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + geminiResponse := &geminiImageGenerationResponse{} + if err := json.Unmarshal(body, geminiResponse); err != nil { + return nil, fmt.Errorf("unable to unmarshal gemini image generation response: %v", err) + } + response := g.buildImageGenerationResponse(ctx, geminiResponse) + return json.Marshal(response) +} + +func (g *geminiProvider) buildImageGenerationResponse(ctx wrapper.HttpContext, geminiResponse *geminiImageGenerationResponse) *imageGenerationResponse { + data := make([]imageGenerationData, len(geminiResponse.Predictions)) + for i, prediction := range geminiResponse.Predictions { + data[i] = imageGenerationData{ + B64: prediction.BytesBase64Encoded, + } + } + response := &imageGenerationResponse{ + Created: time.Now().UnixMilli() / 1000, + Data: data, + } + return response +} + func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + log.Debugf("chat completion response body:%s", string(body)) geminiResponse := &geminiChatResponse{} if err := json.Unmarshal(body, geminiResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err) @@ -180,26 +241,37 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body return json.Marshal(response) } -func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { +func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bool) string { action := "" - if apiName == ApiNameEmbeddings { - action = geminiEmbeddingPath - } else if stream { - action = geminiChatCompletionStreamPath - } else { - action = geminiChatCompletionPath + if g.config.apiVersion == "" { + g.config.apiVersion = geminiDefaultApiVersion } - return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action) + switch apiName { + case ApiNameModels: + return fmt.Sprintf("/%s/%s", g.config.apiVersion, geminiModelsPath) + case ApiNameEmbeddings: + action = geminiEmbeddingPath + case ApiNameChatCompletion: + if stream { + action = geminiChatCompletionStreamPath + } else { + action = geminiChatCompletionPath + } + case ApiNameImageGeneration: + action = geminiImageGenerationPath + } + return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action) } -type geminiChatRequest struct { +type geminiGenerationContentRequest struct { // Model and Stream are only used when using the gemini original protocol - Model string `json:"model,omitempty"` - Stream bool `json:"stream,omitempty"` - Contents []geminiChatContent `json:"contents"` - SafetySettings []geminiChatSafetySetting `json:"safety_settings,omitempty"` - GenerationConfig geminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []geminiChatTools `json:"tools,omitempty"` + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + Contents []geminiChatContent `json:"contents"` + SystemInstruction *geminiChatContent `json:"system_instruction,omitempty"` + SafetySettings []geminiChatSafetySetting `json:"safetySettings,omitempty"` + GenerationConfig geminiChatGenerationConfig `json:"generationConfig,omitempty"` + Tools []geminiChatTools `json:"tools,omitempty"` } type geminiChatContent struct { @@ -212,13 +284,26 @@ type geminiChatSafetySetting struct { Threshold string `json:"threshold"` } +type geminiThinkingConfig struct { + IncludeThoughts bool `json:"includeThoughts,omitempty"` + ThinkingBudget int64 `json:"thinkingBudget,omitempty"` +} + type geminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int64 `json:"topK,omitempty"` + Seed int64 `json:"seed,omitempty"` + Logprobs bool `json:"logprobs,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + PresencePenalty int64 `json:"presencePenalty,omitempty"` + FrequencyPenalty int64 `json:"frequencyPenalty,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + NegativePrompt string `json:"negativePrompt,omitempty"` + ThinkingConfig *geminiThinkingConfig `json:"thinkingConfig,omitempty"` + MediaResolution string `json:"mediaResolution,omitempty"` } type geminiChatTools struct { @@ -241,25 +326,52 @@ type geminiFunctionCall struct { Arguments any `json:"args"` } -func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiChatRequest { +// geminiImageGenerationRequest is the request body for generate image using Imagen 3 +type geminiImageGenerationRequest struct { + Instances []geminiImageGenerationInstance `json:"instances"` + Parameters *geminiImageGenerationParameters `json:"parameters,omitempty"` +} + +type geminiImageGenerationInstance struct { + Prompt string `json:"prompt"` +} + +type geminiImageGenerationParameters struct { + SampleCount int `json:"sampleCount,omitempty"` + AspectRatio string `json:"aspectRatio,omitempty"` +} + +type geminiImageGenerationPrediction struct { + BytesBase64Encoded string `json:"bytesBase64Encoded"` + MimeType string `json:"mimeType"` +} + +type geminiImageGenerationResponse struct { + Predictions []geminiImageGenerationPrediction `json:"predictions"` +} + +func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiGenerationContentRequest { var safetySettings []geminiChatSafetySetting - { - } for category, threshold := range g.config.geminiSafetySetting { safetySettings = append(safetySettings, geminiChatSafetySetting{ Category: category, Threshold: threshold, }) } - geminiRequest := geminiChatRequest{ + geminiRequest := geminiGenerationContentRequest{ Contents: make([]geminiChatContent, 0, len(request.Messages)), SafetySettings: safetySettings, GenerationConfig: geminiChatGenerationConfig{ - Temperature: request.Temperature, - TopP: request.TopP, - MaxOutputTokens: request.MaxTokens, + Temperature: request.Temperature, + TopP: request.TopP, + MaxOutputTokens: request.MaxTokens, + PresencePenalty: int64(request.PresencePenalty), + FrequencyPenalty: int64(request.FrequencyPenalty), + Logprobs: request.Logprobs, + ResponseModalities: request.Modalities, }, } + if request.Tools != nil { functions := make([]function, 0, len(request.Tools)) for _, tool := range request.Tools { @@ -271,7 +383,7 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) }, } } - shouldAddDummyModelMessage := false + // shouldAddDummyModelMessage := false for _, message := range request.Messages { content := geminiChatContent{ Role: message.Role, @@ -283,32 +395,22 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) } // there's no assistant role in gemini and API shall vomit if role is not user or model - if content.Role == roleAssistant { + switch content.Role { + case roleSystem: + content.Role = "" + geminiRequest.SystemInstruction = &content + continue + case roleAssistant: content.Role = "model" - } else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason - content.Role = roleUser - shouldAddDummyModelMessage = true } geminiRequest.Contents = append(geminiRequest.Contents, content) - // if a system message is the last message, we need to add a dummy model message to make gemini happy - if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, geminiChatContent{ - Role: "model", - Parts: []geminiPart{ - { - Text: "Okay", - }, - }, - }) - shouldAddDummyModelMessage = false - } } return &geminiRequest } -func (g *geminiProvider) setSystemContent(request *geminiChatRequest, content string) { +func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) { systemContents := []geminiChatContent{{ Role: roleUser, Parts: []geminiPart{ @@ -398,32 +500,34 @@ func (g *geminiProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re Object: objectChatCompletion, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), - Choices: make([]chatCompletionChoice, 0, len(response.Candidates)), Usage: usage{ PromptTokens: response.UsageMetadata.PromptTokenCount, CompletionTokens: response.UsageMetadata.CandidatesTokenCount, TotalTokens: response.UsageMetadata.TotalTokenCount, }, } - for i, candidate := range response.Candidates { - choice := chatCompletionChoice{ - Index: i, - Message: &chatMessage{ - Role: roleAssistant, - }, - FinishReason: finishReasonStop, - } - if len(candidate.Content.Parts) > 0 { - if candidate.Content.Parts[0].FunctionCall != nil { - choice.Message.ToolCalls = g.buildToolCalls(&candidate) - } else { - choice.Message.Content = candidate.Content.Parts[0].Text + choiceIndex := 0 + for _, candidate := range response.Candidates { + for _, part := range candidate.Content.Parts { + choice := chatCompletionChoice{ + Index: choiceIndex, + Message: &chatMessage{ + Role: roleAssistant, + }, + FinishReason: finishReasonStop, } - } else { - choice.Message.Content = "" + if part.FunctionCall != nil { + choice.Message.ToolCalls = g.buildToolCalls(&candidate) + } else if part.InlineData != nil { + choice.Message.Content = part.InlineData.Data + } else { + choice.Message.Content = part.Text + } + choice.FinishReason = candidate.FinishReason + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + choiceIndex += 1 } - fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse } @@ -511,5 +615,8 @@ func (g *geminiProvider) GetApiName(path string) ApiName { if strings.Contains(path, geminiEmbeddingPath) { return ApiNameEmbeddings } + if strings.Contains(path, geminiImageGenerationPath) { + return ApiNameImageGeneration + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 81a4c0add..0cd55bdb3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -290,8 +290,8 @@ type ProviderConfig struct { // @Description zh-CN 配置一个外部获取对话上下文的文件来源,用于在AI请求中补充对话上下文 context *ContextConfig `required:"false" yaml:"context" json:"context"` // @Title zh-CN 版本 - // @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务 - claudeVersion string `required:"false" yaml:"version" json:"version"` + // @Description zh-CN 请求AI服务的版本,目前仅适用于 Gemini 和 Claude AI服务 + apiVersion string `required:"false" yaml:"apiVersion" json:"apiVersion"` // @Title zh-CN Cloudflare Account ID // @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"` @@ -375,7 +375,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.context = &ContextConfig{} c.context.FromJson(contextJson) } - c.claudeVersion = json.Get("claudeVersion").String() + + // 这里获取 claudeVersion 字段,与结构体中定义 yaml/json 的 tag 不一致 + c.apiVersion = json.Get("claudeVersion").String() + if c.apiVersion == "" { + // 增加获取 version 字段,用于适配其他模型的配置,并保持与结构体中定义的 tag 一致 + c.apiVersion = json.Get("apiVersion").String() + } c.hunyuanAuthId = json.Get("hunyuanAuthId").String() c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String() c.awsAccessKey = json.Get("awsAccessKey").String()