mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
feat: retry failed request (#1590)
This commit is contained in:
@@ -41,6 +41,7 @@ description: AI 代理插件配置参考
|
|||||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||||
|
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||||
|
|
||||||
`context`的配置字段说明如下:
|
`context`的配置字段说明如下:
|
||||||
|
|
||||||
@@ -79,13 +80,21 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
|
|||||||
`failover` 的配置字段说明如下:
|
`failover` 的配置字段说明如下:
|
||||||
|
|
||||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|------------------|--------|------|-------|-----------------------------|
|
|------------------|--------|-----------------|-------|-----------------------------|
|
||||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||||
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
|
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
|
||||||
|
|
||||||
|
`retryOnFailure` 的配置字段说明如下:
|
||||||
|
|
||||||
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|
|------------------|--------|-----------------|-------|-------------|
|
||||||
|
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
|
||||||
|
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
|
||||||
|
| retryTimeout | int | 非必填 | 5000 | 重试超时时间,单位毫秒 |
|
||||||
|
|
||||||
### 提供商特有配置
|
### 提供商特有配置
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
pluginName = "ai-proxy"
|
pluginName = "ai-proxy"
|
||||||
|
|
||||||
ctxKeyApiName = "apiName"
|
|
||||||
|
|
||||||
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
|
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.SetContext(provider.CtxKeyApiName, apiName)
|
||||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||||
ctx.DisableReroute()
|
ctx.DisableReroute()
|
||||||
|
|
||||||
ctx.SetContext(ctxKeyApiName, apiName)
|
|
||||||
|
|
||||||
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
|
|
||||||
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
||||||
if needHandleBody || needHandleStreamingBody {
|
if needHandleStreamingBody {
|
||||||
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
|
|
||||||
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
||||||
if settingErr != nil {
|
if settingErr != nil {
|
||||||
@@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
|||||||
log.Errorf("unable to load :status header from response: %v", err)
|
log.Errorf("unable to load :status header from response: %v", err)
|
||||||
}
|
}
|
||||||
ctx.DontReadResponseBody()
|
ctx.DontReadResponseBody()
|
||||||
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
|
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
|
||||||
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
||||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
|
headers := util.GetOriginalResponseHeaders()
|
||||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
||||||
action, err := handler.OnResponseHeaders(ctx, apiName, log)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
if err == nil {
|
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||||
checkStream(&ctx, log)
|
} else {
|
||||||
return action
|
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
||||||
}
|
|
||||||
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
}
|
||||||
|
util.ReplaceResponseHeaders(headers)
|
||||||
|
|
||||||
checkStream(&ctx, log)
|
checkStream(&ctx, log)
|
||||||
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
|
|
||||||
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
||||||
if !needHandleBody && !needHandleStreamingBody {
|
if !needHandleStreamingBody {
|
||||||
ctx.DontReadResponseBody()
|
|
||||||
} else if !needHandleStreamingBody {
|
|
||||||
ctx.BufferResponseBody()
|
ctx.BufferResponseBody()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
|||||||
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
||||||
if err == nil && modifiedChunk != nil {
|
if err == nil && modifiedChunk != nil {
|
||||||
return modifiedChunk
|
return modifiedChunk
|
||||||
@@ -249,17 +239,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||||
//log.Debugf("response body: %s", string(body))
|
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
|
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
action, err := handler.OnResponseBody(ctx, apiName, body, log)
|
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
return action
|
|
||||||
}
|
|
||||||
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
||||||
|
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
|
|||||||
return json.Marshal(claudeRequest)
|
return json.Marshal(claudeRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
claudeResponse := &claudeTextGenResponse{}
|
claudeResponse := &claudeTextGenResponse{}
|
||||||
if err := json.Unmarshal(body, claudeResponse); err != nil {
|
if err := json.Unmarshal(body, claudeResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
||||||
}
|
}
|
||||||
if claudeResponse.Error != nil {
|
if claudeResponse.Error != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
|
return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
|
||||||
}
|
}
|
||||||
response := c.responseClaude2OpenAI(ctx, claudeResponse)
|
response := c.responseClaude2OpenAI(ctx, claudeResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
|
||||||
|
|
||||||
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
||||||
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
|
|
||||||
if c.config.protocol == protocolOriginal {
|
|
||||||
ctx.DontReadResponseBody()
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
|
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
|
||||||
}
|
}
|
||||||
if err := replaceHttpJsonRequestBody(body, log); err != nil {
|
if err := replaceRequestBody(body, log); err != nil {
|
||||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
|
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
|
|||||||
return json.Marshal(baiduRequest)
|
return json.Marshal(baiduRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
|
||||||
deeplResponse := &deeplResponse{}
|
deeplResponse := &deeplResponse{}
|
||||||
if err := json.Unmarshal(body, deeplResponse); err != nil {
|
if err := json.Unmarshal(body, deeplResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
|
||||||
}
|
}
|
||||||
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
|
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {
|
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
type failover struct {
|
type failover struct {
|
||||||
// @Title zh-CN 是否启用 apiToken 的 failover 机制
|
// @Title zh-CN 是否启用 apiToken 的 failover 机制
|
||||||
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
|
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
|
||||||
// @Title zh-CN 触发 failover 连续请求失败的阈值
|
// @Title zh-CN 触发 failover 连续请求失败的阈值
|
||||||
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
|
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
|
||||||
// @Title zh-CN 健康检测的成功阈值
|
// @Title zh-CN 健康检测的成功阈值
|
||||||
@@ -29,7 +29,7 @@ type failover struct {
|
|||||||
// @Title zh-CN 健康检测的超时时间,单位毫秒
|
// @Title zh-CN 健康检测的超时时间,单位毫秒
|
||||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||||
// @Title zh-CN 健康检测使用的模型
|
// @Title zh-CN 健康检测使用的模型
|
||||||
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
|
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||||
// @Title zh-CN 本次请求使用的 apiToken
|
// @Title zh-CN 本次请求使用的 apiToken
|
||||||
ctxApiTokenInUse string
|
ctxApiTokenInUse string
|
||||||
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
|
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
|
||||||
@@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
|
|||||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||||
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
||||||
headers := util.GetOriginalHttpHeaders()
|
headers := util.GetOriginalRequestHeaders()
|
||||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
|
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
|
||||||
util.ReplaceOriginalHttpHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
} else {
|
} else {
|
||||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||||
}
|
}
|
||||||
@@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() {
|
|||||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
|
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
|
||||||
if c.isFailoverEnabled() {
|
if c.isFailoverEnabled() {
|
||||||
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
||||||
}
|
}
|
||||||
|
if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
|
||||||
|
c.retryFailedRequest(activeProvider, ctx, log)
|
||||||
|
return types.HeaderStopAllIterationAndWatermark
|
||||||
|
}
|
||||||
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||||
@@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
|
|||||||
} else {
|
} else {
|
||||||
apiToken = c.GetRandomToken()
|
apiToken = c.GetRandomToken()
|
||||||
}
|
}
|
||||||
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
|
log.Debugf("Use apiToken %s to send request", apiToken)
|
||||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -105,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
|||||||
return json.Marshal(geminiRequest)
|
return json.Marshal(geminiRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
||||||
if g.config.protocol == protocolOriginal {
|
|
||||||
ctx.DontReadResponseBody()
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||||
log.Infof("chunk body:%s", string(chunk))
|
log.Infof("chunk body:%s", string(chunk))
|
||||||
if isLastChunk || len(chunk) == 0 {
|
if isLastChunk || len(chunk) == 0 {
|
||||||
@@ -148,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
|||||||
return []byte(modifiedResponseChunk), nil
|
return []byte(modifiedResponseChunk), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
if apiName == ApiNameChatCompletion {
|
if apiName == ApiNameChatCompletion {
|
||||||
return g.onChatCompletionResponseBody(ctx, body, log)
|
return g.onChatCompletionResponseBody(ctx, body, log)
|
||||||
} else if apiName == ApiNameEmbeddings {
|
} else {
|
||||||
return g.onEmbeddingsResponseBody(ctx, body, log)
|
return g.onEmbeddingsResponseBody(ctx, body, log)
|
||||||
}
|
}
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
geminiResponse := &geminiChatResponse{}
|
geminiResponse := &geminiChatResponse{}
|
||||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||||
}
|
}
|
||||||
if geminiResponse.Error != nil {
|
if geminiResponse.Error != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
|
return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
|
||||||
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
||||||
}
|
}
|
||||||
response := g.buildChatCompletionResponse(ctx, geminiResponse)
|
response := g.buildChatCompletionResponse(ctx, geminiResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
geminiResponse := &geminiEmbeddingResponse{}
|
geminiResponse := &geminiEmbeddingResponse{}
|
||||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
||||||
}
|
}
|
||||||
if geminiResponse.Error != nil {
|
if geminiResponse.Error != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
|
return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
|
||||||
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
||||||
}
|
}
|
||||||
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
|
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
|
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
|
||||||
|
|||||||
@@ -288,11 +288,6 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
|||||||
return json.Marshal(hunyuanRequest)
|
return json.Marshal(hunyuanRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||||
if m.config.protocol == protocolOriginal {
|
if m.config.protocol == protocolOriginal {
|
||||||
return chunk, nil
|
return chunk, nil
|
||||||
@@ -409,21 +404,14 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
|||||||
return []byte(openAIChunk.String()), nil
|
return []byte(openAIChunk.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
|
|
||||||
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
|
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
|
||||||
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
|
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
|
||||||
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
|
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.config.protocol == protocolOriginal {
|
|
||||||
return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
|
response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
|
||||||
|
return json.Marshal(response)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
||||||
|
|||||||
@@ -144,19 +144,16 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, heade
|
|||||||
return sjson.SetBytes(body, "model", mappedModel)
|
return sjson.SetBytes(body, "model", mappedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
|
||||||
// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
|
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||||
if m.config.protocol == protocolOriginal {
|
if m.config.protocol == protocolOriginal {
|
||||||
ctx.DontReadResponseBody()
|
ctx.DontReadResponseBody()
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface.
|
// Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface.
|
||||||
if minimaxApiTypePro != m.config.minimaxApiType {
|
if minimaxApiTypePro != m.config.minimaxApiType {
|
||||||
ctx.DontReadResponseBody()
|
ctx.DontReadResponseBody()
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
}
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||||
@@ -196,16 +193,16 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||||
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
minimaxResp := &minimaxChatCompletionV2Resp{}
|
minimaxResp := &minimaxChatCompletionV2Resp{}
|
||||||
if err := json.Unmarshal(body, minimaxResp); err != nil {
|
if err := json.Unmarshal(body, minimaxResp); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
||||||
}
|
}
|
||||||
if minimaxResp.BaseResp.StatusCode != 0 {
|
if minimaxResp.BaseResp.StatusCode != 0 {
|
||||||
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
return nil, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
||||||
}
|
}
|
||||||
response := m.responseV2ToOpenAI(minimaxResp)
|
response := m.responseV2ToOpenAI(minimaxResp)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request.
|
// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request.
|
||||||
|
|||||||
@@ -59,7 +59,9 @@ const (
|
|||||||
finishReasonLength = "length"
|
finishReasonLength = "length"
|
||||||
|
|
||||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||||
ctxKeyApiName = "apiKey"
|
ctxKeyApiKey = "apiKey"
|
||||||
|
CtxKeyApiName = "apiName"
|
||||||
|
ctxKeyIsStreaming = "isStreaming"
|
||||||
ctxKeyStreamingBody = "streamingBody"
|
ctxKeyStreamingBody = "streamingBody"
|
||||||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||||||
ctxKeyFinalRequestModel = "finalRequestModel"
|
ctxKeyFinalRequestModel = "finalRequestModel"
|
||||||
@@ -115,22 +117,26 @@ type Provider interface {
|
|||||||
GetProviderType() string
|
GetProviderType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiNameHandler interface {
|
|
||||||
GetApiName(path string) ApiName
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestHeadersHandler interface {
|
type RequestHeadersHandler interface {
|
||||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
|
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransformRequestHeadersHandler interface {
|
|
||||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestBodyHandler interface {
|
type RequestBodyHandler interface {
|
||||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamingResponseBodyHandler interface {
|
||||||
|
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ApiNameHandler interface {
|
||||||
|
GetApiName(path string) ApiName
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransformRequestHeadersHandler interface {
|
||||||
|
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||||
|
}
|
||||||
|
|
||||||
type TransformRequestBodyHandler interface {
|
type TransformRequestBodyHandler interface {
|
||||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||||
}
|
}
|
||||||
@@ -141,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface {
|
|||||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseHeadersHandler interface {
|
type TransformResponseHeadersHandler interface {
|
||||||
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamingResponseBodyHandler interface {
|
type TransformResponseBodyHandler interface {
|
||||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseBodyHandler interface {
|
|
||||||
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TickFuncHandler allows the provider to execute a function periodically
|
// TickFuncHandler allows the provider to execute a function periodically
|
||||||
@@ -175,6 +177,9 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN apiToken 故障切换
|
// @Title zh-CN apiToken 故障切换
|
||||||
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
|
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
|
||||||
failover *failover `required:"false" yaml:"failover" json:"failover"`
|
failover *failover `required:"false" yaml:"failover" json:"failover"`
|
||||||
|
// @Title zh-CN 失败请求重试
|
||||||
|
// @Description zh-CN 对失败的请求立即进行重试
|
||||||
|
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
|
||||||
// @Title zh-CN 基于OpenAI协议的自定义后端URL
|
// @Title zh-CN 基于OpenAI协议的自定义后端URL
|
||||||
// @Description zh-CN 仅适用于支持 openai 协议的服务。
|
// @Description zh-CN 仅适用于支持 openai 协议的服务。
|
||||||
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
|
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
|
||||||
@@ -352,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
c.failover.FromJson(failoverJson)
|
c.failover.FromJson(failoverJson)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retryOnFailureJson := json.Get("retryOnFailure")
|
||||||
|
c.retryOnFailure = &retryOnFailure{
|
||||||
|
enabled: false,
|
||||||
|
}
|
||||||
|
if retryOnFailureJson.Exists() {
|
||||||
|
c.retryOnFailure.FromJson(retryOnFailureJson)
|
||||||
|
}
|
||||||
|
|
||||||
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
||||||
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
||||||
}
|
}
|
||||||
@@ -399,10 +412,10 @@ func (c *ProviderConfig) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||||||
ctxApiKey := ctx.GetContext(ctxKeyApiName)
|
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
|
||||||
if ctxApiKey == nil {
|
if ctxApiKey == nil {
|
||||||
ctxApiKey = c.GetRandomToken()
|
ctxApiKey = c.GetRandomToken()
|
||||||
ctx.SetContext(ctxKeyApiName, ctxApiKey)
|
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
|
||||||
}
|
}
|
||||||
return ctxApiKey.(string)
|
return ctxApiKey.(string)
|
||||||
}
|
}
|
||||||
@@ -446,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
|||||||
streaming := req.Stream
|
streaming := req.Stream
|
||||||
if streaming {
|
if streaming {
|
||||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||||
|
ctx.SetContext(ctxKeyIsStreaming, true)
|
||||||
|
} else {
|
||||||
|
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.setRequestModel(ctx, req, log)
|
return c.setRequestModel(ctx, req, log)
|
||||||
@@ -540,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||||
headers := util.GetOriginalHttpHeaders()
|
headers := util.GetOriginalRequestHeaders()
|
||||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||||||
util.ReplaceOriginalHttpHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
} else {
|
} else {
|
||||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||||
}
|
}
|
||||||
@@ -551,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
return types.ActionContinue, err
|
return types.ActionContinue, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
|
||||||
|
if c.isRetryOnFailureEnabled() {
|
||||||
|
ctx.SetContext(ctxRequestBody, body)
|
||||||
|
}
|
||||||
|
|
||||||
if apiName == ApiNameChatCompletion {
|
if apiName == ApiNameChatCompletion {
|
||||||
if c.context == nil {
|
if c.context == nil {
|
||||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
return types.ActionContinue, replaceRequestBody(body, log)
|
||||||
}
|
}
|
||||||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||||||
|
|
||||||
@@ -562,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
}
|
}
|
||||||
return types.ActionContinue, err
|
return types.ActionContinue, err
|
||||||
}
|
}
|
||||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
return types.ActionContinue, replaceRequestBody(body, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||||||
|
headers := util.GetOriginalRequestHeaders()
|
||||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||||
originalHeaders := util.GetOriginalHttpHeaders()
|
handler.TransformRequestHeaders(ctx, apiName, headers, log)
|
||||||
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
|
util.ReplaceRequestHeaders(headers)
|
||||||
util.ReplaceOriginalHttpHeaders(originalHeaders)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
|||||||
}
|
}
|
||||||
return json.Marshal(request)
|
return json.Marshal(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||||
|
if c.protocol == protocolOriginal {
|
||||||
|
ctx.DontReadResponseBody()
|
||||||
|
} else {
|
||||||
|
headers.Del("Content-Length")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -183,16 +183,6 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
|
|||||||
return json.Marshal(qwenRequest)
|
return json.Marshal(qwenRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
||||||
if m.config.protocol == protocolOriginal {
|
|
||||||
ctx.DontReadResponseBody()
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||||
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
||||||
return chunk, nil
|
return chunk, nil
|
||||||
@@ -278,9 +268,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
|||||||
return []byte(modifiedResponseChunk), nil
|
return []byte(modifiedResponseChunk), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
if m.config.qwenEnableCompatible {
|
if m.config.qwenEnableCompatible {
|
||||||
return types.ActionContinue, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
if apiName == ApiNameChatCompletion {
|
if apiName == ApiNameChatCompletion {
|
||||||
return m.onChatCompletionResponseBody(ctx, body, log)
|
return m.onChatCompletionResponseBody(ctx, body, log)
|
||||||
@@ -288,25 +278,25 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
if apiName == ApiNameEmbeddings {
|
if apiName == ApiNameEmbeddings {
|
||||||
return m.onEmbeddingsResponseBody(ctx, body, log)
|
return m.onEmbeddingsResponseBody(ctx, body, log)
|
||||||
}
|
}
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
return nil, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
qwenResponse := &qwenTextGenResponse{}
|
qwenResponse := &qwenTextGenResponse{}
|
||||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||||
}
|
}
|
||||||
response := m.buildChatCompletionResponse(ctx, qwenResponse)
|
response := m.buildChatCompletionResponse(ctx, qwenResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
qwenResponse := &qwenTextEmbeddingResponse{}
|
qwenResponse := &qwenTextEmbeddingResponse{}
|
||||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||||
}
|
}
|
||||||
response := m.buildEmbeddingsResponse(ctx, qwenResponse)
|
response := m.buildEmbeddingsResponse(ctx, qwenResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
|
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
|
func replaceRequestBody(body []byte, log wrapper.Log) error {
|
||||||
log.Debugf("request body: %s", string(body))
|
log.Debugf("request body: %s", string(body))
|
||||||
err := proxywasm.ReplaceHttpRequestBody(body)
|
err := proxywasm.ReplaceHttpRequestBody(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -65,15 +65,11 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceJsonResponseBody(response interface{}, log wrapper.Log) error {
|
func ReplaceResponseBody(body []byte, log wrapper.Log) error {
|
||||||
body, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to marshal response: %v", err)
|
|
||||||
}
|
|
||||||
log.Debugf("response body: %s", string(body))
|
log.Debugf("response body: %s", string(body))
|
||||||
err = proxywasm.ReplaceHttpResponseBody(body)
|
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to replace the original response body: %v", err)
|
return fmt.Errorf("unable to replace the original response body: %v", err)
|
||||||
}
|
}
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
141
plugins/wasm-go/extensions/ai-proxy/provider/retry.go
Normal file
141
plugins/wasm-go/extensions/ai-proxy/provider/retry.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ctxRequestBody = "requestBody"
|
||||||
|
ctxRetryCount = "retryCount"
|
||||||
|
)
|
||||||
|
|
||||||
|
type retryOnFailure struct {
|
||||||
|
// @Title zh-CN 是否启用请求重试
|
||||||
|
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
|
||||||
|
// @Title zh-CN 重试次数
|
||||||
|
maxRetries int64 `required:"false" yaml:"maxRetries" json:"maxRetries"`
|
||||||
|
// @Title zh-CN 重试超时时间
|
||||||
|
retryTimeout int64 `required:"false" yaml:"retryTimeout" json:"retryTimeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *retryOnFailure) FromJson(json gjson.Result) {
|
||||||
|
r.enabled = json.Get("enabled").Bool()
|
||||||
|
r.maxRetries = json.Get("maxRetries").Int()
|
||||||
|
if r.maxRetries == 0 {
|
||||||
|
r.maxRetries = 1
|
||||||
|
}
|
||||||
|
r.retryTimeout = json.Get("retryTimeout").Int()
|
||||||
|
if r.retryTimeout == 0 {
|
||||||
|
r.retryTimeout = 5000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) isRetryOnFailureEnabled() bool {
|
||||||
|
return c.retryOnFailure.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) {
|
||||||
|
log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType())
|
||||||
|
retryClient := createRetryClient(ctx)
|
||||||
|
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
|
||||||
|
ctx.SetContext(ctxRetryCount, 1)
|
||||||
|
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
|
||||||
|
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
|
||||||
|
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||||
|
} else {
|
||||||
|
c.DefaultTransformResponseHeaders(ctx, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
|
||||||
|
var err error
|
||||||
|
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to transform response body: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.HeaderToSlice(headers), body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) retryCall(
|
||||||
|
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
|
||||||
|
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
|
||||||
|
retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) {
|
||||||
|
|
||||||
|
retryCount := ctx.GetContext(ctxRetryCount).(int)
|
||||||
|
log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries)
|
||||||
|
|
||||||
|
if statusCode == 200 {
|
||||||
|
log.Debugf("Retry request succeeded")
|
||||||
|
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
|
||||||
|
proxywasm.SendHttpResponse(200, headers, body, -1)
|
||||||
|
} else {
|
||||||
|
log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
retryCount++
|
||||||
|
if retryCount <= int(c.retryOnFailure.maxRetries) {
|
||||||
|
ctx.SetContext(ctxRetryCount, retryCount)
|
||||||
|
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) sendRetryRequest(
|
||||||
|
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
|
||||||
|
retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) {
|
||||||
|
|
||||||
|
requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log)
|
||||||
|
path := getRetryPath(ctx)
|
||||||
|
|
||||||
|
err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient)
|
||||||
|
}, uint32(c.retryOnFailure.retryTimeout))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to send retry request: %v", err)
|
||||||
|
proxywasm.ResumeHttpResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRetryClient(ctx wrapper.HttpContext) *wrapper.ClusterClient[wrapper.RouteCluster] {
|
||||||
|
host := wrapper.GetRequestHost()
|
||||||
|
if host == "" {
|
||||||
|
host = ctx.GetContext(ctxRequestHost).(string)
|
||||||
|
}
|
||||||
|
retryClient := wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||||
|
Host: host,
|
||||||
|
})
|
||||||
|
return retryClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRetryPath(ctx wrapper.HttpContext) string {
|
||||||
|
path := wrapper.GetRequestPath()
|
||||||
|
if path == "" {
|
||||||
|
path = ctx.GetContext(ctxRequestPath).(string)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) {
|
||||||
|
// The retry request may be sent with different apiToken, so the header needs to be regenerated
|
||||||
|
c.SetApiTokenInUse(ctx, log)
|
||||||
|
|
||||||
|
requestHeaders := http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
}
|
||||||
|
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
|
||||||
|
handler.TransformRequestHeaders(ctx, apiName, requestHeaders, log)
|
||||||
|
}
|
||||||
|
requestBody := ctx.GetContext(ctxRequestBody).([]byte)
|
||||||
|
|
||||||
|
return requestHeaders, requestBody
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,21 +81,16 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
|
||||||
sparkResponse := &sparkResponse{}
|
sparkResponse := &sparkResponse{}
|
||||||
if err := json.Unmarshal(body, sparkResponse); err != nil {
|
if err := json.Unmarshal(body, sparkResponse); err != nil {
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err)
|
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
|
||||||
}
|
}
|
||||||
if sparkResponse.Code != 0 {
|
if sparkResponse.Code != 0 {
|
||||||
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
|
return nil, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
|
||||||
}
|
}
|
||||||
response := p.responseSpark2OpenAI(ctx, sparkResponse)
|
response := p.responseSpark2OpenAI(ctx, sparkResponse)
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
return json.Marshal(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||||
|
|||||||
@@ -86,12 +86,22 @@ func SliceToHeader(slice [][2]string) http.Header {
|
|||||||
return header
|
return header
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOriginalHttpHeaders() http.Header {
|
func GetOriginalRequestHeaders() http.Header {
|
||||||
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
|
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
|
||||||
return SliceToHeader(originalHeaders)
|
return SliceToHeader(originalHeaders)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReplaceOriginalHttpHeaders(headers http.Header) {
|
func GetOriginalResponseHeaders() http.Header {
|
||||||
|
originalHeaders, _ := proxywasm.GetHttpResponseHeaders()
|
||||||
|
return SliceToHeader(originalHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReplaceRequestHeaders(headers http.Header) {
|
||||||
modifiedHeaders := HeaderToSlice(headers)
|
modifiedHeaders := HeaderToSlice(headers)
|
||||||
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
|
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ReplaceResponseHeaders(headers http.Header) {
|
||||||
|
modifiedHeaders := HeaderToSlice(headers)
|
||||||
|
_ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user