Add cohere embedding for ai-cache (#1572)

This commit is contained in:
ayanami-desu
2024-12-27 17:48:44 +08:00
committed by GitHub
parent 6dc4d43df5
commit 2d74c48e8a
5 changed files with 239 additions and 41 deletions

View File

@@ -10,11 +10,13 @@ import (
const (
PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_TEXTIN = "textin"
PROVIDER_TYPE_COHERE = "cohere"
PROVIDER_TYPE_OPENAI = "openai"
)
type providerInitializer interface {
ValidateConfig(ProviderConfig) error
InitConfig(json gjson.Result)
ValidateConfig() error
CreateProvider(ProviderConfig) (Provider, error)
}
@@ -22,6 +24,7 @@ var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
PROVIDER_TYPE_COHERE: &cohereProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)
@@ -39,35 +42,26 @@ type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务端口
// @Description zh-CN 文本特征提取服务端口
servicePort int64
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
// @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间
timeout uint32
// @Title zh-CN 文本特征提取服务使用的模型
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
model string
initializer providerInitializer
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
i, has := providerInitializers[c.typ]
if has {
i.InitConfig(json)
c.initializer = i
}
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.textinAppId = json.Get("textinAppId").String()
c.textinSecretCode = json.Get("textinSecretCode").String()
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String()
if c.timeout == 0 {
@@ -82,11 +76,10 @@ func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("embedding service type is required")
}
initializer, has := providerInitializers[c.typ]
if !has {
if c.initializer == nil {
return errors.New("unknown embedding service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
if err := c.initializer.ValidateConfig(); err != nil {
return err
}
return nil