feat: retry failed request (#1590)

This commit is contained in:
Se7en
2024-12-26 18:30:50 +08:00
committed by GitHub
parent 380717ae3d
commit 579c986915
15 changed files with 304 additions and 183 deletions

View File

@@ -59,7 +59,9 @@ const (
finishReasonLength = "length"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiName = "apiKey"
ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
@@ -115,22 +117,26 @@ type Provider interface {
GetProviderType() string
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
@@ -141,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
type TransformResponseHeadersHandler interface {
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type ResponseBodyHandler interface {
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
// TickFuncHandler allows the provider to execute a function periodically
@@ -175,6 +177,9 @@ type ProviderConfig struct {
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 失败请求重试
// @Description zh-CN 对失败的请求立即进行重试
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -352,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.failover.FromJson(failoverJson)
}
retryOnFailureJson := json.Get("retryOnFailure")
c.retryOnFailure = &retryOnFailure{
enabled: false,
}
if retryOnFailureJson.Exists() {
c.retryOnFailure.FromJson(retryOnFailureJson)
}
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
}
@@ -399,10 +412,10 @@ func (c *ProviderConfig) Validate() error {
}
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
ctxApiKey := ctx.GetContext(ctxKeyApiName)
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
if ctxApiKey == nil {
ctxApiKey = c.GetRandomToken()
ctx.SetContext(ctxKeyApiName, ctxApiKey)
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
}
return ctxApiKey.(string)
}
@@ -446,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
}
return c.setRequestModel(ctx, req, log)
@@ -540,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody(
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
}
@@ -551,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody(
return types.ActionContinue, err
}
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
if c.isRetryOnFailureEnabled() {
ctx.SetContext(ctxRequestBody, body)
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body, log)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
@@ -562,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody(
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body, log)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
headers := util.GetOriginalRequestHeaders()
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders()
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(originalHeaders)
handler.TransformRequestHeaders(ctx, apiName, headers, log)
util.ReplaceRequestHeaders(headers)
}
}
@@ -585,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
}
return json.Marshal(request)
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
if c.protocol == protocolOriginal {
ctx.DontReadResponseBody()
} else {
headers.Del("Content-Length")
}
}