mirror of
https://github.com/alibaba/higress.git
synced 2026-06-01 08:37:26 +08:00
feature: allow ai-proxy to forward standard AI capabilities that are … (#1704)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@@ -12,14 +11,27 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type ApiName string
|
||||
type Pointcut string
|
||||
|
||||
const (
|
||||
ApiNameChatCompletion ApiName = "chatCompletion"
|
||||
ApiNameEmbeddings ApiName = "embeddings"
|
||||
|
||||
// ApiName 格式 {vendor}/{version}/{apitype}
|
||||
// 表示遵循 厂商/版本/接口类型 的格式
|
||||
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
|
||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
@@ -250,6 +262,12 @@ type ProviderConfig struct {
|
||||
inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"`
|
||||
// @Title zh-CN dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用
|
||||
outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"`
|
||||
// @Title zh-CN 额外支持的ai能力
|
||||
// @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"}
|
||||
capabilities map[string]string
|
||||
// @Title zh-CN 是否开启透传
|
||||
// @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false
|
||||
passthrough bool
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -361,12 +379,22 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.botType = json.Get("botType").String()
|
||||
c.inputVariable = json.Get("inputVariable").String()
|
||||
c.outputVariable = json.Get("outputVariable").String()
|
||||
|
||||
c.capabilities = make(map[string]string)
|
||||
for capability, pathJson := range json.Get("capabilities").Map() {
|
||||
// 过滤掉不受支持的能力
|
||||
switch capability {
|
||||
case string(ApiNameChatCompletion),
|
||||
string(ApiNameEmbeddings),
|
||||
string(ApiNameImageGeneration),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameCohereV1Rerank):
|
||||
c.capabilities[capability] = pathJson.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
if c.timeout < 0 {
|
||||
return errors.New("invalid timeout in config")
|
||||
}
|
||||
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
|
||||
return errors.New("invalid protocol in config")
|
||||
}
|
||||
@@ -425,6 +453,10 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
|
||||
return ReplaceByCustomSettings(body, c.customSettings)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) PassthroughUnsupportedAPI() bool {
|
||||
return c.passthrough
|
||||
}
|
||||
|
||||
func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
initializer, has := providerInitializers[pc.typ]
|
||||
if !has {
|
||||
@@ -499,7 +531,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo
|
||||
}
|
||||
|
||||
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
if modelMapping == nil || len(modelMapping) == 0 {
|
||||
if len(modelMapping) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -527,11 +559,22 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
|
||||
_, exist := c.capabilities[string(apiName)]
|
||||
return exist
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
|
||||
for capability, path := range capabilities {
|
||||
c.capabilities[capability] = path
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||||
) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.protocol == protocolOriginal {
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -578,17 +621,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
}
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
var request interface{}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
request = &chatCompletionRequest{}
|
||||
} else {
|
||||
request = &embeddingsRequest{}
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
if stream {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
ctx.SetContext(ctxKeyIsStreaming, true)
|
||||
} else {
|
||||
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||
}
|
||||
}
|
||||
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(request)
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
|
||||
Reference in New Issue
Block a user