feat: implement apiToken failover mechanism (#1256)

This commit is contained in:
Se7en
2024-11-16 19:03:09 +08:00
committed by GitHub
parent f2a5df3949
commit d24123a55f
33 changed files with 1558 additions and 1231 deletions

View File

@@ -31,15 +31,16 @@ description: AI 代理插件配置参考
`provider`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
`context`的配置字段说明如下:
@@ -75,6 +76,16 @@ custom-setting会遵循如下表格根据`name`和协议来替换对应的字
如果启用了raw模式custom-setting会直接用输入的`name``value`去更改请求中的json内容而不对参数名称做任何限制和修改。
对于大多数协议custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
`failover` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
### 提供商特有配置

View File

@@ -1,9 +1,9 @@
package config
import (
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
// @Name ai-proxy
@@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
return nil
}
func (c *PluginConfig) Complete() error {
func (c *PluginConfig) Complete(log wrapper.Log) error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)
providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
return err
}

View File

@@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if apiName == "" && !providerConfig.IsOriginal() {
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}
if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
@@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
@@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
@@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
return types.ActionContinue
}
// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
@@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
@@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
(*ctx).BufferResponseBody()
}
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}

View File

@@ -1,14 +1,12 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// ai360Provider is the provider for 360 OpenAI service.
@@ -46,10 +44,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -58,47 +53,12 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -3,16 +3,15 @@ package provider
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// azureProvider is the provider for Azure OpenAI service.
type azureProviderInitializer struct {
}
@@ -55,47 +54,23 @@ func (m *azureProvider) GetProviderType() string {
}
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
if apiName == ApiNameChatCompletion {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
} else {
ctx.DontReadRequestBody()
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,10 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,10 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(baichuanChatCompletionPath)
_ = util.OverwriteRequestHost(baichuanDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +54,12 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -16,7 +17,8 @@ import (
// baiduProvider is the provider for baidu ernie bot service.
const (
baiduDomain = "aip.baidubce.com"
baiduDomain = "aip.baidubce.com"
baiduChatCompletionPath = "/chat"
)
var baiduModelToPathSuffixMap = map[string]string{
@@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(baiduDomain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
b.config.handleRequestHeaders(b, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, baiduDomain)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用文心一言接口协议
if b.config.protocol == protocolOriginal {
request := &baiduTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := b.getRequestPath(request.Model)
_ = util.OverwriteRequestPath(path)
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
}
if b.config.context == nil {
return types.ActionContinue, nil
}
err := b.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
b.setSystemContent(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := b.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
path := b.getRequestPath(ctx, request.Model)
util.OverwriteRequestPathHeader(headers, path)
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, b.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := b.getRequestPath(mappedModel)
_ = util.OverwriteRequestPath(path)
if b.config.context == nil {
baiduRequest := b.baiduTextGenRequest(request)
return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
}
err := b.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
baiduRequest := b.baiduTextGenRequest(request)
if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
baiduRequest := b.baiduTextGenRequest(request)
return json.Marshal(baiduRequest)
}
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -226,13 +165,13 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}
func (b *baiduProvider) getRequestPath(baiduModel string) string {
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
}
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
@@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (b *baiduProvider) GetApiName(path string) ApiName {
if strings.Contains(path, baiduChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -105,102 +106,39 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}
_ = util.OverwriteRequestPath(claudeChatCompletionPath)
_ = util.OverwriteRequestHost(claudeDomain)
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestHostHeader(headers, claudeDomain)
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}
_ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
headers.Add("anthropic-version", c.config.claudeVersion)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
// use original protocol
if c.config.protocol == protocolOriginal {
if c.config.context == nil {
return types.ActionContinue, nil
}
request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
// use openai protocol
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
if c.config.context == nil {
claudeRequest := c.buildClaudeTextGenRequest(request)
return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
claudeRequest := c.buildClaudeTextGenRequest(request)
if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
claudeRequest := c.buildClaudeTextGenRequest(request)
return json.Marshal(claudeRequest)
}
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -369,3 +307,25 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.System == "" {
request.System = content
} else {
request.System = content + "\n" + request.System
}
return json.Marshal(request)
}
func (c *claudeProvider) GetApiName(path string) ApiName {
if strings.Contains(path, claudeChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,13 +46,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
_ = util.OverwriteRequestHost(cloudflareDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -61,49 +54,13 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
if c.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -3,17 +3,16 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
const (
cohereDomain = "api.cohere.com"
chatCompletionPath = "/v1/chat"
cohereDomain = "api.cohere.com"
cohereChatCompletionPath = "/v1/chat"
)
type cohereProviderInitializer struct{}
@@ -27,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &cohereProvider{
config: config,
config: config,
contextCache: createContextCache(&config),
}, nil
}
type cohereProvider struct {
config ProviderConfig
config ProviderConfig
contextCache *contextCache
}
type cohereTextGenRequest struct {
@@ -57,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(cohereDomain)
_ = util.OverwriteRequestPath(chatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -68,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.config.protocol == protocolOriginal {
request := &cohereTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
return m.handleRequestBody(log, request)
}
origin := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, origin); err != nil {
return types.ActionContinue, err
}
request := m.buildCohereRequest(origin)
return m.handleRequestBody(log, request)
}
func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
err := replaceJsonRequestBody(request, log)
if err != nil {
_ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
@@ -112,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
PresencePenalty: origin.PresencePenalty,
}
}
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
cohereRequest := m.buildCohereRequest(request)
return json.Marshal(cohereRequest)
}
func (m *cohereProvider) GetApiName(path string) ApiName {
if strings.Contains(path, cohereChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -1,12 +1,15 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
)
@@ -57,6 +60,10 @@ type contextCache struct {
content string
}
type ContextInserter interface {
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
}
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
if callback == nil {
return errors.New("callback is nil")
@@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
timeout: providerConfig.timeout,
}
}
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
// get context will overwrite the original request host and path
// save the original request host and path in case they are needed for apiToken health check
ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost())
ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath())
if c.loaded {
log.Debugf("context file loaded from cache")
insertContext(provider, c.content, nil, body, log)
return nil
}
log.Infof("loading context file from %s", c.fileUrl.String())
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
return
}
c.content = string(responseBody)
c.loaded = true
log.Debugf("content: %s", c.content)
insertContext(provider, c.content, nil, body, log)
}, c.timeout)
}
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
typ := provider.GetProviderType()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
if inserter, ok := provider.(ContextInserter); ok {
body, err = inserter.insertHttpContextMessage(body, content, false)
} else {
body, err = defaultInsertHttpContextMessage(body, content)
}
if err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
}
if err := replaceHttpJsonRequestBody(body, log); err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}
func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
request := &chatCompletionRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
fileMessage := chatMessage{
Role: roleSystem,
Content: content,
}
var firstNonSystemMessageIndex int
for i, message := range request.Messages {
if message.Role != roleSystem {
firstNonSystemMessageIndex = i
break
}
}
if firstNonSystemMessageIndex == 0 {
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
} else {
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
}
return json.Marshal(request)
}

