feature: allow ai-proxy to forward standard AI capabilities that are … (#1704)

This commit is contained in:
pepesi
2025-02-12 15:23:44 +08:00
committed by GitHub
parent 477e44b9f1
commit a84a382f1d
32 changed files with 517 additions and 158 deletions

View File

@@ -42,7 +42,8 @@ description: AI 代理插件配置参考
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式不需要重写可以直接转发通过此配置项指定来开启转发, key表示的是采用的厂商协议能力values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本允许任意api透传就像没有ai-proxy插件一样 |
`context`的配置字段说明如下: `context`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |

View File

@@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
rawPath := ctx.Path() rawPath := ctx.Path()
path, _ := url.Parse(rawPath) path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path) apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig() providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() { if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok { if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
@@ -103,9 +103,18 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request. // Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log) providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log) err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil { if err != nil {
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
ctx.DontReadRequestBody()
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody { if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length") proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
@@ -116,10 +125,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
return types.ActionContinue return types.ActionContinue
} }
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}
return types.ActionContinue return types.ActionContinue
} }
@@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil { if err == nil {
return action return action
} }
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
} }
return types.ActionContinue return types.ActionContinue
@@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
} }
} }
func getOpenAiApiName(path string) provider.ApiName { func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") { if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion return provider.ApiNameChatCompletion
} }
if strings.HasSuffix(path, "/v1/embeddings") { if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings return provider.ApiNameEmbeddings
} }
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
return "" return ""
} }

View File

@@ -22,6 +22,13 @@ type ai360Provider struct {
contextCache *contextCache contextCache *contextCache
} }
func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
@@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
} }
func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &ai360Provider{ return &ai360Provider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string {
} }
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -50,7 +58,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
} }
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
@@ -58,5 +66,6 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain) util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
} }

View File

@@ -15,6 +15,14 @@ import (
type azureProviderInitializer struct { type azureProviderInitializer struct {
} }
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" { if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config") return errors.New("missing azureServiceUrl in provider config")
@@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
} else { } else {
serviceUrl = u serviceUrl = u
} }
config.setDefaultCapabilities(m.DefaultCapabilities())
return &azureProvider{ return &azureProvider{
config: config, config: config,
serviceUrl: serviceUrl, serviceUrl: serviceUrl,
@@ -54,7 +63,7 @@ func (m *azureProvider) GetProviderType() string {
} }
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -62,7 +71,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
} }
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)

View File

@@ -13,7 +13,6 @@ import (
const ( const (
baichuanDomain = "api.baichuan-ai.com" baichuanDomain = "api.baichuan-ai.com"
baichuanChatCompletionPath = "/v1/chat/completions"
) )
type baichuanProviderInitializer struct { type baichuanProviderInitializer struct {
@@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil return nil
} }
func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &baichuanProvider{ return &baichuanProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -43,7 +50,7 @@ func (m *baichuanProvider) GetProviderType() string {
} }
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -51,14 +58,14 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
} }
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, baichuanDomain) util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -14,6 +14,7 @@ import (
const ( const (
baiduDomain = "qianfan.baidubce.com" baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions" baiduChatCompletionPath = "/v2/chat/completions"
baiduEmbeddings = "/v2/embeddings"
) )
type baiduProviderInitializer struct{} type baiduProviderInitializer struct{}
@@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): baiduChatCompletionPath,
string(ApiNameEmbeddings): baiduEmbeddings,
}
}
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &baiduProvider{ return &baiduProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -42,7 +51,7 @@ func (g *baiduProvider) GetProviderType() string {
} }
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
@@ -50,14 +59,14 @@ func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
} }
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
} }
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, baiduDomain) util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &claudeProvider{ return &claudeProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -102,7 +111,7 @@ func (c *claudeProvider) GetProviderType() string {
} }
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
c.config.handleRequestHeaders(c, ctx, apiName, log) c.config.handleRequestHeaders(c, ctx, apiName, log)
@@ -110,7 +119,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain) util.OverwriteRequestHostHeader(headers, claudeDomain)
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
@@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
} }
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
} }
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err return nil, err
@@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
} }
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
claudeResponse := &claudeTextGenResponse{} claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil { if err := json.Unmarshal(body, claudeResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
@@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
// only process the response from chat completion, skip other responses
if name != ApiNameChatCompletion {
return chunk, nil
}
responseBuilder := &strings.Builder{} responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n") lines := strings.Split(string(chunk), "\n")

View File

@@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
} }
return nil return nil
} }
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
}
}
func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(c.DefaultCapabilities())
return &cloudflareProvider{ return &cloudflareProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -43,7 +49,7 @@ func (c *cloudflareProvider) GetProviderType() string {
} }
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !c.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
c.config.handleRequestHeaders(c, ctx, apiName, log) c.config.handleRequestHeaders(c, ctx, apiName, log)
@@ -51,7 +57,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
} }
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)

