feat: implement apiToken failover mechanism (#1256)

This commit is contained in:
Se7en
2024-11-16 19:03:09 +08:00
committed by GitHub
parent f2a5df3949
commit d24123a55f
33 changed files with 1558 additions and 1231 deletions

View File

@@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if apiName == "" && !providerConfig.IsOriginal() {
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}
if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
@@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
@@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
@@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
return types.ActionContinue
}
// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
@@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
@@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
(*ctx).BufferResponseBody()
}
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}