feat: support gemini ai model (#1173)

This commit is contained in:
韩贤涛
2024-08-09 09:55:40 +08:00
committed by GitHub
parent 564f8c770a
commit 04a9104062
5 changed files with 706 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ const (
providerTypeMinimax = "minimax"
providerTypeCloudflare = "cloudflare"
providerTypeSpark = "spark"
providerTypeGemini = "gemini"
protocolOpenAI = "openai"
protocolOriginal = "original"
@@ -86,6 +87,7 @@ var (
providerTypeMinimax: &minimaxProviderInitializer{},
providerTypeCloudflare: &cloudflareProviderInitializer{},
providerTypeSpark: &sparkProviderInitializer{},
providerTypeGemini: &geminiProviderInitializer{},
}
)
@@ -168,6 +170,9 @@ type ProviderConfig struct {
// @Title zh-CN Cloudflare Account ID
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
// @Title zh-CN Gemini AI内容过滤和安全级别设定
// @Description zh-CN 仅适用于 Gemini AI 服务。参考https://ai.google.dev/gemini-api/docs/safety-settings
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
@@ -208,6 +213,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
c.minimaxGroupId = json.Get("minimaxGroupId").String()
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
if c.typ == providerTypeGemini {
c.geminiSafetySetting = make(map[string]string)
for k, v := range json.Get("geminiSafetySetting").Map() {
c.geminiSafetySetting[k] = v.String()
}
}
}
func (c *ProviderConfig) Validate() error {