View File

@@ -13,7 +13,9 @@ import (
const ( const (
cohereDomain = "api.cohere.com" cohereDomain = "api.cohere.com"
// TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat
cohereChatCompletionPath = "/v1/chat" cohereChatCompletionPath = "/v1/chat"
cohereRerankPath = "/v1/rerank"
) )
type cohereProviderInitializer struct{} type cohereProviderInitializer struct{}
@@ -25,7 +27,15 @@ func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (m *cohereProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): cohereChatCompletionPath,
string(ApiNameCohereV1Rerank): cohereRerankPath,
}
}
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &cohereProvider{ return &cohereProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -56,7 +66,7 @@ func (m *cohereProvider) GetProviderType() string {
} }
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -64,7 +74,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
@@ -90,13 +100,16 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
} }
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, cohereDomain) util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")
} }
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err return nil, err

View File

@@ -21,7 +21,12 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{}
}
func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &cozeProvider{ return &cozeProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),

View File

@@ -64,7 +64,14 @@ func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (d *deeplProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): deeplChatCompletionPath,
}
}
func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(d.DefaultCapabilities())
return &deeplProvider{ return &deeplProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -76,7 +83,7 @@ func (d *deeplProvider) GetProviderType() string {
} }
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !d.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
d.config.handleRequestHeaders(d, ctx, apiName, log) d.config.handleRequestHeaders(d, ctx, apiName, log)
@@ -89,7 +96,7 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
} }
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !d.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
@@ -112,6 +119,9 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
} }
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
deeplResponse := &deeplResponse{} deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil { if err := json.Unmarshal(body, deeplResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err) return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)

View File

@@ -13,6 +13,8 @@ import (
const ( const (
deepseekDomain = "api.deepseek.com" deepseekDomain = "api.deepseek.com"
// TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion
// accourding to the docs, the path should be /chat/completions, need to be verified
deepseekChatCompletionPath = "/v1/chat/completions" deepseekChatCompletionPath = "/v1/chat/completions"
) )
@@ -26,7 +28,14 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil return nil
} }
func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): deepseekChatCompletionPath,
}
}
func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *deepseekProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &deepseekProvider{ return &deepseekProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -43,7 +52,7 @@ func (m *deepseekProvider) GetProviderType() string {
} }
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -51,14 +60,14 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
} }
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, deepseekDomain) util.OverwriteRequestHostHeader(headers, deepseekDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -4,13 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
"time"
) )
const ( const (
@@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
} }
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
err := d.config.parseRequestAndMapModel(ctx, request, body, log) err := d.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil { if err != nil {
@@ -95,6 +99,9 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
} }
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
difyResponse := &DifyChatResponse{} difyResponse := &DifyChatResponse{}
if err := json.Unmarshal(body, difyResponse); err != nil { if err := json.Unmarshal(body, difyResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal dify response: %v", err) return nil, fmt.Errorf("unable to unmarshal dify response: %v", err)
@@ -146,6 +153,9 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample event response: // sample event response:
// data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} // data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"}

View File

@@ -13,6 +13,7 @@ import (
const ( const (
doubaoDomain = "ark.cn-beijing.volces.com" doubaoDomain = "ark.cn-beijing.volces.com"
doubaoChatCompletionPath = "/api/v3/chat/completions" doubaoChatCompletionPath = "/api/v3/chat/completions"
doubaoEmbeddingsPath = "/api/v3/embeddings"
) )
type doubaoProviderInitializer struct{} type doubaoProviderInitializer struct{}
@@ -24,7 +25,15 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (m *doubaoProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): doubaoChatCompletionPath,
string(ApiNameEmbeddings): doubaoEmbeddingsPath,
}
}
func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &doubaoProvider{ return &doubaoProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -41,7 +50,7 @@ func (m *doubaoProvider) GetProviderType() string {
} }
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -49,14 +58,14 @@ func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, doubaoDomain) util.OverwriteRequestHostHeader(headers, doubaoDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")
@@ -66,5 +75,8 @@ func (m *doubaoProvider) GetApiName(path string) ApiName {
if strings.Contains(path, doubaoChatCompletionPath) { if strings.Contains(path, doubaoChatCompletionPath) {
return ApiNameChatCompletion return ApiNameChatCompletion
} }
if strings.Contains(path, doubaoEmbeddingsPath) {
return ApiNameEmbeddings
}
return "" return ""
} }

