feat: retry failed request (#1590)

This commit is contained in:
Se7en
2024-12-26 18:30:50 +08:00
committed by GitHub
parent 380717ae3d
commit 579c986915
15 changed files with 304 additions and 183 deletions

View File

@@ -41,6 +41,7 @@ description: AI 代理插件配置参考
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
`context`的配置字段说明如下:
@@ -78,14 +79,22 @@ custom-setting会遵循如下表格根据`name`和协议来替换对应的字
`failover` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|-----------------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
`retryOnFailure` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|-----------------|-------|-------------|
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
| retryTimeout | int | 非必填 | 5000 | 重试超时时间,单位毫秒 |
### 提供商特有配置

View File

@@ -20,8 +20,6 @@ import (
const (
pluginName = "ai-proxy"
ctxKeyApiName = "apiName"
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
)
@@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
return types.ActionContinue
}
ctx.SetContext(provider.CtxKeyApiName, apiName)
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
ctx.SetContext(ctxKeyApiName, apiName)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if needHandleBody || needHandleStreamingBody {
if needHandleStreamingBody {
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
}
@@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
if settingErr != nil {
@@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
return types.ActionContinue
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
}
// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
checkStream(&ctx, log)
return action
}
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
return types.ActionContinue
headers := util.GetOriginalResponseHeaders()
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
handler.TransformResponseHeaders(ctx, apiName, headers, log)
} else {
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
}
util.ReplaceResponseHeaders(headers)
checkStream(&ctx, log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody {
ctx.DontReadResponseBody()
} else if !needHandleStreamingBody {
if !needHandleStreamingBody {
ctx.BufferResponseBody()
}
@@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
if err == nil && modifiedChunk != nil {
return modifiedChunk
@@ -249,16 +239,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
}
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
//log.Debugf("response body: %s", string(body))
if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseBody(ctx, apiName, body, log)
if err == nil {
return action
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
if err != nil {
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(body, log); err != nil {
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
}
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
return types.ActionContinue
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -139,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
return json.Marshal(claudeRequest)
}
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
}
if claudeResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
}
response := c.responseClaude2OpenAI(ctx, claudeResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
if c.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
return json.Marshal(response)
}
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {

View File

@@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
if err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
}
if err := replaceHttpJsonRequestBody(body, log); err != nil {
if err := replaceRequestBody(body, log); err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -112,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
return json.Marshal(baiduRequest)
}
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
}
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {

View File

@@ -19,7 +19,7 @@ import (
type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
@@ -29,7 +29,7 @@ type failover struct {
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string
// @Title zh-CN 记录 apiToken 请求失败的次数key 为 apiTokenvalue 为失败次数
@@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
}
@@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
if c.isFailoverEnabled() {
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
}
if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
c.retryFailedRequest(activeProvider, ctx, log)
return types.HeaderStopAllIterationAndWatermark
}
return types.ActionContinue
}
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
@@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
} else {
apiToken = c.GetRandomToken()
}
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
log.Debugf("Use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}

View File

@@ -105,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if g.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
@@ -148,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return []byte(modifiedResponseChunk), nil
}
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
} else {
return g.onEmbeddingsResponseBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildChatCompletionResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {

View File

@@ -288,11 +288,6 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
return json.Marshal(hunyuanRequest)
}
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.protocol == protocolOriginal {
return chunk, nil
@@ -409,21 +404,14 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
return []byte(openAIChunk.String()), nil
}
func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
return nil, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
}
if m.config.protocol == protocolOriginal {
return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log)
}
response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {

View File

@@ -144,19 +144,16 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, heade
return sjson.SetBytes(body, "model", mappedModel)
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
// Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface.
if minimaxApiTypePro != m.config.minimaxApiType {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
@@ -196,16 +193,16 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
}
// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
minimaxResp := &minimaxChatCompletionV2Resp{}
if err := json.Unmarshal(body, minimaxResp); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err)
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
}
if minimaxResp.BaseResp.StatusCode != 0 {
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
return nil, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
}
response := m.responseV2ToOpenAI(minimaxResp)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request.

View File

@@ -59,7 +59,9 @@ const (
finishReasonLength = "length"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiName = "apiKey"
ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
@@ -115,22 +117,26 @@ type Provider interface {
GetProviderType() string
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
@@ -141,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
type TransformResponseHeadersHandler interface {
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type ResponseBodyHandler interface {
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
// TickFuncHandler allows the provider to execute a function periodically
@@ -175,6 +177,9 @@ type ProviderConfig struct {
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 失败请求重试
// @Description zh-CN 对失败的请求立即进行重试
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -352,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.failover.FromJson(failoverJson)
}
retryOnFailureJson := json.Get("retryOnFailure")
c.retryOnFailure = &retryOnFailure{
enabled: false,
}
if retryOnFailureJson.Exists() {
c.retryOnFailure.FromJson(retryOnFailureJson)
}
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
}
@@ -399,10 +412,10 @@ func (c *ProviderConfig) Validate() error {
}
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
ctxApiKey := ctx.GetContext(ctxKeyApiName)
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
if ctxApiKey == nil {
ctxApiKey = c.GetRandomToken()
ctx.SetContext(ctxKeyApiName, ctxApiKey)
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
}
return ctxApiKey.(string)
}
@@ -446,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
}
return c.setRequestModel(ctx, req, log)
@@ -540,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody(
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
}
@@ -551,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody(
return types.ActionContinue, err
}
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
if c.isRetryOnFailureEnabled() {
ctx.SetContext(ctxRequestBody, body)
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body, log)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
@@ -562,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody(
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body, log)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
headers := util.GetOriginalRequestHeaders()
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders()
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(originalHeaders)
handler.TransformRequestHeaders(ctx, apiName, headers, log)
util.ReplaceRequestHeaders(headers)
}
}
@@ -585,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
}
return json.Marshal(request)
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
if c.protocol == protocolOriginal {
ctx.DontReadResponseBody()
} else {
headers.Del("Content-Length")
}
}

