mirror of
https://github.com/alibaba/higress.git
synced 2026-03-08 02:30:56 +08:00
575 lines
22 KiB
Go
575 lines
22 KiB
Go
package provider
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"math/rand"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
type ApiName string
|
||
type Pointcut string
|
||
|
||
const (
|
||
ApiNameChatCompletion ApiName = "chatCompletion"
|
||
ApiNameEmbeddings ApiName = "embeddings"
|
||
|
||
providerTypeMoonshot = "moonshot"
|
||
providerTypeAzure = "azure"
|
||
providerTypeAi360 = "ai360"
|
||
providerTypeGithub = "github"
|
||
providerTypeQwen = "qwen"
|
||
providerTypeOpenAI = "openai"
|
||
providerTypeGroq = "groq"
|
||
providerTypeBaichuan = "baichuan"
|
||
providerTypeYi = "yi"
|
||
providerTypeDeepSeek = "deepseek"
|
||
providerTypeZhipuAi = "zhipuai"
|
||
providerTypeOllama = "ollama"
|
||
providerTypeClaude = "claude"
|
||
providerTypeBaidu = "baidu"
|
||
providerTypeHunyuan = "hunyuan"
|
||
providerTypeStepfun = "stepfun"
|
||
providerTypeMinimax = "minimax"
|
||
providerTypeCloudflare = "cloudflare"
|
||
providerTypeSpark = "spark"
|
||
providerTypeGemini = "gemini"
|
||
providerTypeDeepl = "deepl"
|
||
providerTypeMistral = "mistral"
|
||
providerTypeCohere = "cohere"
|
||
providerTypeDoubao = "doubao"
|
||
providerTypeCoze = "coze"
|
||
|
||
protocolOpenAI = "openai"
|
||
protocolOriginal = "original"
|
||
|
||
roleSystem = "system"
|
||
roleAssistant = "assistant"
|
||
roleUser = "user"
|
||
|
||
finishReasonStop = "stop"
|
||
finishReasonLength = "length"
|
||
|
||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||
ctxKeyApiName = "apiKey"
|
||
ctxKeyStreamingBody = "streamingBody"
|
||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||
ctxKeyFinalRequestModel = "finalRequestModel"
|
||
ctxKeyPushedMessage = "pushedMessage"
|
||
|
||
objectChatCompletion = "chat.completion"
|
||
objectChatCompletionChunk = "chat.completion.chunk"
|
||
|
||
wildcard = "*"
|
||
|
||
defaultTimeout = 2 * 60 * 1000 // ms
|
||
)
|
||
|
||
type providerInitializer interface {
|
||
ValidateConfig(ProviderConfig) error
|
||
CreateProvider(ProviderConfig) (Provider, error)
|
||
}
|
||
|
||
var (
|
||
errUnsupportedApiName = errors.New("unsupported API name")
|
||
|
||
providerInitializers = map[string]providerInitializer{
|
||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||
providerTypeAzure: &azureProviderInitializer{},
|
||
providerTypeAi360: &ai360ProviderInitializer{},
|
||
providerTypeGithub: &githubProviderInitializer{},
|
||
providerTypeQwen: &qwenProviderInitializer{},
|
||
providerTypeOpenAI: &openaiProviderInitializer{},
|
||
providerTypeGroq: &groqProviderInitializer{},
|
||
providerTypeBaichuan: &baichuanProviderInitializer{},
|
||
providerTypeYi: &yiProviderInitializer{},
|
||
providerTypeDeepSeek: &deepseekProviderInitializer{},
|
||
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
|
||
providerTypeOllama: &ollamaProviderInitializer{},
|
||
providerTypeClaude: &claudeProviderInitializer{},
|
||
providerTypeBaidu: &baiduProviderInitializer{},
|
||
providerTypeHunyuan: &hunyuanProviderInitializer{},
|
||
providerTypeStepfun: &stepfunProviderInitializer{},
|
||
providerTypeMinimax: &minimaxProviderInitializer{},
|
||
providerTypeCloudflare: &cloudflareProviderInitializer{},
|
||
providerTypeSpark: &sparkProviderInitializer{},
|
||
providerTypeGemini: &geminiProviderInitializer{},
|
||
providerTypeDeepl: &deeplProviderInitializer{},
|
||
providerTypeMistral: &mistralProviderInitializer{},
|
||
providerTypeCohere: &cohereProviderInitializer{},
|
||
providerTypeDoubao: &doubaoProviderInitializer{},
|
||
providerTypeCoze: &cozeProviderInitializer{},
|
||
}
|
||
)
|
||
|
||
type Provider interface {
|
||
GetProviderType() string
|
||
}
|
||
|
||
type ApiNameHandler interface {
|
||
GetApiName(path string) ApiName
|
||
}
|
||
|
||
type RequestHeadersHandler interface {
|
||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||
}
|
||
|
||
type TransformRequestHeadersHandler interface {
|
||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||
}
|
||
|
||
type RequestBodyHandler interface {
|
||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||
}
|
||
|
||
type TransformRequestBodyHandler interface {
|
||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||
}
|
||
|
||
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
|
||
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
|
||
type TransformRequestBodyHeadersHandler interface {
|
||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||
}
|
||
|
||
type ResponseHeadersHandler interface {
|
||
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||
}
|
||
|
||
type StreamingResponseBodyHandler interface {
|
||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||
}
|
||
|
||
type ResponseBodyHandler interface {
|
||
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||
}
|
||
|
||
// TickFuncHandler allows the provider to execute a function periodically
|
||
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
|
||
type TickFuncHandler interface {
|
||
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
|
||
}
|
||
|
||
type ProviderConfig struct {
|
||
// @Title zh-CN ID
|
||
// @Description zh-CN AI服务提供商标识
|
||
id string `required:"true" yaml:"id" json:"id"`
|
||
// @Title zh-CN 类型
|
||
// @Description zh-CN AI服务提供商类型
|
||
typ string `required:"true" yaml:"type" json:"type"`
|
||
// @Title zh-CN API Tokens
|
||
// @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。
|
||
apiTokens []string `required:"false" yaml:"apiToken" json:"apiTokens"`
|
||
// @Title zh-CN 请求超时
|
||
// @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟
|
||
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
|
||
// @Title zh-CN apiToken 故障切换
|
||
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
|
||
failover *failover `required:"false" yaml:"failover" json:"failover"`
|
||
// @Title zh-CN 基于OpenAI协议的自定义后端URL
|
||
// @Description zh-CN 仅适用于支持 openai 协议的服务。
|
||
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
|
||
// @Title zh-CN Moonshot File ID
|
||
// @Description zh-CN 仅适用于Moonshot AI服务。Moonshot AI服务的文件ID,其内容用于补充AI请求上下文
|
||
moonshotFileId string `required:"false" yaml:"moonshotFileId" json:"moonshotFileId"`
|
||
// @Title zh-CN Azure OpenAI Service URL
|
||
// @Description zh-CN 仅适用于Azure OpenAI服务。要请求的OpenAI服务的完整URL,包含api-version等参数
|
||
azureServiceUrl string `required:"false" yaml:"azureServiceUrl" json:"azureServiceUrl"`
|
||
// @Title zh-CN 通义千问File ID
|
||
// @Description zh-CN 仅适用于通义千问服务。上传到Dashscope的文件ID,其内容用于补充AI请求上下文。仅支持qwen-long模型。
|
||
qwenFileIds []string `required:"false" yaml:"qwenFileIds" json:"qwenFileIds"`
|
||
// @Title zh-CN 启用通义千问搜索服务
|
||
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
|
||
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
|
||
// @Title zh-CN 开启通义千问兼容模式
|
||
// @Description zh-CN 启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。
|
||
qwenEnableCompatible bool `required:"false" yaml:"qwenEnableCompatible" json:"qwenEnableCompatible"`
|
||
// @Title zh-CN Ollama Server IP/Domain
|
||
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。
|
||
ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"`
|
||
// @Title zh-CN Ollama Server Port
|
||
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的端口号。
|
||
ollamaServerPort uint32 `required:"false" yaml:"ollamaServerPort" json:"ollamaServerPort"`
|
||
// @Title zh-CN hunyuan api key for authorization
|
||
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权,API key/id 参考:https://cloud.tencent.com/document/api/1729/101843#Golang
|
||
hunyuanAuthKey string `required:"false" yaml:"hunyuanAuthKey" json:"hunyuanAuthKey"`
|
||
// @Title zh-CN hunyuan api id for authorization
|
||
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
|
||
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
|
||
// @Title zh-CN minimax group id
|
||
// @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型
|
||
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
|
||
// @Title zh-CN 模型名称映射表
|
||
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
|
||
modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"`
|
||
// @Title zh-CN 对外接口协议
|
||
// @Description zh-CN 通过本插件对外提供的AI服务接口协议。默认值为“openai”,即OpenAI的接口协议。如需保留原有接口协议,可配置为“original"
|
||
protocol string `required:"false" yaml:"protocol" json:"protocol"`
|
||
// @Title zh-CN 模型对话上下文
|
||
// @Description zh-CN 配置一个外部获取对话上下文的文件来源,用于在AI请求中补充对话上下文
|
||
context *ContextConfig `required:"false" yaml:"context" json:"context"`
|
||
// @Title zh-CN 版本
|
||
// @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务
|
||
claudeVersion string `required:"false" yaml:"version" json:"version"`
|
||
// @Title zh-CN Cloudflare Account ID
|
||
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
|
||
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
|
||
// @Title zh-CN Gemini AI内容过滤和安全级别设定
|
||
// @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings
|
||
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
|
||
// @Title zh-CN 翻译服务需指定的目标语种
|
||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
|
||
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考:https://platform.openai.com/docs/guides/structured-outputs
|
||
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
|
||
// @Title zh-CN 自定义大模型参数配置
|
||
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
|
||
customSettings []CustomSetting
|
||
// @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken
|
||
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
|
||
// @Title zh-CN 请求刷新百度 apiToken 服务名称
|
||
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
|
||
// @Title zh-CN 请求刷新百度 apiToken 服务域名
|
||
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
|
||
// @Title zh-CN 请求刷新百度 apiToken 服务端口
|
||
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
|
||
// @Title zh-CN 是否使用全局的 apiToken
|
||
// @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新
|
||
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
|
||
}
|
||
|
||
func (c *ProviderConfig) GetId() string {
|
||
return c.id
|
||
}
|
||
|
||
func (c *ProviderConfig) GetType() string {
|
||
return c.typ
|
||
}
|
||
|
||
func (c *ProviderConfig) GetProtocol() string {
|
||
return c.protocol
|
||
}
|
||
|
||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||
c.id = json.Get("id").String()
|
||
c.typ = json.Get("type").String()
|
||
c.apiTokens = make([]string, 0)
|
||
for _, token := range json.Get("apiTokens").Array() {
|
||
c.apiTokens = append(c.apiTokens, token.String())
|
||
}
|
||
c.timeout = uint32(json.Get("timeout").Uint())
|
||
if c.timeout == 0 {
|
||
c.timeout = defaultTimeout
|
||
}
|
||
c.openaiCustomUrl = json.Get("openaiCustomUrl").String()
|
||
c.moonshotFileId = json.Get("moonshotFileId").String()
|
||
c.azureServiceUrl = json.Get("azureServiceUrl").String()
|
||
c.qwenFileIds = make([]string, 0)
|
||
for _, fileId := range json.Get("qwenFileIds").Array() {
|
||
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
|
||
}
|
||
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
|
||
c.qwenEnableCompatible = json.Get("qwenEnableCompatible").Bool()
|
||
c.ollamaServerHost = json.Get("ollamaServerHost").String()
|
||
c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint())
|
||
c.modelMapping = make(map[string]string)
|
||
for k, v := range json.Get("modelMapping").Map() {
|
||
c.modelMapping[k] = v.String()
|
||
}
|
||
c.protocol = json.Get("protocol").String()
|
||
if c.protocol == "" {
|
||
c.protocol = protocolOpenAI
|
||
}
|
||
contextJson := json.Get("context")
|
||
if contextJson.Exists() {
|
||
c.context = &ContextConfig{}
|
||
c.context.FromJson(contextJson)
|
||
}
|
||
c.claudeVersion = json.Get("claudeVersion").String()
|
||
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
|
||
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
|
||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
|
||
if c.typ == providerTypeGemini {
|
||
c.geminiSafetySetting = make(map[string]string)
|
||
for k, v := range json.Get("geminiSafetySetting").Map() {
|
||
c.geminiSafetySetting[k] = v.String()
|
||
}
|
||
}
|
||
c.targetLang = json.Get("targetLang").String()
|
||
|
||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||
c.responseJsonSchema = schemaValue
|
||
} else {
|
||
c.responseJsonSchema = nil
|
||
}
|
||
|
||
c.customSettings = make([]CustomSetting, 0)
|
||
customSettingsJson := json.Get("customSettings")
|
||
if customSettingsJson.Exists() {
|
||
protocol := protocolOpenAI
|
||
if c.protocol == protocolOriginal {
|
||
// use provider name to represent original protocol name
|
||
protocol = c.typ
|
||
}
|
||
for _, settingJson := range customSettingsJson.Array() {
|
||
setting := CustomSetting{}
|
||
setting.FromJson(settingJson)
|
||
// use protocol info to rewrite setting
|
||
setting.AdjustWithProtocol(protocol)
|
||
if setting.Validate() {
|
||
c.customSettings = append(c.customSettings, setting)
|
||
}
|
||
}
|
||
}
|
||
|
||
failoverJson := json.Get("failover")
|
||
c.failover = &failover{
|
||
enabled: false,
|
||
}
|
||
if failoverJson.Exists() {
|
||
c.failover.FromJson(failoverJson)
|
||
}
|
||
|
||
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
||
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
||
}
|
||
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
|
||
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
|
||
if c.baiduApiTokenServiceHost == "" {
|
||
c.baiduApiTokenServiceHost = baiduApiTokenDomain
|
||
}
|
||
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
|
||
if c.baiduApiTokenServicePort == 0 {
|
||
c.baiduApiTokenServicePort = baiduApiTokenPort
|
||
}
|
||
}
|
||
|
||
func (c *ProviderConfig) Validate() error {
|
||
if c.timeout < 0 {
|
||
return errors.New("invalid timeout in config")
|
||
}
|
||
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
|
||
return errors.New("invalid protocol in config")
|
||
}
|
||
if c.context != nil {
|
||
if err := c.context.Validate(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if c.failover.enabled {
|
||
if err := c.failover.Validate(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if c.typ == "" {
|
||
return errors.New("missing type in provider config")
|
||
}
|
||
initializer, has := providerInitializers[c.typ]
|
||
if !has {
|
||
return errors.New("unknown provider type: " + c.typ)
|
||
}
|
||
if err := initializer.ValidateConfig(*c); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||
ctxApiKey := ctx.GetContext(ctxKeyApiName)
|
||
if ctxApiKey == nil {
|
||
ctxApiKey = c.GetRandomToken()
|
||
ctx.SetContext(ctxKeyApiName, ctxApiKey)
|
||
}
|
||
return ctxApiKey.(string)
|
||
}
|
||
|
||
func (c *ProviderConfig) GetRandomToken() string {
|
||
apiTokens := c.apiTokens
|
||
count := len(apiTokens)
|
||
switch count {
|
||
case 0:
|
||
return ""
|
||
case 1:
|
||
return apiTokens[0]
|
||
default:
|
||
return apiTokens[rand.Intn(count)]
|
||
}
|
||
}
|
||
|
||
func (c *ProviderConfig) IsOriginal() bool {
|
||
return c.protocol == protocolOriginal
|
||
}
|
||
|
||
func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
|
||
return ReplaceByCustomSettings(body, c.customSettings)
|
||
}
|
||
|
||
func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||
initializer, has := providerInitializers[pc.typ]
|
||
if !has {
|
||
return nil, errors.New("unknown provider type: " + pc.typ)
|
||
}
|
||
return initializer.CreateProvider(pc)
|
||
}
|
||
|
||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
|
||
switch req := request.(type) {
|
||
case *chatCompletionRequest:
|
||
if err := decodeChatCompletionRequest(body, req); err != nil {
|
||
return err
|
||
}
|
||
|
||
streaming := req.Stream
|
||
if streaming {
|
||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||
}
|
||
|
||
return c.setRequestModel(ctx, req, log)
|
||
case *embeddingsRequest:
|
||
if err := decodeEmbeddingsRequest(body, req); err != nil {
|
||
return err
|
||
}
|
||
return c.setRequestModel(ctx, req, log)
|
||
default:
|
||
return errors.New("unsupported request type")
|
||
}
|
||
}
|
||
|
||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
|
||
var model *string
|
||
|
||
switch req := request.(type) {
|
||
case *chatCompletionRequest:
|
||
model = &req.Model
|
||
case *embeddingsRequest:
|
||
model = &req.Model
|
||
default:
|
||
return errors.New("unsupported request type")
|
||
}
|
||
|
||
return c.mapModel(ctx, model, log)
|
||
}
|
||
|
||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
|
||
if *model == "" {
|
||
return errors.New("missing model in request")
|
||
}
|
||
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
|
||
|
||
mappedModel := getMappedModel(*model, c.modelMapping, log)
|
||
if mappedModel == "" {
|
||
return errors.New("model becomes empty after applying the configured mapping")
|
||
}
|
||
|
||
*model = mappedModel
|
||
ctx.SetContext(ctxKeyFinalRequestModel, *model)
|
||
return nil
|
||
}
|
||
|
||
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||
mappedModel := doGetMappedModel(model, modelMapping, log)
|
||
if len(mappedModel) != 0 {
|
||
return mappedModel
|
||
}
|
||
return model
|
||
}
|
||
|
||
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||
if modelMapping == nil || len(modelMapping) == 0 {
|
||
return ""
|
||
}
|
||
|
||
if v, ok := modelMapping[model]; ok {
|
||
log.Debugf("model [%s] is mapped to [%s] explictly", model, v)
|
||
return v
|
||
}
|
||
|
||
for k, v := range modelMapping {
|
||
if k == wildcard || !strings.HasSuffix(k, wildcard) {
|
||
continue
|
||
}
|
||
k = strings.TrimSuffix(k, wildcard)
|
||
if strings.HasPrefix(model, k) {
|
||
log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k)
|
||
return v
|
||
}
|
||
}
|
||
|
||
if v, ok := modelMapping[wildcard]; ok {
|
||
log.Debugf("model [%s] is mapped to [%s] via wildcard", model, v)
|
||
return v
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
func (c *ProviderConfig) handleRequestBody(
|
||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||
) (types.Action, error) {
|
||
// use original protocol
|
||
if c.protocol == protocolOriginal {
|
||
return types.ActionContinue, nil
|
||
}
|
||
|
||
// use openai protocol
|
||
var err error
|
||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||
headers := util.GetOriginalHttpHeaders()
|
||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||
util.ReplaceOriginalHttpHeaders(headers)
|
||
} else {
|
||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||
}
|
||
|
||
if err != nil {
|
||
return types.ActionContinue, err
|
||
}
|
||
|
||
if apiName == ApiNameChatCompletion {
|
||
if c.context == nil {
|
||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||
}
|
||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||
|
||
if err == nil {
|
||
return types.ActionPause, nil
|
||
}
|
||
return types.ActionContinue, err
|
||
}
|
||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||
}
|
||
|
||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||
originalHeaders := util.GetOriginalHttpHeaders()
|
||
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
|
||
util.ReplaceOriginalHttpHeaders(originalHeaders)
|
||
}
|
||
}
|
||
|
||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||
var request interface{}
|
||
if apiName == ApiNameChatCompletion {
|
||
request = &chatCompletionRequest{}
|
||
} else {
|
||
request = &embeddingsRequest{}
|
||
}
|
||
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||
return nil, err
|
||
}
|
||
return json.Marshal(request)
|
||
}
|