View File

@@ -35,7 +35,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{}
}
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &geminiProvider{ return &geminiProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -52,7 +57,7 @@ func (g *geminiProvider) GetProviderType() string {
} }
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
@@ -66,7 +71,7 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
} }
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
@@ -110,6 +115,9 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
if name != ApiNameChatCompletion {
return chunk, nil
}
// sample end event response: // sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini一个大型多模态模型由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini一个大型多模态模型由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{} responseBuilder := &strings.Builder{}

View File

@@ -32,7 +32,15 @@ func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (m *githubProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): githubCompletionPath,
string(ApiNameEmbeddings): githubEmbeddingPath,
}
}
func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &githubProvider{ return &githubProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -44,7 +52,7 @@ func (m *githubProvider) GetProviderType() string {
} }
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -53,7 +61,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
@@ -61,12 +69,7 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, githubDomain) util.OverwriteRequestHostHeader(headers, githubDomain)
if apiName == ApiNameChatCompletion { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestPathHeader(headers, githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
} }

View File

@@ -25,7 +25,14 @@ func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
func (g *groqProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): groqChatCompletionPath,
}
}
func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &groqProvider{ return &groqProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -42,7 +49,7 @@ func (g *groqProvider) GetProviderType() string {
} }
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !g.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
@@ -50,14 +57,14 @@ func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
} }
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
} }
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, groqChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, groqDomain) util.OverwriteRequestHostHeader(headers, groqDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -39,6 +39,11 @@ const (
hunyuanAuthKeyLen = 32 hunyuanAuthKeyLen = 32
hunyuanAuthIdLen = 36 hunyuanAuthIdLen = 36
// docs: https://cloud.tencent.com/document/product/1729/111007
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
hunyuanOpenAiRequestPath = "/v1/chat/completions"
hunyuanOpenAiEmbeddings = "/v1/embeddings"
) )
type hunyuanProviderInitializer struct { type hunyuanProviderInitializer struct {
@@ -86,6 +91,10 @@ type hunyuanChatMessage struct {
} }
func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// 允许 hunyuanauthid 和 hunyuanauthkey 为空, 当他们都为空的时候认为是使用openai的 兼容接口
if len(config.hunyuanAuthId) == 0 && len(config.hunyuanAuthKey) == 0 {
return nil
}
// 校验hunyuan id 和 key的合法性 // 校验hunyuan id 和 key的合法性
if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen { if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen {
return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file") return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file")
@@ -93,7 +102,15 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro
return nil return nil
} }
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings,
}
}
func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *hunyuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &hunyuanProvider{ return &hunyuanProvider{
config: config, config: config,
client: wrapper.NewClusterClient(wrapper.RouteCluster{ client: wrapper.NewClusterClient(wrapper.RouteCluster{
@@ -114,8 +131,12 @@ func (m *hunyuanProvider) GetProviderType() string {
return providerTypeHunyuan return providerTypeHunyuan
} }
func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {
return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0
}
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -124,19 +145,27 @@ func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
} }
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.useOpenAICompatibleAPI() {
util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
} else {
util.OverwriteRequestHostHeader(headers, hunyuanDomain) util.OverwriteRequestHostHeader(headers, hunyuanDomain)
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段 // 添加 hunyuan 需要的自定义字段
headers.Set(actionKey, hunyuanChatCompletionTCAction) headers.Set(actionKey, hunyuanChatCompletionTCAction)
headers.Set(versionKey, versionValue) headers.Set(versionKey, versionValue)
}
} }
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
if m.useOpenAICompatibleAPI() {
return types.ActionContinue, nil
}
// 为header添加时间戳字段 因为需要根据body进行签名时依赖时间戳故于body处理部分创建时间戳 // 为header添加时间戳字段 因为需要根据body进行签名时依赖时间戳故于body处理部分创建时间戳
var timestamp int64 = time.Now().Unix() var timestamp int64 = time.Now().Unix()
@@ -264,6 +293,9 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 // hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if m.useOpenAICompatibleAPI() {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log) err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil { if err != nil {
@@ -289,7 +321,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
} }
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.protocol == protocolOriginal { if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion {
return chunk, nil return chunk, nil
} }
@@ -405,6 +437,12 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
} }
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
return body, nil
}
if apiName != ApiNameChatCompletion {
return body, nil
}
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
if err := json.Unmarshal(body, hunyuanResponse); err != nil { if err := json.Unmarshal(body, hunyuanResponse); err != nil {

View File

@@ -41,7 +41,7 @@ type minimaxProviderInitializer struct {
func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// If using the chat completion Pro API, a group ID must be set. // If using the chat completion Pro API, a group ID must be set.
if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" { if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" {
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)) return fmt.Errorf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)
} }
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
@@ -49,7 +49,15 @@ func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) erro
return nil return nil
} }
func (m *minimaxProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// minimax 替换path的时候要根据modelmapping替换这儿的配置无实质作用只是为了保持和其他provider的一致性
string(ApiNameChatCompletion): minimaxChatCompletionV2Path,
}
}
func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &minimaxProvider{ return &minimaxProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -66,7 +74,7 @@ func (m *minimaxProvider) GetProviderType() string {
} }
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -81,7 +89,7 @@ func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
} }
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
if minimaxApiTypePro == m.config.minimaxApiType { if minimaxApiTypePro == m.config.minimaxApiType {
@@ -159,6 +167,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
if name != ApiNameChatCompletion {
return chunk, nil
}
// Sample event response: // Sample event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
@@ -192,6 +203,9 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
// TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. // TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
minimaxResp := &minimaxChatCompletionProResp{} minimaxResp := &minimaxChatCompletionProResp{}
if err := json.Unmarshal(body, minimaxResp); err != nil { if err := json.Unmarshal(body, minimaxResp); err != nil {
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err) return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
@@ -268,18 +282,6 @@ type minimaxUsage struct {
CompletionTokens int64 `json:"completion_tokens"` CompletionTokens int64 `json:"completion_tokens"`
} }
func (m *minimaxProvider) parseModel(body []byte) (string, error) {
var tempMap map[string]interface{}
if err := json.Unmarshal(body, &tempMap); err != nil {
return "", err
}
model, ok := tempMap["model"].(string)
if !ok {
return "", errors.New("missing model in chat completion request")
}
return model, nil
}
func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) { func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) {
if len(request.BotSettings) == 0 { if len(request.BotSettings) == 0 {
request.BotSettings = []minimaxBotSetting{ request.BotSettings = []minimaxBotSetting{

View File

@@ -22,7 +22,16 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro
return nil return nil
} }
func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *mistralProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &mistralProvider{ return &mistralProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -39,7 +48,7 @@ func (m *mistralProvider) GetProviderType() string {
} }
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -47,7 +56,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
} }
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)

View File

@@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) {
} }
} }
// https://platform.openai.com/docs/guides/images
type imageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
// https://platform.openai.com/docs/guides/speech-to-text
type audioSpeechRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
}
type embeddingsRequest struct { type embeddingsRequest struct {
Input interface{} `json:"input"` Input interface{} `json:"input"`
Model string `json:"model"` Model string `json:"model"`

View File

@@ -34,7 +34,14 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err
return nil return nil
} }
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): moonshotChatCompletionPath,
}
}
func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *moonshotProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &moonshotProvider{ return &moonshotProvider{
config: config, config: config,
client: wrapper.NewClusterClient(wrapper.RouteCluster{ client: wrapper.NewClusterClient(wrapper.RouteCluster{
@@ -57,7 +64,7 @@ func (m *moonshotProvider) GetProviderType() string {
} }
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -65,7 +72,7 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
} }
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, moonshotDomain) util.OverwriteRequestHostHeader(headers, moonshotDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")
@@ -74,9 +81,13 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
// moonshot 有自己获取 context 的配置moonshotFileId因此无法复用 handleRequestBody 方法 // moonshot 有自己获取 context 的配置moonshotFileId因此无法复用 handleRequestBody 方法
// moonshot 的 body 没有修改无须实现TransformRequestBody使用默认的 defaultTransformRequestBody 方法 // moonshot 的 body 没有修改无须实现TransformRequestBody使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
// 非chat类型的请求不做处理
if apiName != ApiNameChatCompletion {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
@@ -154,6 +165,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
} }
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if name != ApiNameChatCompletion {
return chunk, nil
}
receivedBody := chunk receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...) receivedBody = append(bufferedStreamingBody, chunk...)

View File

@@ -12,10 +12,6 @@ import (
// ollamaProvider is the provider for Ollama service. // ollamaProvider is the provider for Ollama service.
const (
ollamaChatCompletionPath = "/v1/chat/completions"
)
type ollamaProviderInitializer struct { type ollamaProviderInitializer struct {
} }
@@ -29,9 +25,17 @@ func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (m *ollamaProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// ollama的chat接口path和OpenAI的chat接口一样
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
}
}
func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort) serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort)
serviceDomain := config.ollamaServerHost + ":" + serverPortStr serviceDomain := config.ollamaServerHost + ":" + serverPortStr
config.setDefaultCapabilities(m.DefaultCapabilities())
return &ollamaProvider{ return &ollamaProvider{
config: config, config: config,
serviceDomain: serviceDomain, serviceDomain: serviceDomain,
@@ -50,7 +54,7 @@ func (m *ollamaProvider) GetProviderType() string {
} }
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -58,14 +62,14 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, m.serviceDomain) util.OverwriteRequestHostHeader(headers, m.serviceDomain)
headers.Del("Content-Length") headers.Del("Content-Length")
} }