View File

@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deeplChatCompletionPath)
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
d.config.handleRequestHeaders(d, ctx, apiName, log)
return types.HeaderStopIteration, nil
}
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
headers.Del("Accept-Encoding")
}
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if d.config.protocol == protocolOriginal {
request := &deeplRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if err := d.overwriteRequestHost(request.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
originRequest := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, originRequest); err != nil {
return types.ActionContinue, err
}
if err := d.overwriteRequestHost(originRequest.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model)
deeplRequest := &deeplRequest{
Text: make([]string, 0),
TargetLang: d.config.targetLang,
}
for _, msg := range originRequest.Messages {
if msg.Role == roleSystem {
deeplRequest.Context = msg.StringContent()
} else {
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
}
}
return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log)
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
}
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
err := d.overwriteRequestHost(headers, request.Model)
if err != nil {
return nil, err
}
baiduRequest := d.deeplTextGenRequest(request)
return json.Marshal(baiduRequest)
}
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo
}
}
func (d *deeplProvider) overwriteRequestHost(model string) error {
func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error {
if model == "Pro" {
_ = util.OverwriteRequestHost(deeplHostPro)
util.OverwriteRequestHostHeader(headers, deeplHostPro)
} else if model == "Free" {
_ = util.OverwriteRequestHost(deeplHostFree)
util.OverwriteRequestHostHeader(headers, deeplHostFree)
} else {
return errors.New(`deepl model should be "Free" or "Pro"`)
}
return nil
}
func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest {
deeplRequest := &deeplRequest{
Text: make([]string, 0),
TargetLang: d.config.targetLang,
}
for _, msg := range request.Messages {
if msg.Role == roleSystem {
deeplRequest.Context = msg.StringContent()
} else {
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
}
}
return deeplRequest
}
func (d *deeplProvider) GetApiName(path string) ApiName {
if strings.Contains(path, deeplChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
// deepseekProvider is the provider for deepseek Ai service.
@@ -47,10 +45,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deepseekChatCompletionPath)
_ = util.OverwriteRequestHost(deepseekDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +53,12 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath)
util.OverwriteRequestHostHeader(headers, deepseekDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,12 +2,11 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
const (
@@ -41,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string {
}
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(doubaoDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
}
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(doubaoChatCompletionPath)
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -59,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
} else {
return types.ActionContinue, err
}
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
return types.ActionContinue, err
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath)
util.OverwriteRequestHostHeader(headers, doubaoDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *doubaoProvider) GetApiName(path string) ApiName {
if strings.Contains(path, doubaoChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -0,0 +1,594 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"math/rand"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
// @Title zh-CN 健康检测的间隔时间,单位毫秒
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string
// @Title zh-CN 记录 apiToken 请求失败的次数key 为 apiTokenvalue 为失败次数
ctxApiTokenRequestFailureCount string
// @Title zh-CN 记录 apiToken 健康检测成功的次数key 为 apiTokenvalue 为成功次数
ctxApiTokenRequestSuccessCount string
// @Title zh-CN 记录所有可用的 apiToken 列表
ctxApiTokens string
// @Title zh-CN 记录所有不可用的 apiToken 列表
ctxUnavailableApiTokens string
// @Title zh-CN 记录请求的 cluster, host 和 path用于在健康检测时构建请求
ctxHealthCheckEndpoint string
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
ctxVmLease string
}
type Lease struct {
VMID string `json:"vmID"`
Timestamp int64 `json:"timestamp"`
}
type HealthCheckEndpoint struct {
Host string `json:"host"`
Path string `json:"path"`
Cluster string `json:"cluster"`
}
const (
casMaxRetries = 10
addApiTokenOperation = "addApiToken"
removeApiTokenOperation = "removeApiToken"
addApiTokenRequestCountOperation = "addApiTokenRequestCount"
resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
ctxRequestHost = "requestHost"
ctxRequestPath = "requestPath"
)
var (
healthCheckClient wrapper.HttpClient
)
func (f *failover) FromJson(json gjson.Result) {
f.enabled = json.Get("enabled").Bool()
f.failureThreshold = json.Get("failureThreshold").Int()
if f.failureThreshold == 0 {
f.failureThreshold = 3
}
f.successThreshold = json.Get("successThreshold").Int()
if f.successThreshold == 0 {
f.successThreshold = 1
}
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
if f.healthCheckInterval == 0 {
f.healthCheckInterval = 5000
}
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
if f.healthCheckTimeout == 0 {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
}
func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
}
return nil
}
func (c *ProviderConfig) initVariable() {
// Set provider name as prefix to differentiate shared data
provider := c.GetType()
c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse"
c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount"
c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount"
c.failover.ctxApiTokens = provider + "-apiTokens"
c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens"
c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath"
c.failover.ctxVmLease = provider + "-vmLease"
}
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
return nil
}
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
c.resetSharedData()
if c.isFailoverEnabled() {
log.Debugf("ai-proxy plugin failover is enabled")
vmID := generateVMID()
err := c.initApiTokens()
if err != nil {
return fmt.Errorf("failed to init apiTokens: %v", err)
}
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
// Only the Wasm VM that successfully acquires the lease will perform health check
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
if err != nil {
log.Errorf("Failed to get unavailable tokens: %v", err)
return
}
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Host: healthCheckEndpoint.Host,
Cluster: healthCheckEndpoint.Cluster,
})
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken, log)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
}
})
}
return nil
}
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) {
originalHeaders := util.SliceToHeader(headers)
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log)
}
var err error
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
}
modifiedHeaders := util.HeaderToSlice(originalHeaders)
return modifiedHeaders, body, nil
}
func createHttpContext() *wrapper.CommonHttpCtx[any] {
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
return ctx
}
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
if err != nil {
log.Errorf("Failed to get request host and path: %v", err)
}
var healthCheckEndpoint HealthCheckEndpoint
err = json.Unmarshal(data, &healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to unmarshal request host and path: %v", err)
}
headers := [][2]string{
{"content-type", "application/json"},
}
body := []byte(fmt.Sprintf(`{
"model": "%s",
"messages": [
{
"role": "user",
"content": "who are you?"
}
]
}`, c.failover.healthCheckModel))
return healthCheckEndpoint, headers, body
}
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
now := time.Now().Unix()
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return c.setLease(vmID, now, cas, log)
} else {
log.Errorf("Failed to get lease: %v", err)
return false
}
}
if data == nil {
return c.setLease(vmID, now, cas, log)
}
var lease Lease
err = json.Unmarshal(data, &lease)
if err != nil {
log.Errorf("Failed to unmarshal lease data: %v", err)
return false
}
// If vmID is itself, try to renew the lease directly
// If the lease is expired (60s), try to acquire the lease
if lease.VMID == vmID || now-lease.Timestamp > 60 {
lease.VMID = vmID
lease.Timestamp = now
return c.setLease(vmID, now, cas, log)
}
return false
}
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
lease := Lease{
VMID: vmID,
Timestamp: timestamp,
}
leaseByte, err := json.Marshal(lease)
if err != nil {
log.Errorf("Failed to marshal lease data: %v", err)
return false
}
if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
log.Errorf("Failed to set or renew lease: %v", err)
return false
}
return true
}
func generateVMID() string {
return uuid.New().String()
}
// When number of request successes exceeds the threshold during health check,
// add the apiToken back to the available list and remove it from the unavailable list
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
if err != nil {
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
return
}
successCount := successApiTokenRequestCount[apiToken] + 1
if successCount >= c.failover.successThreshold {
log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
addApiToken(c.failover.ctxApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
} else {
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
}
}
// When number of request failures exceeds the threshold,
// remove the apiToken from the available list and add it to the unavailable list
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
return
}
availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
if err != nil {
log.Errorf("Failed to get available apiToken: %v", err)
return
}
// unavailable apiToken has been removed from the available list
if !containsElement(availableTokens, apiToken) {
return
}
failureCount := failureApiTokenRequestCount[apiToken] + 1
if failureCount >= c.failover.failureThreshold {
log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx, log)
} else {
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
}
}
func addApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, addApiTokenOperation, log)
}
func removeApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
}
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokens, cas, err := getApiTokens(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
exists := containsElement(apiTokens, apiToken)
if op == addApiTokenOperation && exists {
log.Debugf("%s already exists in %s", apiToken, key)
return
} else if op == removeApiTokenOperation && !exists {
log.Debugf("%s does not exist in %s", apiToken, key)
return
}
if op == addApiTokenOperation {
apiTokens = append(apiTokens, apiToken)
} else {
apiTokens = removeElement(apiTokens, apiToken)
}
if err := setApiTokens(key, apiTokens, cas); err == nil {
log.Debugf("Successfully updated %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func getApiTokens(key string) ([]string, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return []string{}, cas, nil
}
return nil, 0, err
}
if data == nil {
return []string{}, cas, nil
}
var apiTokens []string
if err = json.Unmarshal(data, &apiTokens); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
}
return apiTokens, cas, nil
}
func setApiTokens(key string, apiTokens []string, cas uint32) error {
data, err := json.Marshal(apiTokens)
if err != nil {
return fmt.Errorf("failed to marshal tokens: %v", err)
}
return proxywasm.SetSharedData(key, data, cas)
}
func removeElement(slice []string, s string) []string {
for i := 0; i < len(slice); i++ {
if slice[i] == s {
slice = append(slice[:i], slice[i+1:]...)
i--
}
}
return slice
}
func containsElement(slice []string, s string) bool {
for _, item := range slice {
if item == s {
return true
}
}
return false
}
func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return make(map[string]int64), cas, nil
}
return nil, 0, err
}
if data == nil {
return make(map[string]int64), cas, nil
}
var apiTokens map[string]int64
err = json.Unmarshal(data, &apiTokens)
if err != nil {
return nil, 0, err
}
return apiTokens, cas, nil
}
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
}
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
}
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
if c.isFailoverEnabled() {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
}
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
}
}
}
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
if op == resetApiTokenRequestCountOperation {
delete(apiTokenRequestCount, apiToken)
} else {
apiTokenRequestCount[apiToken]++
}
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
if err != nil {
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
}
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
}
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
if err != nil {
return ""
}
count := len(apiTokens)
switch count {
case 0:
return ""
case 1:
return apiTokens[0]
default:
return apiTokens[rand.Intn(count)]
}
}
func (c *ProviderConfig) isFailoverEnabled() bool {
return c.failover.enabled
}
func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
if c.isFailoverEnabled() {
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
}
}
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
}
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
var apiToken string
if c.isFailoverEnabled() {
// if enable apiToken failover, only use available apiToken
apiToken = c.GetGlobalRandomToken(log)
} else {
apiToken = c.GetRandomToken()
}
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil {
log.Errorf("Failed to get cluster_name: %v", err)
}
host := wrapper.GetRequestHost()
if host == "" {
host = ctx.GetContext(ctxRequestHost).(string)
}
path := wrapper.GetRequestPath()
if path == "" {
path = ctx.GetContext(ctxRequestPath).(string)
}
healthCheckEndpoint := HealthCheckEndpoint{
Host: host,
Path: path,
Cluster: string(cluster),
}
healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to marshal request host and path: %v", err)
}
err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
if err != nil {
log.Errorf("Failed to set request host and path: %v", err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -17,8 +18,11 @@ import (
// geminiProvider is the provider for google gemini/gemini flash service.
const (
geminiApiKeyHeader = "x-goog-api-key"
geminiDomain = "generativelanguage.googleapis.com"
geminiApiKeyHeader = "x-goog-api-key"
geminiDomain = "generativelanguage.googleapis.com"
geminiChatCompletionPath = "generateContent"
geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
geminiEmbeddingPath = "batchEmbedContents"
)
type geminiProviderInitializer struct {
@@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken())
_ = util.OverwriteRequestHost(geminiDomain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionRequestBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
return g.onEmbeddingsRequestBody(ctx, body, log)
return g.onChatCompletionRequestBody(ctx, body, headers, log)
} else {
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiChatRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
_ = util.OverwriteRequestPath(path)
// 移除多余的model和stream字段
request = &geminiChatRequest{
Contents: request.Contents,
SafetySettings: request.SafetySettings,
GenerationConfig: request.GenerationConfig,
Tools: request.Tools,
}
if g.config.context == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
g.setSystemContent(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream)
_ = util.OverwriteRequestPath(path)
if g.config.context == nil {
geminiRequest := g.buildGeminiChatRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
geminiRequest := g.buildGeminiChatRequest(request)
if err := replaceJsonRequestBody(geminiRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
geminiRequest := g.buildGeminiChatRequest(request)
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiBatchEmbeddingRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
_ = util.OverwriteRequestPath(path)
// 移除多余的model字段
request = &geminiBatchEmbeddingRequest{
Requests: request.Requests,
}
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false)
_ = util.OverwriteRequestPath(path)
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
geminiRequest := g.buildBatchEmbeddingRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
action = "batchEmbedContents"
action = geminiEmbeddingPath
} else if stream {
action = "streamGenerateContent?alt=sse"
action = geminiChatCompletionStreamPath
} else {
action = "generateContent"
action = geminiChatCompletionPath
}
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
}
@@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini
func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (g *geminiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) {
return ApiNameChatCompletion
}
if strings.Contains(path, geminiEmbeddingPath) {
return ApiNameEmbeddings
}
return ""
}

View File

@@ -1,14 +1,12 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
// githubProvider is the provider for GitHub OpenAI service.
@@ -48,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(githubDomain)
if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(githubEmbeddingPath)
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -66,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, githubDomain)
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
util.OverwriteRequestPathHeader(headers, githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
return types.ActionContinue, errUnsupportedApiName
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
func (m *githubProvider) GetApiName(path string) ApiName {
if strings.Contains(path, githubCompletionPath) {
return ApiNameChatCompletion
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
if strings.Contains(path, githubEmbeddingPath) {
return ApiNameEmbeddings
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
return ""
}

View File

@@ -2,11 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -18,14 +18,14 @@ const (
type groqProviderInitializer struct{}
func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &groqProvider{
config: config,
contextCache: createContextCache(&config),
@@ -37,47 +37,35 @@ type groqProvider struct {
contextCache *contextCache
}
func (m *groqProvider) GetProviderType() string {
func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}
func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(groqChatCompletionPath)
_ = util.OverwriteRequestHost(groqDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, groqChatCompletionPath)
util.OverwriteRequestHostHeader(headers, groqDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (g *groqProvider) GetApiName(path string) ApiName {
if strings.Contains(path, groqChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string {
}
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId)
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(hunyuanDomain)
_ = util.OverwriteRequestPath(hunyuanRequestPath)
// 添加hunyuan需要的自定义字段
_ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction)
_ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue)
// 删除一些字段
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段
headers.Add(actionKey, hunyuanChatCompletionTCAction)
headers.Add(versionKey, versionValue)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// 为header添加时间戳字段 因为需要根据body进行签名时依赖时间戳故于body处理部分创建时间戳
var timestamp int64 = time.Now().Unix()
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
// log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp)
// 使用混元本身接口的协议
if m.config.protocol == protocolOriginal {
@@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!")
model := request.Model
if model == "" {
@@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
string(body),
)
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
// log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body))
// // 打印所有的headers
// headers, err2 := proxywasm.GetHttpRequestHeaders()
// if err2 != nil {
// log.Errorf("failed to get request headers: %v", err2)
// } else {
// // 迭代并打印所有请求头
// for _, header := range headers {
// log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1])
// }
// }
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
}
@@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, err
}
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
var timestamp int64 = time.Now().Unix()
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
// 根据确定好的payload进行签名
body, _ = json.Marshal(hunyuanRequest)
authorizedValueNew := GetTC3Authorizationcode(
m.config.hunyuanAuthId,
m.config.hunyuanAuthKey,
timestamp,
hunyuanDomain,
hunyuanChatCompletionTCAction,
string(body),
)
util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew)
return json.Marshal(hunyuanRequest)
}
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
@@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64,
// fmt.Println(curl)
return authorization
}
func (m *hunyuanProvider) GetApiName(path string) ApiName {
return ApiNameChatCompletion
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(minimaxDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, minimaxDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return m.handleRequestBodyByChatCompletionPro(body, log)
} else {
// 使用ChatCompletion v2接口
return m.handleRequestBodyByChatCompletionV2(body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
}
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
}
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
// 使用minimax接口协议
if m.config.protocol == protocolOriginal {
request := &minimaxChatCompletionV2Request{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
if m.config.minimaxGroupId == "" {
return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
}
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil {
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
m.setBotSettings(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
@@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
// 由于 minimaxChatCompletionV2格式和 OpenAI 一致)和 minimaxChatCompletionPro格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
@@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
}
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
return nil, err
}
// 映射模型重写requestPath
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
if m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return body, nil
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (m *minimaxProvider) GetApiName(path string) ApiName {
if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
const (
@@ -43,9 +41,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(mistralDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -53,28 +49,11 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, mistralDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -3,13 +3,12 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"net/http"
)
// moonshotProvider is the provider for Moonshot AI service.
@@ -58,33 +57,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(moonshotChatCompletionPath)
_ = util.OverwriteRequestHost(moonshotDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath)
util.OverwriteRequestHostHeader(headers, moonshotDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
// moonshot 有自己获取 context 的配置moonshotFileId因此无法复用 handleRequestBody 方法
// moonshot 的 body 没有修改无须实现TransformRequestBody使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.config.moonshotFileId == "" && m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}

View File

@@ -3,11 +3,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
// ollamaProvider is the provider for Ollama service.
@@ -53,10 +52,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(ollamaChatCompletionPath)
_ = util.OverwriteRequestHost(m.serviceDomain)
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -64,51 +60,11 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.config.modelMapping == nil && m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
} else {
return types.ActionContinue, err
}
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
return types.ActionContinue, err
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath)
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
headers.Del("Content-Length")
}

View File

@@ -1,12 +1,13 @@
package provider
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string {
}
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" {
switch apiName {
case ApiNameChatCompletion:
_ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath)
util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
case ApiNameEmbeddings:
ctx.DontReadRequestBody()
_ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath)
util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
}
} else {
_ = util.OverwriteRequestPath(m.customPath)
util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
_ = util.OverwriteRequestHost(defaultOpenaiDomain)
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else {
_ = util.OverwriteRequestHost(m.customDomain)
util.OverwriteRequestHostHeader(headers, m.customDomain)
}
if len(m.config.apiTokens) > 0 {
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
}
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
headers.Del("Content-Length")
}
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
return nil, err
}
if m.config.responseJsonSchema != nil {
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
@@ -101,27 +110,5 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
request.StreamOptions.IncludeUsage = true
}
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return json.Marshal(request)
}

