mirror of
https://github.com/alibaba/higress.git
synced 2026-03-07 18:10:54 +08:00
feature: allow ai-proxy to forward standard AI capabilities that are … (#1704)
This commit is contained in:
@@ -42,7 +42,8 @@ description: AI 代理插件配置参考
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||
|
||||
| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
|
||||
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 |
|
||||
`context`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|
||||
@@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
|
||||
rawPath := ctx.Path()
|
||||
path, _ := url.Parse(rawPath)
|
||||
apiName := getOpenAiApiName(path.Path)
|
||||
apiName := getApiName(path.Path)
|
||||
providerConfig := pluginConfig.GetProviderConfig()
|
||||
if providerConfig.IsOriginal() {
|
||||
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
|
||||
@@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
// Set the apiToken for the current request.
|
||||
providerConfig.SetApiTokenInUse(ctx, log)
|
||||
|
||||
hasRequestBody := wrapper.HasRequestBody()
|
||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
if err == nil {
|
||||
if hasRequestBody {
|
||||
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
// Delay the header processing to allow changing in OnRequestBody
|
||||
return types.HeaderStopIteration
|
||||
if err != nil {
|
||||
if providerConfig.PassthroughUnsupportedAPI() {
|
||||
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.DontReadRequestBody()
|
||||
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||
hasRequestBody := wrapper.HasRequestBody()
|
||||
if hasRequestBody {
|
||||
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
// Delay the header processing to allow changing in OnRequestBody
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
@@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
|
||||
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue
|
||||
@@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAiApiName(path string) provider.ApiName {
|
||||
func getApiName(path string) provider.ApiName {
|
||||
// openai style
|
||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/embeddings") {
|
||||
return provider.ApiNameEmbeddings
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/audio/speech") {
|
||||
return provider.ApiNameAudioSpeech
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
||||
return provider.ApiNameImageGeneration
|
||||
}
|
||||
// cohere style
|
||||
if strings.HasSuffix(path, "/v1/rerank") {
|
||||
return provider.ApiNameCohereV1Rerank
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -22,6 +22,13 @@ type ai360Provider struct {
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
@@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
}
|
||||
|
||||
func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &ai360Provider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -50,7 +58,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
@@ -58,5 +66,6 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
|
||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
@@ -15,6 +15,14 @@ import (
|
||||
type azureProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.azureServiceUrl == "" {
|
||||
return errors.New("missing azureServiceUrl in provider config")
|
||||
@@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
|
||||
} else {
|
||||
serviceUrl = u
|
||||
}
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &azureProvider{
|
||||
config: config,
|
||||
serviceUrl: serviceUrl,
|
||||
@@ -54,7 +63,7 @@ func (m *azureProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -62,7 +71,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
|
||||
@@ -12,8 +12,7 @@ import (
|
||||
// baichuanProvider is the provider for baichuan Ai service.
|
||||
|
||||
const (
|
||||
baichuanDomain = "api.baichuan-ai.com"
|
||||
baichuanChatCompletionPath = "/v1/chat/completions"
|
||||
baichuanDomain = "api.baichuan-ai.com"
|
||||
)
|
||||
|
||||
type baichuanProviderInitializer struct {
|
||||
@@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &baichuanProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -43,7 +50,7 @@ func (m *baichuanProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -51,14 +58,14 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, baichuanDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
const (
|
||||
baiduDomain = "qianfan.baidubce.com"
|
||||
baiduChatCompletionPath = "/v2/chat/completions"
|
||||
baiduEmbeddings = "/v2/embeddings"
|
||||
)
|
||||
|
||||
type baiduProviderInitializer struct{}
|
||||
@@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): baiduChatCompletionPath,
|
||||
string(ApiNameEmbeddings): baiduEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
return &baiduProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -42,7 +51,7 @@ func (g *baiduProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
@@ -50,14 +59,14 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): claudeChatCompletionPath,
|
||||
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(c.DefaultCapabilities())
|
||||
return &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -102,7 +111,7 @@ func (c *claudeProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
@@ -110,7 +119,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
|
||||
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||
@@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
@@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
claudeResponse := &claudeTextGenResponse{}
|
||||
if err := json.Unmarshal(body, claudeResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
||||
@@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// only process the response from chat completion, skip other responses
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
|
||||
@@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(c.DefaultCapabilities())
|
||||
return &cloudflareProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -43,7 +49,7 @@ func (c *cloudflareProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
@@ -51,7 +57,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
|
||||
@@ -12,8 +12,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
cohereDomain = "api.cohere.com"
|
||||
cohereDomain = "api.cohere.com"
|
||||
// TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat
|
||||
cohereChatCompletionPath = "/v1/chat"
|
||||
cohereRerankPath = "/v1/rerank"
|
||||
)
|
||||
|
||||
type cohereProviderInitializer struct{}
|
||||
@@ -25,7 +27,15 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cohereProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): cohereChatCompletionPath,
|
||||
string(ApiNameCohereV1Rerank): cohereRerankPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &cohereProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -56,7 +66,7 @@ func (m *cohereProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -64,7 +74,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
@@ -90,13 +100,16 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, cohereDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -21,7 +21,12 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &cozeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
|
||||
@@ -64,7 +64,14 @@ func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *deeplProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): deeplChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(d.DefaultCapabilities())
|
||||
return &deeplProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -76,7 +83,7 @@ func (d *deeplProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !d.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
d.config.handleRequestHeaders(d, ctx, apiName, log)
|
||||
@@ -89,7 +96,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !d.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
|
||||
@@ -112,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
deeplResponse := &deeplResponse{}
|
||||
if err := json.Unmarshal(body, deeplResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
// deepseekProvider is the provider for deepseek Ai service.
|
||||
|
||||
const (
|
||||
deepseekDomain = "api.deepseek.com"
|
||||
deepseekDomain = "api.deepseek.com"
|
||||
// TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion
|
||||
// accourding to the docs, the path should be /chat/completions, need to be verified
|
||||
deepseekChatCompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
@@ -26,7 +28,14 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): deepseekChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &deepseekProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -43,7 +52,7 @@ func (m *deepseekProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -51,14 +60,14 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, deepseekDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
}
|
||||
|
||||
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
err := d.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
@@ -95,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
|
||||
}
|
||||
|
||||
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
difyResponse := &DifyChatResponse{}
|
||||
if err := json.Unmarshal(body, difyResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal dify response: %v", err)
|
||||
@@ -146,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
// sample event response:
|
||||
// data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
const (
|
||||
doubaoDomain = "ark.cn-beijing.volces.com"
|
||||
doubaoChatCompletionPath = "/api/v3/chat/completions"
|
||||
doubaoEmbeddingsPath = "/api/v3/embeddings"
|
||||
)
|
||||
|
||||
type doubaoProviderInitializer struct{}
|
||||
@@ -24,7 +25,15 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *doubaoProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): doubaoChatCompletionPath,
|
||||
string(ApiNameEmbeddings): doubaoEmbeddingsPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &doubaoProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -41,7 +50,7 @@ func (m *doubaoProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -49,14 +58,14 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, doubaoDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
@@ -66,5 +75,8 @@ func (m *doubaoProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, doubaoChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, doubaoEmbeddingsPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -35,7 +35,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
return &geminiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -52,7 +57,7 @@ func (g *geminiProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
@@ -66,7 +71,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
@@ -110,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
// sample end event response:
|
||||
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
|
||||
responseBuilder := &strings.Builder{}
|
||||
|
||||
@@ -32,7 +32,15 @@ func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *githubProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): githubCompletionPath,
|
||||
string(ApiNameEmbeddings): githubEmbeddingPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &githubProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -44,7 +52,7 @@ func (m *githubProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -53,7 +61,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
@@ -61,12 +69,7 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
|
||||
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, githubDomain)
|
||||
if apiName == ApiNameChatCompletion {
|
||||
util.OverwriteRequestPathHeader(headers, githubCompletionPath)
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
|
||||
}
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,14 @@ func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *groqProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): groqChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
return &groqProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -42,7 +49,7 @@ func (g *groqProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
@@ -50,14 +57,14 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, groqChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, groqDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -39,6 +39,11 @@ const (
|
||||
|
||||
hunyuanAuthKeyLen = 32
|
||||
hunyuanAuthIdLen = 36
|
||||
|
||||
// docs: https://cloud.tencent.com/document/product/1729/111007
|
||||
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
|
||||
hunyuanOpenAiRequestPath = "/v1/chat/completions"
|
||||
hunyuanOpenAiEmbeddings = "/v1/embeddings"
|
||||
)
|
||||
|
||||
type hunyuanProviderInitializer struct {
|
||||
@@ -86,6 +91,10 @@ type hunyuanChatMessage struct {
|
||||
}
|
||||
|
||||
func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// 允许 hunyuanauthid 和 hunyuanauthkey 为空, 当他们都为空的时候,认为是使用openai的 兼容接口
|
||||
if len(config.hunyuanAuthId) == 0 && len(config.hunyuanAuthKey) == 0 {
|
||||
return nil
|
||||
}
|
||||
// 校验hunyuan id 和 key的合法性
|
||||
if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen {
|
||||
return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file")
|
||||
@@ -93,7 +102,15 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
|
||||
string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &hunyuanProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
@@ -114,8 +131,12 @@ func (m *hunyuanProvider) GetProviderType() string {
|
||||
return providerTypeHunyuan
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {
|
||||
return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -124,19 +145,27 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
|
||||
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
||||
|
||||
// 添加 hunyuan 需要的自定义字段
|
||||
headers.Set(actionKey, hunyuanChatCompletionTCAction)
|
||||
headers.Set(versionKey, versionValue)
|
||||
if m.useOpenAICompatibleAPI() {
|
||||
util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
} else {
|
||||
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
|
||||
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
||||
// 添加 hunyuan 需要的自定义字段
|
||||
headers.Set(actionKey, hunyuanChatCompletionTCAction)
|
||||
headers.Set(versionKey, versionValue)
|
||||
}
|
||||
}
|
||||
|
||||
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
||||
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.useOpenAICompatibleAPI() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳)
|
||||
var timestamp int64 = time.Now().Unix()
|
||||
@@ -264,6 +293,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
|
||||
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
|
||||
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
if m.useOpenAICompatibleAPI() {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
@@ -289,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
@@ -405,6 +437,12 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
|
||||
return body, nil
|
||||
}
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
|
||||
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
|
||||
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ type minimaxProviderInitializer struct {
|
||||
func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// If using the chat completion Pro API, a group ID must be set.
|
||||
if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" {
|
||||
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro))
|
||||
return fmt.Errorf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)
|
||||
}
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
@@ -49,7 +49,15 @@ func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *minimaxProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// minimax 替换path的时候,要根据modelmapping替换,这儿的配置无实质作用,只是为了保持和其他provider的一致性
|
||||
string(ApiNameChatCompletion): minimaxChatCompletionV2Path,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &minimaxProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -66,7 +74,7 @@ func (m *minimaxProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -81,7 +89,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if minimaxApiTypePro == m.config.minimaxApiType {
|
||||
@@ -159,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
// Sample event response:
|
||||
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
|
||||
|
||||
@@ -192,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
|
||||
// TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
minimaxResp := &minimaxChatCompletionProResp{}
|
||||
if err := json.Unmarshal(body, minimaxResp); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
||||
@@ -268,18 +282,6 @@ type minimaxUsage struct {
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) parseModel(body []byte) (string, error) {
|
||||
var tempMap map[string]interface{}
|
||||
if err := json.Unmarshal(body, &tempMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
model, ok := tempMap["model"].(string)
|
||||
if !ok {
|
||||
return "", errors.New("missing model in chat completion request")
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) {
|
||||
if len(request.BotSettings) == 0 {
|
||||
request.BotSettings = []minimaxBotSetting{
|
||||
|
||||
@@ -22,7 +22,16 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &mistralProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -39,7 +48,7 @@ func (m *mistralProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -47,7 +56,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
|
||||
@@ -19,21 +19,21 @@ const (
|
||||
)
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
@@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) {
|
||||
}
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/images
|
||||
type imageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/speech-to-text
|
||||
type audioSpeechRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
Input interface{} `json:"input"`
|
||||
Model string `json:"model"`
|
||||
|
||||
@@ -34,7 +34,14 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): moonshotChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &moonshotProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
@@ -57,7 +64,7 @@ func (m *moonshotProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -65,7 +72,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, moonshotDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
@@ -74,9 +81,13 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
|
||||
// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法
|
||||
// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
|
||||
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 非chat类型的请求,不做处理
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
@@ -154,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
receivedBody := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
receivedBody = append(bufferedStreamingBody, chunk...)
|
||||
|
||||
@@ -12,10 +12,6 @@ import (
|
||||
|
||||
// ollamaProvider is the provider for Ollama service.
|
||||
|
||||
const (
|
||||
ollamaChatCompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
type ollamaProviderInitializer struct {
|
||||
}
|
||||
|
||||
@@ -29,9 +25,17 @@ func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ollamaProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// ollama的chat接口path和OpenAI的chat接口一样
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort)
|
||||
serviceDomain := config.ollamaServerHost + ":" + serverPortStr
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &ollamaProvider{
|
||||
config: config,
|
||||
serviceDomain: serviceDomain,
|
||||
@@ -50,7 +54,7 @@ func (m *ollamaProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -58,14 +62,14 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -26,6 +26,13 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
if config.openaiCustomUrl == "" {
|
||||
return &openaiProvider{
|
||||
@@ -38,6 +45,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
if len(pairs) != 2 {
|
||||
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
|
||||
}
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &openaiProvider{
|
||||
config: config,
|
||||
customDomain: pairs[0],
|
||||
@@ -64,13 +72,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
if m.customPath == "" {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
|
||||
case ApiNameEmbeddings:
|
||||
ctx.DontReadRequestBody()
|
||||
util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
|
||||
}
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
} else {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@@ -12,14 +11,27 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type ApiName string
|
||||
type Pointcut string
|
||||
|
||||
const (
|
||||
ApiNameChatCompletion ApiName = "chatCompletion"
|
||||
ApiNameEmbeddings ApiName = "embeddings"
|
||||
|
||||
// ApiName 格式 {vendor}/{version}/{apitype}
|
||||
// 表示遵循 厂商/版本/接口类型 的格式
|
||||
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
|
||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
@@ -250,6 +262,12 @@ type ProviderConfig struct {
|
||||
inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"`
|
||||
// @Title zh-CN dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用
|
||||
outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"`
|
||||
// @Title zh-CN 额外支持的ai能力
|
||||
// @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"}
|
||||
capabilities map[string]string
|
||||
// @Title zh-CN 是否开启透传
|
||||
// @Description zh-CN 如果是插件不支持的API,是否透传请求, 默认为false
|
||||
passthrough bool
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -361,12 +379,22 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.botType = json.Get("botType").String()
|
||||
c.inputVariable = json.Get("inputVariable").String()
|
||||
c.outputVariable = json.Get("outputVariable").String()
|
||||
|
||||
c.capabilities = make(map[string]string)
|
||||
for capability, pathJson := range json.Get("capabilities").Map() {
|
||||
// 过滤掉不受支持的能力
|
||||
switch capability {
|
||||
case string(ApiNameChatCompletion),
|
||||
string(ApiNameEmbeddings),
|
||||
string(ApiNameImageGeneration),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameCohereV1Rerank):
|
||||
c.capabilities[capability] = pathJson.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
if c.timeout < 0 {
|
||||
return errors.New("invalid timeout in config")
|
||||
}
|
||||
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
|
||||
return errors.New("invalid protocol in config")
|
||||
}
|
||||
@@ -425,6 +453,10 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
|
||||
return ReplaceByCustomSettings(body, c.customSettings)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) PassthroughUnsupportedAPI() bool {
|
||||
return c.passthrough
|
||||
}
|
||||
|
||||
func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
initializer, has := providerInitializers[pc.typ]
|
||||
if !has {
|
||||
@@ -499,7 +531,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo
|
||||
}
|
||||
|
||||
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
if modelMapping == nil || len(modelMapping) == 0 {
|
||||
if len(modelMapping) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -527,11 +559,22 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
|
||||
_, exist := c.capabilities[string(apiName)]
|
||||
return exist
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
|
||||
for capability, path := range capabilities {
|
||||
c.capabilities[capability] = path
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||||
) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.protocol == protocolOriginal {
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -578,17 +621,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
}
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
var request interface{}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
request = &chatCompletionRequest{}
|
||||
} else {
|
||||
request = &embeddingsRequest{}
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
if stream {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
ctx.SetContext(ctxKeyIsStreaming, true)
|
||||
} else {
|
||||
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||
}
|
||||
}
|
||||
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(request)
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
|
||||
@@ -52,7 +52,15 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): qwenChatCompletionPath,
|
||||
string(ApiNameEmbeddings): qwenTextEmbeddingPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &qwenProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -75,18 +83,19 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
if m.config.IsOriginal() {
|
||||
} else if m.config.qwenEnableCompatible {
|
||||
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
|
||||
} else if apiName == ApiNameChatCompletion {
|
||||
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
|
||||
} else if apiName == ApiNameChatCompletion || apiName == ApiNameEmbeddings {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return m.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
default:
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +104,7 @@ func (m *qwenProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
|
||||
@@ -140,7 +149,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
@@ -278,6 +287,9 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsResponseBody(ctx, body, log)
|
||||
}
|
||||
if m.config.isSupportedAPI(apiName) {
|
||||
return body, nil
|
||||
}
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,14 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): sparkChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(i.DefaultCapabilities())
|
||||
return &sparkProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -67,7 +74,7 @@ func (p *sparkProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !p.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
p.config.handleRequestHeaders(p, ctx, apiName, log)
|
||||
@@ -75,13 +82,16 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !p.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
sparkResponse := &sparkResponse{}
|
||||
if err := json.Unmarshal(body, sparkResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
|
||||
@@ -97,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
@@ -168,7 +181,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons
|
||||
}
|
||||
|
||||
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, sparkHost)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
@@ -24,7 +24,15 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// stepfun的chat接口path和OpenAI的chat接口一样
|
||||
string(ApiNameChatCompletion): stepfunChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &stepfunProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -41,7 +49,7 @@ func (m *stepfunProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -49,14 +57,14 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, stepfunDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -2,11 +2,12 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,7 +24,14 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): togetherAICompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &togetherAIProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -40,7 +48,7 @@ func (m *togetherAIProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -48,14 +56,14 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, togetherAICompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, togetherAIDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -24,7 +24,14 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *yiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): yiChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &yiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -41,7 +48,7 @@ func (m *yiProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -49,14 +56,14 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, yiChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, yiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
const (
|
||||
zhipuAiDomain = "open.bigmodel.cn"
|
||||
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
|
||||
zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings"
|
||||
)
|
||||
|
||||
type zhipuAiProviderInitializer struct{}
|
||||
@@ -24,7 +25,15 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): zhipuAiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): zhipuAiEmbeddingsPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &zhipuAiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -41,7 +50,7 @@ func (m *zhipuAiProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
@@ -49,14 +58,14 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
@@ -66,5 +75,8 @@ func (m *zhipuAiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, zhipuAiEmbeddingsPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -57,6 +57,17 @@ func OverwriteRequestPathHeader(headers http.Header, path string) {
|
||||
headers.Set(":path", path)
|
||||
}
|
||||
|
||||
func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) {
|
||||
mappedPath, exist := mapping[apiName]
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
||||
}
|
||||
headers.Set(":path", mappedPath)
|
||||
}
|
||||
|
||||
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
|
||||
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
|
||||
if originAuth := headers.Get("Authorization"); originAuth != "" {
|
||||
|
||||
Reference in New Issue
Block a user