View File

@@ -26,6 +26,13 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
}
}
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl == "" { if config.openaiCustomUrl == "" {
return &openaiProvider{ return &openaiProvider{
@@ -38,6 +45,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
if len(pairs) != 2 { if len(pairs) != 2 {
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
} }
config.setDefaultCapabilities(m.DefaultCapabilities())
return &openaiProvider{ return &openaiProvider{
config: config, config: config,
customDomain: pairs[0], customDomain: pairs[0],
@@ -64,13 +72,7 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" { if m.customPath == "" {
switch apiName { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
case ApiNameChatCompletion:
util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
case ApiNameEmbeddings:
ctx.DontReadRequestBody()
util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
}
} else { } else {
util.OverwriteRequestPathHeader(headers, m.customPath) util.OverwriteRequestPathHeader(headers, m.customPath)
} }

View File

@@ -1,7 +1,6 @@
package provider package provider
import ( import (
"encoding/json"
"errors" "errors"
"math/rand" "math/rand"
"net/http" "net/http"
@@ -12,14 +11,27 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
type ApiName string type ApiName string
type Pointcut string type Pointcut string
const ( const (
ApiNameChatCompletion ApiName = "chatCompletion"
ApiNameEmbeddings ApiName = "embeddings" // ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
// TODO: 以下是一些非标准的API名称需要进一步确认是否支持
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
providerTypeMoonshot = "moonshot" providerTypeMoonshot = "moonshot"
providerTypeAzure = "azure" providerTypeAzure = "azure"
@@ -250,6 +262,12 @@ type ProviderConfig struct {
inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"` inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"`
// @Title zh-CN dify中应用类型为workflow时需要设置输出变量当botType为workflow时一起使用 // @Title zh-CN dify中应用类型为workflow时需要设置输出变量当botType为workflow时一起使用
outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"` outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"`
// @Title zh-CN 额外支持的ai能力
// @Description zh-CN 开放的ai能力和urlpath映射例如 {"openai/v1/chatcompletions": "/v1/chat/completions"}
capabilities map[string]string
// @Title zh-CN 是否开启透传
// @Description zh-CN 如果是插件不支持的API是否透传请求, 默认为false
passthrough bool
} }
func (c *ProviderConfig) GetId() string { func (c *ProviderConfig) GetId() string {
@@ -361,12 +379,22 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.botType = json.Get("botType").String() c.botType = json.Get("botType").String()
c.inputVariable = json.Get("inputVariable").String() c.inputVariable = json.Get("inputVariable").String()
c.outputVariable = json.Get("outputVariable").String() c.outputVariable = json.Get("outputVariable").String()
c.capabilities = make(map[string]string)
for capability, pathJson := range json.Get("capabilities").Map() {
// 过滤掉不受支持的能力
switch capability {
case string(ApiNameChatCompletion),
string(ApiNameEmbeddings),
string(ApiNameImageGeneration),
string(ApiNameAudioSpeech),
string(ApiNameCohereV1Rerank):
c.capabilities[capability] = pathJson.String()
}
}
} }
func (c *ProviderConfig) Validate() error { func (c *ProviderConfig) Validate() error {
if c.timeout < 0 {
return errors.New("invalid timeout in config")
}
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal { if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
return errors.New("invalid protocol in config") return errors.New("invalid protocol in config")
} }
@@ -425,6 +453,10 @@ func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings) return ReplaceByCustomSettings(body, c.customSettings)
} }
func (c *ProviderConfig) PassthroughUnsupportedAPI() bool {
return c.passthrough
}
func CreateProvider(pc ProviderConfig) (Provider, error) { func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ] initializer, has := providerInitializers[pc.typ]
if !has { if !has {
@@ -499,7 +531,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo
} }
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
if modelMapping == nil || len(modelMapping) == 0 { if len(modelMapping) == 0 {
return "" return ""
} }
@@ -527,11 +559,22 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return "" return ""
} }
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
_, exist := c.capabilities[string(apiName)]
return exist
}
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
for capability, path := range capabilities {
c.capabilities[capability] = path
}
}
func (c *ProviderConfig) handleRequestBody( func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) { ) (types.Action, error) {
// use original protocol // use original protocol
if c.protocol == protocolOriginal { if c.IsOriginal() {
return types.ActionContinue, nil return types.ActionContinue, nil
} }
@@ -578,17 +621,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
} }
} }
// defaultTransformRequestBody 默认的请求体转换方法只做模型映射用slog替换模型名称不用序列化和反序列化提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
var request interface{} switch apiName {
if apiName == ApiNameChatCompletion { case ApiNameChatCompletion:
request = &chatCompletionRequest{} stream := gjson.GetBytes(body, "stream").Bool()
if stream {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else { } else {
request = &embeddingsRequest{} ctx.SetContext(ctxKeyIsStreaming, false)
} }
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
} }
return json.Marshal(request) model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
} }
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {

View File

@@ -52,7 +52,15 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
func (m *qwenProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): qwenChatCompletionPath,
string(ApiNameEmbeddings): qwenTextEmbeddingPath,
}
}
func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &qwenProvider{ return &qwenProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -75,18 +83,19 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
if m.config.IsOriginal() { if m.config.IsOriginal() {
} else if m.config.qwenEnableCompatible { } else if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion { } else if apiName == ApiNameChatCompletion || apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
} else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
} }
} }
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion { switch apiName {
case ApiNameChatCompletion:
return m.onChatCompletionRequestBody(ctx, body, headers, log) return m.onChatCompletionRequestBody(ctx, body, headers, log)
} else { case ApiNameEmbeddings:
return m.onEmbeddingsRequestBody(ctx, body, log) return m.onEmbeddingsRequestBody(ctx, body, log)
default:
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
} }
} }
@@ -95,7 +104,7 @@ func (m *qwenProvider) GetProviderType() string {
} }
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
@@ -140,7 +149,7 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
return types.ActionContinue, nil return types.ActionContinue, nil
} }
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
@@ -278,6 +287,9 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
if apiName == ApiNameEmbeddings { if apiName == ApiNameEmbeddings {
return m.onEmbeddingsResponseBody(ctx, body, log) return m.onEmbeddingsResponseBody(ctx, body, log)
} }
if m.config.isSupportedAPI(apiName) {
return body, nil
}
return nil, errUnsupportedApiName return nil, errUnsupportedApiName
} }

