feat: support ai-proxy custom settings (#1219)

This commit is contained in:
Pxl
2024-08-22 13:59:32 +08:00
committed by GitHub
parent 0e58042fa6
commit 29fcd330d5
7 changed files with 409 additions and 4 deletions

View File

@@ -0,0 +1,137 @@
package provider
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
nameMaxTokens = "max_tokens"
nameTemperature = "temperature"
nameTopP = "top_p"
nameTopK = "top_k"
nameSeed = "seed"
)
var maxTokensMapping = map[string]string{
"openai": "max_tokens",
"baidu": "max_output_tokens",
"spark": "max_tokens",
"qwen": "max_tokens",
"gemini": "maxOutputTokens",
"claude": "max_tokens",
"minimax": "tokens_to_generate",
}
var temperatureMapping = map[string]string{
"openai": "temperature",
"baidu": "temperature",
"spark": "temperature",
"qwen": "temperature",
"gemini": "temperature",
"hunyuan": "Temperature",
"claude": "temperature",
"minimax": "temperature",
}
var topPMapping = map[string]string{
"openai": "top_p",
"baidu": "top_p",
"qwen": "top_p",
"gemini": "topP",
"hunyuan": "TopP",
"claude": "top_p",
"minimax": "top_p",
}
var topKMapping = map[string]string{
"spark": "top_k",
"gemini": "topK",
"claude": "top_k",
}
var seedMapping = map[string]string{
"openai": "seed",
"qwen": "seed",
}
var settingMapping = map[string]map[string]string{
nameMaxTokens: maxTokensMapping,
nameTemperature: temperatureMapping,
nameTopP: topPMapping,
nameTopK: topKMapping,
nameSeed: seedMapping,
}
type CustomSetting struct {
// @Title zh-CN 参数名称
// @Description zh-CN 想要设置的参数的名称例如max_tokens
name string
// @Title zh-CN 参数值
// @Description zh-CN 想要设置的参数的值例如0
value string
// @Title zh-CN 设置模式
// @Description zh-CN 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会根据 /plugins/wasm-go/extensions/ai-proxy/README.md中关于custom-setting部分的表格自动按照协议对参数名做改写如果为"raw"则不会有任何改写和限制检查
mode string
// @Title zh-CN json edit 模式
// @Description zh-CN 如果为false则只在用户没有设置这个参数时填充参数否则会直接覆盖用户原有的参数设置
overwrite bool
}
func (c *CustomSetting) FromJson(json gjson.Result) {
c.name = json.Get("name").String()
c.value = json.Get("value").Raw
if obj := json.Get("mode"); obj.Exists() {
c.mode = obj.String()
} else {
c.mode = "auto"
}
if obj := json.Get("overwrite"); obj.Exists() {
c.overwrite = obj.Bool()
} else {
c.overwrite = true
}
}
func (c *CustomSetting) Validate() bool {
return c.name != ""
}
func (c *CustomSetting) setInvalid() {
c.name = "" // set empty to represent invalid
}
func (c *CustomSetting) AdjustWithProtocol(protocol string) {
if !(c.mode == "raw") {
mapping, ok := settingMapping[c.name]
if ok {
c.name, ok = mapping[protocol]
}
if !ok {
c.setInvalid()
return
}
}
if protocol == providerTypeQwen {
c.name = "parameters." + c.name
}
if protocol == providerTypeGemini {
c.name = "generation_config." + c.name
}
}
func ReplaceByCustomSettings(body []byte, settings []CustomSetting) ([]byte, error) {
var err error
strBody := string(body)
for _, setting := range settings {
if !setting.overwrite && gjson.Get(strBody, setting.name).Exists() {
continue
}
strBody, err = sjson.SetRaw(strBody, setting.name, setting.value)
if err != nil {
break
}
}
return []byte(strBody), err
}

View File

@@ -184,6 +184,9 @@ type ProviderConfig struct {
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
@@ -239,6 +242,25 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.responseJsonSchema = nil
}
c.customSettings = make([]CustomSetting, 0)
customSettingsJson := json.Get("customSettings")
if customSettingsJson.Exists() {
protocol := protocolOpenAI
if c.protocol == protocolOriginal {
// use provider name to represent original protocol name
protocol = c.typ
}
for _, settingJson := range customSettingsJson.Array() {
setting := CustomSetting{}
setting.FromJson(settingJson)
// use protocol info to rewrite setting
setting.AdjustWithProtocol(protocol)
if setting.Validate() {
c.customSettings = append(c.customSettings, setting)
}
}
}
}
func (c *ProviderConfig) Validate() error {
@@ -324,3 +346,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}