mirror of
https://github.com/alibaba/higress.git
synced 2026-03-18 09:17:26 +08:00
feat: implement apiToken failover mechanism (#1256)
This commit is contained in:
@@ -31,15 +31,16 @@ description: AI 代理插件配置参考
|
||||
|
||||
`provider`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称 |
|
||||
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称 |
|
||||
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
|
||||
`context`的配置字段说明如下:
|
||||
|
||||
@@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
|
||||
如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
|
||||
对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
|
||||
|
||||
`failover` 的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|------|-------|-----------------------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
|
||||
|
||||
### 提供商特有配置
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// @Name ai-proxy
|
||||
@@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) Complete() error {
|
||||
func (c *PluginConfig) Complete(log wrapper.Log) error {
|
||||
if c.activeProviderConfig == nil {
|
||||
c.activeProvider = nil
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)
|
||||
|
||||
providerConfig := c.GetProviderConfig()
|
||||
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
|
||||
if err := pluginConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
||||
if err := pluginConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
path, _ := url.Parse(rawPath)
|
||||
apiName := getOpenAiApiName(path.Path)
|
||||
providerConfig := pluginConfig.GetProviderConfig()
|
||||
if apiName == "" && !providerConfig.IsOriginal() {
|
||||
if providerConfig.IsOriginal() {
|
||||
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
|
||||
apiName = handler.GetApiName(path.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if apiName == "" {
|
||||
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
||||
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
|
||||
log.Debugf("[onHttpRequestHeader] no send response")
|
||||
@@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
ctx.SetContext(ctxKeyApiName, apiName)
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
// Set the apiToken for the current request.
|
||||
providerConfig.SetApiTokenInUse(ctx, log)
|
||||
|
||||
hasRequestBody := wrapper.HasRequestBody()
|
||||
action, err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
if err == nil {
|
||||
@@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
|
||||
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
|
||||
|
||||
providerConfig := pluginConfig.GetProviderConfig()
|
||||
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
|
||||
|
||||
status, err := proxywasm.GetHttpResponseHeader(":status")
|
||||
if err != nil || status != "200" {
|
||||
if err != nil {
|
||||
log.Errorf("unable to load :status header from response: %v", err)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
||||
|
||||
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
|
||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
||||
action, err := handler.OnResponseHeaders(ctx, apiName, log)
|
||||
@@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func getOpenAiApiName(path string) provider.ApiName {
|
||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/embeddings") {
|
||||
return provider.ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
||||
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
||||
@@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
|
||||
(*ctx).BufferResponseBody()
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAiApiName(path string) provider.ApiName {
|
||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/embeddings") {
|
||||
return provider.ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// ai360Provider is the provider for 360 OpenAI service.
|
||||
@@ -46,10 +44,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(ai360Domain)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
@@ -58,47 +53,12 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionRequestBody(ctx, body, log)
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
// 映射模型
|
||||
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
request.Model = mappedModel
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in embeddings request")
|
||||
}
|
||||
// 映射模型
|
||||
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
request.Model = mappedModel
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -3,16 +3,15 @@ package provider
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// azureProvider is the provider for Azure OpenAI service.
|
||||
|
||||
type azureProviderInitializer struct {
|
||||
}
|
||||
|
||||
@@ -55,47 +54,23 @@ func (m *azureProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
|
||||
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
|
||||
if apiName == ApiNameChatCompletion {
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
} else {
|
||||
ctx.DontReadRequestBody()
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
// We don't need to process the request body for other APIs.
|
||||
return types.ActionContinue, nil
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -47,10 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(baichuanChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(baichuanDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -58,28 +54,12 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, baichuanDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,7 +17,8 @@ import (
|
||||
// baiduProvider is the provider for baidu ernie bot service.
|
||||
|
||||
const (
|
||||
baiduDomain = "aip.baidubce.com"
|
||||
baiduDomain = "aip.baidubce.com"
|
||||
baiduChatCompletionPath = "/chat"
|
||||
)
|
||||
|
||||
var baiduModelToPathSuffixMap = map[string]string{
|
||||
@@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(baiduDomain)
|
||||
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
b.config.handleRequestHeaders(b, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 使用文心一言接口协议
|
||||
if b.config.protocol == protocolOriginal {
|
||||
request := &baiduTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
path := b.getRequestPath(request.Model)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
if b.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
err := b.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
b.setSystemContent(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
err := b.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := b.getRequestPath(ctx, request.Model)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
// 映射模型重写requestPath
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, b.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
path := b.getRequestPath(mappedModel)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
if b.config.context == nil {
|
||||
baiduRequest := b.baiduTextGenRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
|
||||
}
|
||||
|
||||
err := b.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
baiduRequest := b.baiduTextGenRequest(request)
|
||||
if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
baiduRequest := b.baiduTextGenRequest(request)
|
||||
return json.Marshal(baiduRequest)
|
||||
}
|
||||
|
||||
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -226,13 +165,13 @@ type baiduTextGenRequest struct {
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func (b *baiduProvider) getRequestPath(baiduModel string) string {
|
||||
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
|
||||
if !ok {
|
||||
suffix = baiduModel
|
||||
}
|
||||
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
|
||||
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
|
||||
@@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
|
||||
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (b *baiduProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, baiduChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -105,102 +106,39 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = util.OverwriteRequestPath(claudeChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(claudeDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
|
||||
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if c.config.claudeVersion == "" {
|
||||
c.config.claudeVersion = defaultVersion
|
||||
}
|
||||
_ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
return types.ActionContinue, nil
|
||||
headers.Add("anthropic-version", c.config.claudeVersion)
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
// use original protocol
|
||||
if c.config.protocol == protocolOriginal {
|
||||
if c.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &claudeTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// use openai protocol
|
||||
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, c.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
if c.config.context == nil {
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
|
||||
}
|
||||
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
return json.Marshal(claudeRequest)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -369,3 +307,25 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
|
||||
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
|
||||
request := &claudeTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if request.System == "" {
|
||||
request.System = content
|
||||
} else {
|
||||
request.System = content + "\n" + request.System
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, claudeChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -2,12 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -47,13 +46,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||
_ = util.OverwriteRequestHost(cloudflareDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken())
|
||||
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -61,49 +54,13 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, c.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
if c.contextCache == nil {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -3,17 +3,16 @@ package provider
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
cohereDomain = "api.cohere.com"
|
||||
chatCompletionPath = "/v1/chat"
|
||||
cohereDomain = "api.cohere.com"
|
||||
cohereChatCompletionPath = "/v1/chat"
|
||||
)
|
||||
|
||||
type cohereProviderInitializer struct{}
|
||||
@@ -27,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error
|
||||
|
||||
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &cohereProvider{
|
||||
config: config,
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cohereProvider struct {
|
||||
config ProviderConfig
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
type cohereTextGenRequest struct {
|
||||
@@ -57,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(cohereDomain)
|
||||
_ = util.OverwriteRequestPath(chatCompletionPath)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -68,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.config.protocol == protocolOriginal {
|
||||
request := &cohereTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return m.handleRequestBody(log, request)
|
||||
}
|
||||
origin := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, origin); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
request := m.buildCohereRequest(origin)
|
||||
return m.handleRequestBody(log, request)
|
||||
}
|
||||
|
||||
func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
err := replaceJsonRequestBody(request, log)
|
||||
if err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
|
||||
@@ -112,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
|
||||
PresencePenalty: origin.PresencePenalty,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, cohereDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cohereRequest := m.buildCohereRequest(request)
|
||||
return json.Marshal(cohereRequest)
|
||||
}
|
||||
|
||||
func (m *cohereProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, cohereChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -57,6 +60,10 @@ type contextCache struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ContextInserter interface {
|
||||
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
|
||||
if callback == nil {
|
||||
return errors.New("callback is nil")
|
||||
@@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
|
||||
timeout: providerConfig.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
|
||||
// get context will overwrite the original request host and path
|
||||
// save the original request host and path in case they are needed for apiToken health check
|
||||
ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost())
|
||||
ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath())
|
||||
|
||||
if c.loaded {
|
||||
log.Debugf("context file loaded from cache")
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("loading context file from %s", c.fileUrl.String())
|
||||
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
|
||||
return
|
||||
}
|
||||
c.content = string(responseBody)
|
||||
c.loaded = true
|
||||
log.Debugf("content: %s", c.content)
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
}, c.timeout)
|
||||
}
|
||||
|
||||
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
typ := provider.GetProviderType()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
|
||||
if inserter, ok := provider.(ContextInserter); ok {
|
||||
body, err = inserter.insertHttpContextMessage(body, content, false)
|
||||
} else {
|
||||
body, err = defaultInsertHttpContextMessage(body, content)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
|
||||
}
|
||||
if err := replaceHttpJsonRequestBody(body, log); err != nil {
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
fileMessage := chatMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
var firstNonSystemMessageIndex int
|
||||
for i, message := range request.Messages {
|
||||
if message.Role != roleSystem {
|
||||
firstNonSystemMessageIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
|
||||
} else {
|
||||
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
@@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(deeplChatCompletionPath)
|
||||
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
d.config.handleRequestHeaders(d, ctx, apiName, log)
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
headers.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if d.config.protocol == protocolOriginal {
|
||||
request := &deeplRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if err := d.overwriteRequestHost(request.Model); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
} else {
|
||||
originRequest := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, originRequest); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
if err := d.overwriteRequestHost(originRequest.Model); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model)
|
||||
deeplRequest := &deeplRequest{
|
||||
Text: make([]string, 0),
|
||||
TargetLang: d.config.targetLang,
|
||||
}
|
||||
for _, msg := range originRequest.Messages {
|
||||
if msg.Role == roleSystem {
|
||||
deeplRequest.Context = msg.StringContent()
|
||||
} else {
|
||||
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
|
||||
}
|
||||
}
|
||||
return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log)
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
err := d.overwriteRequestHost(headers, request.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
baiduRequest := d.deeplTextGenRequest(request)
|
||||
return json.Marshal(baiduRequest)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deeplProvider) overwriteRequestHost(model string) error {
|
||||
func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error {
|
||||
if model == "Pro" {
|
||||
_ = util.OverwriteRequestHost(deeplHostPro)
|
||||
util.OverwriteRequestHostHeader(headers, deeplHostPro)
|
||||
} else if model == "Free" {
|
||||
_ = util.OverwriteRequestHost(deeplHostFree)
|
||||
util.OverwriteRequestHostHeader(headers, deeplHostFree)
|
||||
} else {
|
||||
return errors.New(`deepl model should be "Free" or "Pro"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest {
|
||||
deeplRequest := &deeplRequest{
|
||||
Text: make([]string, 0),
|
||||
TargetLang: d.config.targetLang,
|
||||
}
|
||||
for _, msg := range request.Messages {
|
||||
if msg.Role == roleSystem {
|
||||
deeplRequest.Context = msg.StringContent()
|
||||
} else {
|
||||
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
|
||||
}
|
||||
}
|
||||
return deeplRequest
|
||||
}
|
||||
|
||||
func (d *deeplProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, deeplChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// deepseekProvider is the provider for deepseek Ai service.
|
||||
@@ -47,10 +45,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(deepseekChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(deepseekDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -58,28 +53,12 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, deepseekDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,12 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -41,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = util.OverwriteRequestHost(doubaoDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(doubaoChatCompletionPath)
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -59,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
if m.contextCache != nil {
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
} else {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
} else {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, doubaoDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, doubaoChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
594
plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Normal file
594
plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type failover struct {
|
||||
// @Title zh-CN 是否启用 apiToken 的 failover 机制
|
||||
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
|
||||
// @Title zh-CN 触发 failover 连续请求失败的阈值
|
||||
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
|
||||
// @Title zh-CN 健康检测的成功阈值
|
||||
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
|
||||
// @Title zh-CN 健康检测的间隔时间,单位毫秒
|
||||
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
|
||||
// @Title zh-CN 健康检测的超时时间,单位毫秒
|
||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||
// @Title zh-CN 健康检测使用的模型
|
||||
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
// @Title zh-CN 本次请求使用的 apiToken
|
||||
ctxApiTokenInUse string
|
||||
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
|
||||
ctxApiTokenRequestFailureCount string
|
||||
// @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数
|
||||
ctxApiTokenRequestSuccessCount string
|
||||
// @Title zh-CN 记录所有可用的 apiToken 列表
|
||||
ctxApiTokens string
|
||||
// @Title zh-CN 记录所有不可用的 apiToken 列表
|
||||
ctxUnavailableApiTokens string
|
||||
// @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求
|
||||
ctxHealthCheckEndpoint string
|
||||
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
|
||||
ctxVmLease string
|
||||
}
|
||||
|
||||
type Lease struct {
|
||||
VMID string `json:"vmID"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HealthCheckEndpoint struct {
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Cluster string `json:"cluster"`
|
||||
}
|
||||
|
||||
const (
|
||||
casMaxRetries = 10
|
||||
addApiTokenOperation = "addApiToken"
|
||||
removeApiTokenOperation = "removeApiToken"
|
||||
addApiTokenRequestCountOperation = "addApiTokenRequestCount"
|
||||
resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
|
||||
ctxRequestHost = "requestHost"
|
||||
ctxRequestPath = "requestPath"
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckClient wrapper.HttpClient
|
||||
)
|
||||
|
||||
func (f *failover) FromJson(json gjson.Result) {
|
||||
f.enabled = json.Get("enabled").Bool()
|
||||
f.failureThreshold = json.Get("failureThreshold").Int()
|
||||
if f.failureThreshold == 0 {
|
||||
f.failureThreshold = 3
|
||||
}
|
||||
f.successThreshold = json.Get("successThreshold").Int()
|
||||
if f.successThreshold == 0 {
|
||||
f.successThreshold = 1
|
||||
}
|
||||
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
|
||||
if f.healthCheckInterval == 0 {
|
||||
f.healthCheckInterval = 5000
|
||||
}
|
||||
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
|
||||
if f.healthCheckTimeout == 0 {
|
||||
f.healthCheckTimeout = 5000
|
||||
}
|
||||
f.healthCheckModel = json.Get("healthCheckModel").String()
|
||||
}
|
||||
|
||||
func (f *failover) Validate() error {
|
||||
if f.healthCheckModel == "" {
|
||||
return errors.New("missing healthCheckModel in failover config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initVariable() {
|
||||
// Set provider name as prefix to differentiate shared data
|
||||
provider := c.GetType()
|
||||
c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse"
|
||||
c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount"
|
||||
c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount"
|
||||
c.failover.ctxApiTokens = provider + "-apiTokens"
|
||||
c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens"
|
||||
c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath"
|
||||
c.failover.ctxVmLease = provider + "-vmLease"
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
|
||||
c.initVariable()
|
||||
// Reset shared data in case plugin configuration is updated
|
||||
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
|
||||
c.resetSharedData()
|
||||
|
||||
if c.isFailoverEnabled() {
|
||||
log.Debugf("ai-proxy plugin failover is enabled")
|
||||
|
||||
vmID := generateVMID()
|
||||
err := c.initApiTokens()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init apiTokens: %v", err)
|
||||
}
|
||||
|
||||
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
|
||||
// Only the Wasm VM that successfully acquires the lease will perform health check
|
||||
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
|
||||
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
|
||||
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get unavailable tokens: %v", err)
|
||||
return
|
||||
}
|
||||
if len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Host: healthCheckEndpoint.Host,
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken, log)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) {
|
||||
originalHeaders := util.SliceToHeader(headers)
|
||||
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log)
|
||||
}
|
||||
|
||||
var err error
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalHttpHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
|
||||
util.ReplaceOriginalHttpHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
|
||||
}
|
||||
|
||||
modifiedHeaders := util.HeaderToSlice(originalHeaders)
|
||||
return modifiedHeaders, body, nil
|
||||
}
|
||||
|
||||
func createHttpContext() *wrapper.CommonHttpCtx[any] {
|
||||
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
|
||||
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
|
||||
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
|
||||
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
|
||||
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get request host and path: %v", err)
|
||||
}
|
||||
var healthCheckEndpoint HealthCheckEndpoint
|
||||
err = json.Unmarshal(data, &healthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to unmarshal request host and path: %v", err)
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"content-type", "application/json"},
|
||||
}
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"model": "%s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you?"
|
||||
}
|
||||
]
|
||||
}`, c.failover.healthCheckModel))
|
||||
return healthCheckEndpoint, headers, body
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
} else {
|
||||
log.Errorf("Failed to get lease: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
}
|
||||
|
||||
var lease Lease
|
||||
err = json.Unmarshal(data, &lease)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to unmarshal lease data: %v", err)
|
||||
return false
|
||||
}
|
||||
// If vmID is itself, try to renew the lease directly
|
||||
// If the lease is expired (60s), try to acquire the lease
|
||||
if lease.VMID == vmID || now-lease.Timestamp > 60 {
|
||||
lease.VMID = vmID
|
||||
lease.Timestamp = now
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
|
||||
lease := Lease{
|
||||
VMID: vmID,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
leaseByte, err := json.Marshal(lease)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal lease data: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
|
||||
log.Errorf("Failed to set or renew lease: %v", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func generateVMID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// When number of request successes exceeds the threshold during health check,
|
||||
// add the apiToken back to the available list and remove it from the unavailable list
|
||||
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
|
||||
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
successCount := successApiTokenRequestCount[apiToken] + 1
|
||||
if successCount >= c.failover.successThreshold {
|
||||
log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
}
|
||||
}
|
||||
|
||||
// When number of request failures exceeds the threshold,
|
||||
// remove the apiToken from the available list and add it to the unavailable list
|
||||
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get available apiToken: %v", err)
|
||||
return
|
||||
}
|
||||
// unavailable apiToken has been removed from the available list
|
||||
if !containsElement(availableTokens, apiToken) {
|
||||
return
|
||||
}
|
||||
|
||||
failureCount := failureApiTokenRequestCount[apiToken] + 1
|
||||
if failureCount >= c.failover.failureThreshold {
|
||||
log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
// Set the request host and path to shared data in case they are needed in apiToken health check
|
||||
c.setHealthCheckEndpoint(ctx, log)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
}
|
||||
}
|
||||
|
||||
func addApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, addApiTokenOperation, log)
|
||||
}
|
||||
|
||||
func removeApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
|
||||
}
|
||||
|
||||
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokens, cas, err := getApiTokens(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
exists := containsElement(apiTokens, apiToken)
|
||||
if op == addApiTokenOperation && exists {
|
||||
log.Debugf("%s already exists in %s", apiToken, key)
|
||||
return
|
||||
} else if op == removeApiTokenOperation && !exists {
|
||||
log.Debugf("%s does not exist in %s", apiToken, key)
|
||||
return
|
||||
}
|
||||
|
||||
if op == addApiTokenOperation {
|
||||
apiTokens = append(apiTokens, apiToken)
|
||||
} else {
|
||||
apiTokens = removeElement(apiTokens, apiToken)
|
||||
}
|
||||
|
||||
if err := setApiTokens(key, apiTokens, cas); err == nil {
|
||||
log.Debugf("Successfully updated %s in %s", apiToken, key)
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func getApiTokens(key string) ([]string, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return []string{}, cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
if data == nil {
|
||||
return []string{}, cas, nil
|
||||
}
|
||||
|
||||
var apiTokens []string
|
||||
if err = json.Unmarshal(data, &apiTokens); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
|
||||
}
|
||||
|
||||
return apiTokens, cas, nil
|
||||
}
|
||||
|
||||
func setApiTokens(key string, apiTokens []string, cas uint32) error {
|
||||
data, err := json.Marshal(apiTokens)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tokens: %v", err)
|
||||
}
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
|
||||
func removeElement(slice []string, s string) []string {
|
||||
for i := 0; i < len(slice); i++ {
|
||||
if slice[i] == s {
|
||||
slice = append(slice[:i], slice[i+1:]...)
|
||||
i--
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func containsElement(slice []string, s string) bool {
|
||||
for _, item := range slice {
|
||||
if item == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
|
||||
var apiTokens map[string]int64
|
||||
err = json.Unmarshal(data, &apiTokens)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return apiTokens, cas, nil
|
||||
}
|
||||
|
||||
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
|
||||
}
|
||||
|
||||
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
|
||||
if c.isFailoverEnabled() {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
|
||||
}
|
||||
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
|
||||
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if op == resetApiTokenRequestCountOperation {
|
||||
delete(apiTokenRequestCount, apiToken)
|
||||
} else {
|
||||
apiTokenRequestCount[apiToken]++
|
||||
}
|
||||
|
||||
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
|
||||
}
|
||||
|
||||
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
|
||||
log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initApiTokens() error {
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
|
||||
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
|
||||
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
|
||||
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
count := len(apiTokens)
|
||||
switch count {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return apiTokens[0]
|
||||
default:
|
||||
return apiTokens[rand.Intn(count)]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) isFailoverEnabled() bool {
|
||||
return c.failover.enabled
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) resetSharedData() {
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
|
||||
if c.isFailoverEnabled() {
|
||||
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
var apiToken string
|
||||
if c.isFailoverEnabled() {
|
||||
// if enable apiToken failover, only use available apiToken
|
||||
apiToken = c.GetGlobalRandomToken(log)
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
}
|
||||
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get cluster_name: %v", err)
|
||||
}
|
||||
|
||||
host := wrapper.GetRequestHost()
|
||||
if host == "" {
|
||||
host = ctx.GetContext(ctxRequestHost).(string)
|
||||
}
|
||||
path := wrapper.GetRequestPath()
|
||||
if path == "" {
|
||||
path = ctx.GetContext(ctxRequestPath).(string)
|
||||
}
|
||||
|
||||
healthCheckEndpoint := HealthCheckEndpoint{
|
||||
Host: host,
|
||||
Path: path,
|
||||
Cluster: string(cluster),
|
||||
}
|
||||
|
||||
healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal request host and path: %v", err)
|
||||
|
||||
}
|
||||
err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to set request host and path: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,8 +18,11 @@ import (
|
||||
// geminiProvider is the provider for google gemini/gemini flash service.
|
||||
|
||||
const (
|
||||
geminiApiKeyHeader = "x-goog-api-key"
|
||||
geminiDomain = "generativelanguage.googleapis.com"
|
||||
geminiApiKeyHeader = "x-goog-api-key"
|
||||
geminiDomain = "generativelanguage.googleapis.com"
|
||||
geminiChatCompletionPath = "generateContent"
|
||||
geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
|
||||
geminiEmbeddingPath = "batchEmbedContents"
|
||||
)
|
||||
|
||||
type geminiProviderInitializer struct {
|
||||
@@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken())
|
||||
_ = util.OverwriteRequestHost(geminiDomain)
|
||||
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionRequestBody(ctx, body, log)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
return g.onEmbeddingsRequestBody(ctx, body, log)
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
} else {
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
// 使用gemini接口协议
|
||||
if g.config.protocol == protocolOriginal {
|
||||
request := &geminiChatRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
// 移除多余的model和stream字段
|
||||
request = &geminiChatRequest{
|
||||
Contents: request.Contents,
|
||||
SafetySettings: request.SafetySettings,
|
||||
GenerationConfig: request.GenerationConfig,
|
||||
Tools: request.Tools,
|
||||
}
|
||||
if g.config.context == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
err := g.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
g.setSystemContent(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
// 映射模型重写requestPath
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, g.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
if g.config.context == nil {
|
||||
geminiRequest := g.buildGeminiChatRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
|
||||
}
|
||||
|
||||
err := g.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
geminiRequest := g.buildGeminiChatRequest(request)
|
||||
if err := replaceJsonRequestBody(geminiRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
geminiRequest := g.buildGeminiChatRequest(request)
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
// 使用gemini接口协议
|
||||
if g.config.protocol == protocolOriginal {
|
||||
request := &geminiBatchEmbeddingRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
|
||||
// 移除多余的model字段
|
||||
request = &geminiBatchEmbeddingRequest{
|
||||
Requests: request.Requests,
|
||||
}
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in embeddings request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, g.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false)
|
||||
_ = util.OverwriteRequestPath(path)
|
||||
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
geminiRequest := g.buildBatchEmbeddingRequest(request)
|
||||
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
|
||||
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
action = "batchEmbedContents"
|
||||
action = geminiEmbeddingPath
|
||||
} else if stream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
action = geminiChatCompletionStreamPath
|
||||
} else {
|
||||
action = "generateContent"
|
||||
action = geminiChatCompletionPath
|
||||
}
|
||||
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
|
||||
}
|
||||
@@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini
|
||||
func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (g *geminiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, geminiEmbeddingPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// githubProvider is the provider for GitHub OpenAI service.
|
||||
@@ -48,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(githubDomain)
|
||||
if apiName == ApiNameChatCompletion {
|
||||
_ = util.OverwriteRequestPath(githubCompletionPath)
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
_ = util.OverwriteRequestPath(githubEmbeddingPath)
|
||||
}
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
@@ -66,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, githubDomain)
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionRequestBody(ctx, body, log)
|
||||
util.OverwriteRequestPathHeader(headers, githubCompletionPath)
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
func (m *githubProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, githubCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
if strings.Contains(path, githubEmbeddingPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
// 映射模型
|
||||
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
request.Model = mappedModel
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in embeddings request")
|
||||
}
|
||||
// 映射模型
|
||||
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
request.Model = mappedModel
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -18,14 +18,14 @@ const (
|
||||
|
||||
type groqProviderInitializer struct{}
|
||||
|
||||
func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &groqProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -37,47 +37,35 @@ type groqProvider struct {
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *groqProvider) GetProviderType() string {
|
||||
func (g *groqProvider) GetProviderType() string {
|
||||
return providerTypeGroq
|
||||
}
|
||||
|
||||
func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(groqChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(groqDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, groqChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, groqDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (g *groqProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, groqChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId)
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
_ = util.OverwriteRequestHost(hunyuanDomain)
|
||||
_ = util.OverwriteRequestPath(hunyuanRequestPath)
|
||||
|
||||
// 添加hunyuan需要的自定义字段
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue)
|
||||
|
||||
// 删除一些字段
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
|
||||
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
||||
|
||||
// 添加 hunyuan 需要的自定义字段
|
||||
headers.Add(actionKey, hunyuanChatCompletionTCAction)
|
||||
headers.Add(versionKey, versionValue)
|
||||
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
||||
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
@@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
// 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳)
|
||||
var timestamp int64 = time.Now().Unix()
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
|
||||
// log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp)
|
||||
|
||||
// 使用混元本身接口的协议
|
||||
if m.config.protocol == protocolOriginal {
|
||||
@@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
// log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!")
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
@@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
string(body),
|
||||
)
|
||||
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
|
||||
// log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body))
|
||||
|
||||
// // 打印所有的headers
|
||||
// headers, err2 := proxywasm.GetHttpRequestHeaders()
|
||||
// if err2 != nil {
|
||||
// log.Errorf("failed to get request headers: %v", err2)
|
||||
// } else {
|
||||
// // 迭代并打印所有请求头
|
||||
// for _, header := range headers {
|
||||
// log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1])
|
||||
// }
|
||||
// }
|
||||
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
|
||||
}
|
||||
|
||||
@@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
|
||||
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
|
||||
|
||||
var timestamp int64 = time.Now().Unix()
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
|
||||
// 根据确定好的payload进行签名:
|
||||
body, _ = json.Marshal(hunyuanRequest)
|
||||
authorizedValueNew := GetTC3Authorizationcode(
|
||||
m.config.hunyuanAuthId,
|
||||
m.config.hunyuanAuthKey,
|
||||
timestamp,
|
||||
hunyuanDomain,
|
||||
hunyuanChatCompletionTCAction,
|
||||
string(body),
|
||||
)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew)
|
||||
return json.Marshal(hunyuanRequest)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
@@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64,
|
||||
// fmt.Println(curl)
|
||||
return authorization
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) GetApiName(path string) ApiName {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
@@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(minimaxDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, minimaxDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
@@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
return m.handleRequestBodyByChatCompletionPro(body, log)
|
||||
} else {
|
||||
// 使用ChatCompletion v2接口
|
||||
return m.handleRequestBodyByChatCompletionV2(body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
// 使用minimax接口协议
|
||||
if m.config.protocol == protocolOriginal {
|
||||
request := &minimaxChatCompletionV2Request{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 根据模型重写requestPath
|
||||
if m.config.minimaxGroupId == "" {
|
||||
return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
|
||||
}
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
|
||||
if m.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
m.setBotSettings(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
@@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
// 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
|
||||
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
|
||||
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
|
||||
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
|
||||
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
|
||||
@@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
_ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
|
||||
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
|
||||
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re
|
||||
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,9 +41,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(mistralDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -53,28 +49,11 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, mistralDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -3,13 +3,12 @@ package provider
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// moonshotProvider is the provider for Moonshot AI service.
|
||||
@@ -58,33 +57,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(moonshotChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(moonshotDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, moonshotDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法
|
||||
// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
|
||||
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
|
||||
if m.config.moonshotFileId == "" && m.contextCache == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
|
||||
@@ -3,11 +3,10 @@ package provider
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ollamaProvider is the provider for Ollama service.
|
||||
@@ -53,10 +52,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(ollamaChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(m.serviceDomain)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -64,51 +60,11 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
if m.config.modelMapping == nil && m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
|
||||
if m.contextCache != nil {
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
} else {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
} else {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string {
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
if m.customPath == "" {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
_ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath)
|
||||
util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
|
||||
case ApiNameEmbeddings:
|
||||
ctx.DontReadRequestBody()
|
||||
_ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath)
|
||||
util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
|
||||
}
|
||||
} else {
|
||||
_ = util.OverwriteRequestPath(m.customPath)
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
}
|
||||
if m.customDomain == "" {
|
||||
_ = util.OverwriteRequestHost(defaultOpenaiDomain)
|
||||
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
|
||||
} else {
|
||||
_ = util.OverwriteRequestHost(m.customDomain)
|
||||
util.OverwriteRequestHostHeader(headers, m.customDomain)
|
||||
}
|
||||
if len(m.config.apiTokens) > 0 {
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
// We don't need to process the request body for other APIs.
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
return nil, err
|
||||
}
|
||||
if m.config.responseJsonSchema != nil {
|
||||
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
|
||||
@@ -101,27 +110,5 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
request.StreamOptions.IncludeUsage = true
|
||||
}
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type ApiName string
|
||||
@@ -110,14 +113,32 @@ type Provider interface {
|
||||
GetProviderType() string
|
||||
}
|
||||
|
||||
type ApiNameHandler interface {
|
||||
GetApiName(path string) ApiName
|
||||
}
|
||||
|
||||
type RequestHeadersHandler interface {
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||||
}
|
||||
|
||||
type TransformRequestHeadersHandler interface {
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
}
|
||||
|
||||
type RequestBodyHandler interface {
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||
}
|
||||
|
||||
type TransformRequestBodyHandler interface {
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
|
||||
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
|
||||
type TransformRequestBodyHeadersHandler interface {
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
type ResponseHeadersHandler interface {
|
||||
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||||
}
|
||||
@@ -143,6 +164,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 请求超时
|
||||
// @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟
|
||||
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
|
||||
// @Title zh-CN apiToken 故障切换
|
||||
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
|
||||
failover *failover `required:"false" yaml:"failover" json:"failover"`
|
||||
// @Title zh-CN 基于OpenAI协议的自定义后端URL
|
||||
// @Description zh-CN 仅适用于支持 openai 协议的服务。
|
||||
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
|
||||
@@ -289,6 +313,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
failoverJson := json.Get("failover")
|
||||
c.failover = &failover{
|
||||
enabled: false,
|
||||
}
|
||||
if failoverJson.Exists() {
|
||||
c.failover.FromJson(failoverJson)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -304,6 +336,12 @@ func (c *ProviderConfig) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
if c.failover.enabled {
|
||||
if err := c.failover.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.typ == "" {
|
||||
return errors.New("missing type in provider config")
|
||||
}
|
||||
@@ -355,6 +393,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
if err := decodeChatCompletionRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streaming := req.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
case *embeddingsRequest:
|
||||
if err := decodeEmbeddingsRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
|
||||
var model *string
|
||||
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
model = &req.Model
|
||||
case *embeddingsRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
return c.mapModel(ctx, model, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
|
||||
if *model == "" {
|
||||
return errors.New("missing model in request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
|
||||
|
||||
mappedModel := getMappedModel(*model, c.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
|
||||
*model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, *model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
mappedModel := doGetMappedModel(model, modelMapping, log)
|
||||
if len(mappedModel) != 0 {
|
||||
@@ -391,3 +483,62 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||||
) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.protocol == protocolOriginal {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// use openai protocol
|
||||
var err error
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalHttpHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||||
util.ReplaceOriginalHttpHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if c.context == nil {
|
||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||||
}
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||||
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
originalHeaders := util.GetOriginalHttpHeaders()
|
||||
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
|
||||
util.ReplaceOriginalHttpHeaders(originalHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
var request interface{}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
request = &chatCompletionRequest{}
|
||||
} else {
|
||||
request = &embeddingsRequest{}
|
||||
}
|
||||
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -58,35 +59,50 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide
|
||||
}
|
||||
|
||||
type qwenProvider struct {
|
||||
config ProviderConfig
|
||||
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, qwenDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if m.config.qwenEnableCompatible {
|
||||
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
|
||||
} else if apiName == ApiNameChatCompletion {
|
||||
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
|
||||
}
|
||||
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
} else {
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) GetProviderType() string {
|
||||
return providerTypeQwen
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = util.OverwriteRequestHost(qwenDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue, nil
|
||||
} else if m.config.qwenEnableCompatible {
|
||||
_ = util.OverwriteRequestPath(qwenCompatiblePath)
|
||||
} else if apiName == ApiNameChatCompletion {
|
||||
_ = util.OverwriteRequestPath(qwenChatCompletionPath)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
_ = util.OverwriteRequestPath(qwenTextEmbeddingPath)
|
||||
} else {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
@@ -121,65 +137,23 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
}
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionRequestBody(ctx, body, log)
|
||||
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
if m.config.context == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
request := &qwenTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
m.insertContextMessage(request, content, false)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
// Use the qwen multimodal model generation API
|
||||
if strings.HasPrefix(request.Model, qwenVlModelPrefixName) {
|
||||
_ = util.OverwriteRequestPath(qwenMultimodalGenerationPath)
|
||||
util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath)
|
||||
}
|
||||
|
||||
streaming := request.Stream
|
||||
@@ -191,62 +165,20 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
|
||||
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
|
||||
}
|
||||
|
||||
if m.config.context == nil {
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
|
||||
if streaming {
|
||||
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
|
||||
}
|
||||
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
|
||||
if streaming {
|
||||
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
|
||||
}
|
||||
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("=== embeddings request: %v", request)
|
||||
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in the request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
|
||||
} else {
|
||||
return types.ActionContinue, err
|
||||
qwenRequest, err := m.buildQwenTextEmbeddingRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(qwenRequest)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -375,7 +307,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
|
||||
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
|
||||
messages := make([]qwenMessage, 0, len(origRequest.Messages))
|
||||
for i := range origRequest.Messages {
|
||||
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
|
||||
@@ -397,6 +329,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
|
||||
Tools: origRequest.Tools,
|
||||
},
|
||||
}
|
||||
|
||||
if streaming {
|
||||
ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput)
|
||||
}
|
||||
|
||||
if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName {
|
||||
builder := strings.Builder{}
|
||||
for _, fileId := range m.config.qwenFileIds {
|
||||
@@ -406,13 +343,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
|
||||
builder.WriteString("fileid://")
|
||||
builder.WriteString(fileId)
|
||||
}
|
||||
contextMessageId := m.insertContextMessage(request, builder.String(), true)
|
||||
if contextMessageId == 0 {
|
||||
// The context message cannot come first. We need to add another dummy system message before it.
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to marshal request: %v", err)
|
||||
}
|
||||
|
||||
return m.insertHttpContextMessage(body, builder.String(), true)
|
||||
}
|
||||
return request
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
|
||||
@@ -569,7 +508,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int {
|
||||
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
|
||||
request := &qwenTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
fileMessage := qwenMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
@@ -586,10 +530,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
|
||||
}
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
|
||||
return 0
|
||||
} else if !onlyOneSystemBeforeFile {
|
||||
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
return firstNonSystemMessageIndex
|
||||
} else {
|
||||
builder := strings.Builder{}
|
||||
for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] {
|
||||
@@ -599,8 +541,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
|
||||
builder.WriteString(message.StringContent())
|
||||
}
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
|
||||
return 1
|
||||
firstNonSystemMessageIndex = 1
|
||||
}
|
||||
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
// The context message cannot come first. We need to add another dummy system message before it.
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
|
||||
@@ -804,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case strings.Contains(path, qwenChatCompletionPath),
|
||||
strings.Contains(path, qwenMultimodalGenerationPath),
|
||||
strings.Contains(path, qwenCompatiblePath):
|
||||
return ApiNameChatCompletion
|
||||
case strings.Contains(path, qwenTextEmbeddingPath):
|
||||
return ApiNameEmbeddings
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package provider
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
@@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
@@ -31,6 +37,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
|
||||
log.Debugf("request body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to replace the original request body: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertContextMessage(request *chatCompletionRequest, content string) {
|
||||
fileMessage := chatMessage{
|
||||
Role: roleSystem,
|
||||
|
||||
@@ -2,8 +2,8 @@ package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestHost(sparkHost)
|
||||
_ = util.OverwriteRequestPath(sparkChatCompletionPath)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
p.config.handleRequestHeaders(p, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 使用Spark协议
|
||||
if p.config.protocol == protocolOriginal {
|
||||
request := &sparkRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("request model is empty")
|
||||
}
|
||||
// 目前星火在模型名称错误时,也会调用generalv3,这里还是按照输入的模型名称设置响应里的模型名称
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
} else {
|
||||
// 使用openai协议
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
if request.Model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
// 映射模型
|
||||
mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
request.Model = mappedModel
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
}
|
||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
@@ -205,3 +172,11 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp
|
||||
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, sparkHost)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -45,10 +43,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(stepfunChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(stepfunDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -56,28 +51,12 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, stepfunDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -45,10 +44,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(yiChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(yiDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -56,28 +52,12 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, yiChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, yiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -44,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(zhipuAiChatCompletionPath)
|
||||
_ = util.OverwriteRequestHost(zhipuAiDomain)
|
||||
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -55,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if m.contextCache == nil {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package util
|
||||
|
||||
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderContentType = "Content-Type"
|
||||
@@ -21,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string {
|
||||
return headers
|
||||
}
|
||||
|
||||
func OverwriteRequestHost(host string) error {
|
||||
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost)
|
||||
}
|
||||
return proxywasm.ReplaceHttpRequestHeader(":authority", host)
|
||||
}
|
||||
|
||||
func OverwriteRequestPath(path string) error {
|
||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath)
|
||||
@@ -43,3 +40,56 @@ func OverwriteRequestAuthorization(credential string) error {
|
||||
}
|
||||
return proxywasm.ReplaceHttpRequestHeader("Authorization", credential)
|
||||
}
|
||||
|
||||
func OverwriteRequestHostHeader(headers http.Header, host string) {
|
||||
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
|
||||
headers.Set("X-ENVOY-ORIGINAL-HOST", originHost)
|
||||
}
|
||||
headers.Set(":authority", host)
|
||||
}
|
||||
|
||||
func OverwriteRequestPathHeader(headers http.Header, path string) {
|
||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
||||
}
|
||||
headers.Set(":path", path)
|
||||
}
|
||||
|
||||
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
|
||||
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
|
||||
if originAuth := headers.Get("Authorization"); originAuth != "" {
|
||||
headers.Set("X-HI-ORIGINAL-AUTH", originAuth)
|
||||
}
|
||||
}
|
||||
headers.Set("Authorization", credential)
|
||||
}
|
||||
|
||||
func HeaderToSlice(header http.Header) [][2]string {
|
||||
slice := make([][2]string, 0, len(header))
|
||||
for key, values := range header {
|
||||
for _, value := range values {
|
||||
slice = append(slice, [2]string{key, value})
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func SliceToHeader(slice [][2]string) http.Header {
|
||||
header := make(http.Header)
|
||||
for _, pair := range slice {
|
||||
key := pair[0]
|
||||
value := pair[1]
|
||||
header.Add(key, value)
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
func GetOriginalHttpHeaders() http.Header {
|
||||
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
|
||||
return SliceToHeader(originalHeaders)
|
||||
}
|
||||
|
||||
func ReplaceOriginalHttpHeaders(headers http.Header) {
|
||||
modifiedHeaders := HeaderToSlice(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,19 @@ func (c RouteCluster) HostName() string {
|
||||
return GetRequestHost()
|
||||
}
|
||||
|
||||
type TargetCluster struct {
|
||||
Host string
|
||||
Cluster string
|
||||
}
|
||||
|
||||
func (c TargetCluster) ClusterName() string {
|
||||
return c.Cluster
|
||||
}
|
||||
|
||||
func (c TargetCluster) HostName() string {
|
||||
return c.Host
|
||||
}
|
||||
|
||||
type K8sCluster struct {
|
||||
ServiceName string
|
||||
Namespace string
|
||||
|
||||
Reference in New Issue
Block a user