View File

@@ -55,7 +55,14 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error
return nil return nil
} }
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): sparkChatCompletionPath,
}
}
func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(i.DefaultCapabilities())
return &sparkProvider{ return &sparkProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -67,7 +74,7 @@ func (p *sparkProvider) GetProviderType() string {
} }
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !p.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
p.config.handleRequestHeaders(p, ctx, apiName, log) p.config.handleRequestHeaders(p, ctx, apiName, log)
@@ -75,13 +82,16 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
} }
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !p.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
} }
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
sparkResponse := &sparkResponse{} sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil { if err := json.Unmarshal(body, sparkResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err) return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
@@ -97,6 +107,9 @@ func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Ap
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
if name != ApiNameChatCompletion {
return chunk, nil
}
responseBuilder := &strings.Builder{} responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n") lines := strings.Split(string(chunk), "\n")
for _, data := range lines { for _, data := range lines {
@@ -168,7 +181,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons
} }
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities)
util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
} }

View File

@@ -24,7 +24,15 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro
return nil return nil
} }
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// stepfun的chat接口path和OpenAI的chat接口一样
string(ApiNameChatCompletion): stepfunChatCompletionPath,
}
}
func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *stepfunProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &stepfunProvider{ return &stepfunProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -41,7 +49,7 @@ func (m *stepfunProvider) GetProviderType() string {
} }
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -49,14 +57,14 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
} }
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, stepfunDomain) util.OverwriteRequestHostHeader(headers, stepfunDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -2,11 +2,12 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
) )
const ( const (
@@ -23,7 +24,14 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e
return nil return nil
} }
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): togetherAICompletionPath,
}
}
func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &togetherAIProvider{ return &togetherAIProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -40,7 +48,7 @@ func (m *togetherAIProvider) GetProviderType() string {
} }
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -48,14 +56,14 @@ func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
} }
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, togetherAICompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, togetherAIDomain) util.OverwriteRequestHostHeader(headers, togetherAIDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -24,7 +24,14 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
func (m *yiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): yiChatCompletionPath,
}
}
func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *yiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &yiProvider{ return &yiProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -41,7 +48,7 @@ func (m *yiProvider) GetProviderType() string {
} }
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -49,14 +56,14 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
} }
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, yiChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, yiDomain) util.OverwriteRequestHostHeader(headers, yiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")

View File

@@ -13,6 +13,7 @@ import (
const ( const (
zhipuAiDomain = "open.bigmodel.cn" zhipuAiDomain = "open.bigmodel.cn"
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions" zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings"
) )
type zhipuAiProviderInitializer struct{} type zhipuAiProviderInitializer struct{}
@@ -24,7 +25,15 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro
return nil return nil
} }
func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): zhipuAiChatCompletionPath,
string(ApiNameEmbeddings): zhipuAiEmbeddingsPath,
}
}
func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &zhipuAiProvider{ return &zhipuAiProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -41,7 +50,7 @@ func (m *zhipuAiProvider) GetProviderType() string {
} }
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -49,14 +58,14 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
} }
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath) util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, zhipuAiDomain) util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")
@@ -66,5 +75,8 @@ func (m *zhipuAiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, zhipuAiChatCompletionPath) { if strings.Contains(path, zhipuAiChatCompletionPath) {
return ApiNameChatCompletion return ApiNameChatCompletion
} }
if strings.Contains(path, zhipuAiEmbeddingsPath) {
return ApiNameEmbeddings
}
return "" return ""
} }

View File

@@ -57,6 +57,17 @@ func OverwriteRequestPathHeader(headers http.Header, path string) {
headers.Set(":path", path) headers.Set(":path", path)
} }
func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) {
mappedPath, exist := mapping[apiName]
if !exist {
return
}
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
}
headers.Set(":path", mappedPath)
}
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
if originAuth := headers.Get("Authorization"); originAuth != "" { if originAuth := headers.Get("Authorization"); originAuth != "" {