feat: implement apiToken failover mechanism (#1256)

This commit is contained in:
Se7en
2024-11-16 19:03:09 +08:00
committed by GitHub
parent f2a5df3949
commit d24123a55f
33 changed files with 1558 additions and 1231 deletions

View File

@@ -1,14 +1,17 @@
package provider
import (
"encoding/json"
"errors"
"math/rand"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type ApiName string
@@ -110,14 +113,32 @@ type Provider interface {
GetProviderType() string
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
@@ -143,6 +164,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -289,6 +313,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
}
failoverJson := json.Get("failover")
c.failover = &failover{
enabled: false,
}
if failoverJson.Exists() {
c.failover.FromJson(failoverJson)
}
}
func (c *ProviderConfig) Validate() error {
@@ -304,6 +336,12 @@ func (c *ProviderConfig) Validate() error {
}
}
if c.failover.enabled {
if err := c.failover.Validate(); err != nil {
return err
}
}
if c.typ == "" {
return errors.New("missing type in provider config")
}
@@ -355,6 +393,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
return err
}
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
return c.setRequestModel(ctx, req, log)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req, log)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
var model *string
switch req := request.(type) {
case *chatCompletionRequest:
model = &req.Model
case *embeddingsRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model, log)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping, log)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
*model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, *model)
return nil
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
if len(mappedModel) != 0 {
@@ -391,3 +483,62 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
// use original protocol
if c.protocol == protocolOriginal {
return types.ActionContinue, nil
}
// use openai protocol
var err error
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
}
if err != nil {
return types.ActionContinue, err
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders()
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(originalHeaders)
}
}
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
var request interface{}
if apiName == ApiNameChatCompletion {
request = &chatCompletionRequest{}
} else {
request = &embeddingsRequest{}
}
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
return json.Marshal(request)
}