View File

@@ -183,16 +183,6 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if m.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return chunk, nil
@@ -278,9 +268,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
return []byte(modifiedResponseChunk), nil
}
func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if m.config.qwenEnableCompatible {
return types.ActionContinue, nil
return body, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionResponseBody(ctx, body, log)
@@ -288,25 +278,25 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsResponseBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
return nil, errUnsupportedApiName
}
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
response := m.buildChatCompletionResponse(ctx, qwenResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
qwenResponse := &qwenTextEmbeddingResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
response := m.buildEmbeddingsResponse(ctx, qwenResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {

View File

@@ -37,7 +37,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err
}
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
func replaceRequestBody(body []byte, log wrapper.Log) error {
log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -65,15 +65,11 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
}
}
func replaceJsonResponseBody(response interface{}, log wrapper.Log) error {
body, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("unable to marshal response: %v", err)
}
func ReplaceResponseBody(body []byte, log wrapper.Log) error {
log.Debugf("response body: %s", string(body))
err = proxywasm.ReplaceHttpResponseBody(body)
err := proxywasm.ReplaceHttpResponseBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original response body: %v", err)
}
return err
return nil
}

View File

@@ -0,0 +1,141 @@
package provider
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
"net/http"
)
const (
ctxRequestBody = "requestBody"
ctxRetryCount = "retryCount"
)
type retryOnFailure struct {
// @Title zh-CN 是否启用请求重试
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 重试次数
maxRetries int64 `required:"false" yaml:"maxRetries" json:"maxRetries"`
// @Title zh-CN 重试超时时间
retryTimeout int64 `required:"false" yaml:"retryTimeout" json:"retryTimeout"`
}
func (r *retryOnFailure) FromJson(json gjson.Result) {
r.enabled = json.Get("enabled").Bool()
r.maxRetries = json.Get("maxRetries").Int()
if r.maxRetries == 0 {
r.maxRetries = 1
}
r.retryTimeout = json.Get("retryTimeout").Int()
if r.retryTimeout == 0 {
r.retryTimeout = 5000
}
}
func (c *ProviderConfig) isRetryOnFailureEnabled() bool {
return c.retryOnFailure.enabled
}
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) {
log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType())
retryClient := createRetryClient(ctx)
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
ctx.SetContext(ctxRetryCount, 1)
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
}
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
handler.TransformResponseHeaders(ctx, apiName, headers, log)
} else {
c.DefaultTransformResponseHeaders(ctx, headers)
}
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
var err error
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
if err != nil {
log.Errorf("Failed to transform response body: %v", err)
}
}
return util.HeaderToSlice(headers), body
}
func (c *ProviderConfig) retryCall(
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) {
retryCount := ctx.GetContext(ctxRetryCount).(int)
log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries)
if statusCode == 200 {
log.Debugf("Retry request succeeded")
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
proxywasm.SendHttpResponse(200, headers, body, -1)
} else {
log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody))
}
retryCount++
if retryCount <= int(c.retryOnFailure.maxRetries) {
ctx.SetContext(ctxRetryCount, retryCount)
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
} else {
log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries)
proxywasm.ResumeHttpResponse()
}
}
func (c *ProviderConfig) sendRetryRequest(
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) {
requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log)
path := getRetryPath(ctx)
err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient)
}, uint32(c.retryOnFailure.retryTimeout))
if err != nil {
log.Errorf("Failed to send retry request: %v", err)
proxywasm.ResumeHttpResponse()
}
}
func createRetryClient(ctx wrapper.HttpContext) *wrapper.ClusterClient[wrapper.RouteCluster] {
host := wrapper.GetRequestHost()
if host == "" {
host = ctx.GetContext(ctxRequestHost).(string)
}
retryClient := wrapper.NewClusterClient(wrapper.RouteCluster{
Host: host,
})
return retryClient
}
func getRetryPath(ctx wrapper.HttpContext) string {
path := wrapper.GetRequestPath()
if path == "" {
path = ctx.GetContext(ctxRequestPath).(string)
}
return path
}
func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) {
// The retry request may be sent with different apiToken, so the header needs to be regenerated
c.SetApiTokenInUse(ctx, log)
requestHeaders := http.Header{
"Content-Type": []string{"application/json"},
}
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, requestHeaders, log)
}
requestBody := ctx.GetContext(ctxRequestBody).([]byte)
return requestHeaders, requestBody
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -82,21 +81,16 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err)
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
}
if sparkResponse.Code != 0 {
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
return nil, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
}
response := p.responseSpark2OpenAI(ctx, sparkResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
return json.Marshal(response)
}
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {

View File

@@ -86,12 +86,22 @@ func SliceToHeader(slice [][2]string) http.Header {
return header
}
func GetOriginalHttpHeaders() http.Header {
func GetOriginalRequestHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
return SliceToHeader(originalHeaders)
}
func ReplaceOriginalHttpHeaders(headers http.Header) {
func GetOriginalResponseHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpResponseHeaders()
return SliceToHeader(originalHeaders)
}
func ReplaceRequestHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
}
func ReplaceResponseHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders)
}