View File

@@ -1,14 +1,17 @@
package provider
import (
"encoding/json"
"errors"
"math/rand"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type ApiName string
@@ -110,14 +113,32 @@ type Provider interface {
GetProviderType() string
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
@@ -143,6 +164,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -289,6 +313,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
}
failoverJson := json.Get("failover")
c.failover = &failover{
enabled: false,
}
if failoverJson.Exists() {
c.failover.FromJson(failoverJson)
}
}
func (c *ProviderConfig) Validate() error {
@@ -304,6 +336,12 @@ func (c *ProviderConfig) Validate() error {
}
}
if c.failover.enabled {
if err := c.failover.Validate(); err != nil {
return err
}
}
if c.typ == "" {
return errors.New("missing type in provider config")
}
@@ -355,6 +393,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
return err
}
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
return c.setRequestModel(ctx, req, log)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req, log)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
var model *string
switch req := request.(type) {
case *chatCompletionRequest:
model = &req.Model
case *embeddingsRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model, log)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping, log)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
*model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, *model)
return nil
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
if len(mappedModel) != 0 {
@@ -391,3 +483,62 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
// use original protocol
if c.protocol == protocolOriginal {
return types.ActionContinue, nil
}
// use openai protocol
var err error
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
}
if err != nil {
return types.ActionContinue, err
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders()
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(originalHeaders)
}
}
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
var request interface{}
if apiName == ApiNameChatCompletion {
request = &chatCompletionRequest{}
} else {
request = &embeddingsRequest{}
}
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
return json.Marshal(request)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math"
"net/http"
"reflect"
"strings"
"time"
@@ -58,35 +59,50 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide
}
type qwenProvider struct {
config ProviderConfig
config ProviderConfig
contextCache *contextCache
}
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, qwenDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
}
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, headers, log)
} else {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
}
func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(qwenDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
} else if m.config.qwenEnableCompatible {
_ = util.OverwriteRequestPath(qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(qwenTextEmbeddingPath)
} else {
return types.ActionContinue, errUnsupportedApiName
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -121,65 +137,23 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}
return types.ActionContinue, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
if m.config.protocol == protocolOriginal {
if m.config.context == nil {
return types.ActionContinue, nil
}
request := &qwenTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
m.insertContextMessage(request, content, false)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
// Use the qwen multimodal model generation API
if strings.HasPrefix(request.Model, qwenVlModelPrefixName) {
_ = util.OverwriteRequestPath(qwenMultimodalGenerationPath)
util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath)
}
streaming := request.Stream
@@ -191,62 +165,20 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
if m.config.context == nil {
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
}
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
log.Debugf("=== embeddings request: %v", request)
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in the request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil {
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
} else {
return types.ActionContinue, err
qwenRequest, err := m.buildQwenTextEmbeddingRequest(request)
if err != nil {
return nil, err
}
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -375,7 +307,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
messages := make([]qwenMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
@@ -397,6 +329,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
Tools: origRequest.Tools,
},
}
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput)
}
if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName {
builder := strings.Builder{}
for _, fileId := range m.config.qwenFileIds {
@@ -406,13 +343,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
builder.WriteString("fileid://")
builder.WriteString(fileId)
}
contextMessageId := m.insertContextMessage(request, builder.String(), true)
if contextMessageId == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
body, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("unable to marshal request: %v", err)
}
return m.insertHttpContextMessage(body, builder.String(), true)
}
return request
return json.Marshal(request)
}
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
@@ -569,7 +508,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return nil
}
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int {
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &qwenTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
fileMessage := qwenMessage{
Role: roleSystem,
Content: content,
@@ -586,10 +530,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
}
if firstNonSystemMessageIndex == 0 {
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
return 0
} else if !onlyOneSystemBeforeFile {
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
return firstNonSystemMessageIndex
} else {
builder := strings.Builder{}
for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] {
@@ -599,8 +541,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
builder.WriteString(message.StringContent())
}
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
return 1
firstNonSystemMessageIndex = 1
}
if firstNonSystemMessageIndex == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
}
return json.Marshal(request)
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
@@ -804,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
}
}
}
func (m *qwenProvider) GetApiName(path string) ApiName {
switch {
case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath):
return ApiNameEmbeddings
default:
return ""
}
}

View File

@@ -3,7 +3,6 @@ package provider
import (
"encoding/json"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
@@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er
return nil
}
func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
body, err := json.Marshal(request)
if err != nil {
@@ -31,6 +37,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err
}
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original request body: %v", err)
}
return nil
}
func insertContextMessage(request *chatCompletionRequest, content string) {
fileMessage := chatMessage{
Role: roleSystem,

View File

@@ -2,8 +2,8 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(sparkHost)
_ = util.OverwriteRequestPath(sparkChatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
p.config.handleRequestHeaders(p, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用Spark协议
if p.config.protocol == protocolOriginal {
request := &sparkRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 目前星火在模型名称错误时也会调用generalv3这里还是按照输入的模型名称设置响应里的模型名称
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
// 使用openai协议
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -205,3 +172,11 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
const (
@@ -45,10 +43,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(stepfunChatCompletionPath)
_ = util.OverwriteRequestHost(stepfunDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +51,12 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath)
util.OverwriteRequestHostHeader(headers, stepfunDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,10 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -45,10 +44,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(yiChatCompletionPath)
_ = util.OverwriteRequestHost(yiDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +52,12 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, yiChatCompletionPath)
util.OverwriteRequestHostHeader(headers, yiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -44,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(zhipuAiChatCompletionPath)
_ = util.OverwriteRequestHost(zhipuAiDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -55,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath)
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *zhipuAiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, zhipuAiChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -1,6 +1,10 @@
package util
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
import (
"net/http"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
const (
HeaderContentType = "Content-Type"
@@ -21,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string {
return headers
}
func OverwriteRequestHost(host string) error {
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost)
}
return proxywasm.ReplaceHttpRequestHeader(":authority", host)
}
func OverwriteRequestPath(path string) error {
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath)
@@ -43,3 +40,56 @@ func OverwriteRequestAuthorization(credential string) error {
}
return proxywasm.ReplaceHttpRequestHeader("Authorization", credential)
}
func OverwriteRequestHostHeader(headers http.Header, host string) {
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
headers.Set("X-ENVOY-ORIGINAL-HOST", originHost)
}
headers.Set(":authority", host)
}
func OverwriteRequestPathHeader(headers http.Header, path string) {
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
}
headers.Set(":path", path)
}
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
if originAuth := headers.Get("Authorization"); originAuth != "" {
headers.Set("X-HI-ORIGINAL-AUTH", originAuth)
}
}
headers.Set("Authorization", credential)
}
func HeaderToSlice(header http.Header) [][2]string {
slice := make([][2]string, 0, len(header))
for key, values := range header {
for _, value := range values {
slice = append(slice, [2]string{key, value})
}
}
return slice
}
func SliceToHeader(slice [][2]string) http.Header {
header := make(http.Header)
for _, pair := range slice {
key := pair[0]
value := pair[1]
header.Add(key, value)
}
return header
}
func GetOriginalHttpHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
return SliceToHeader(originalHeaders)
}
func ReplaceOriginalHttpHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
}

View File

@@ -45,6 +45,19 @@ func (c RouteCluster) HostName() string {
return GetRequestHost()
}
type TargetCluster struct {
Host string
Cluster string
}
func (c TargetCluster) ClusterName() string {
return c.Cluster
}
func (c TargetCluster) HostName() string {
return c.Host
}
type K8sCluster struct {
ServiceName string
Namespace string