optimize plugin sdk (#1930)

This commit is contained in:
澄潭
2025-03-22 22:46:37 +08:00
committed by GitHub
parent 1812a6b0a9
commit 45fbc8b084
117 changed files with 1036 additions and 766 deletions

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -145,19 +146,19 @@ type Provider interface {
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error)
}
type StreamingEventHandler interface {
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error)
}
type ApiNameHandler interface {
@@ -165,25 +166,25 @@ type ApiNameHandler interface {
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error)
}
type TransformResponseHeadersHandler interface {
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
type ProviderConfig struct {
@@ -496,7 +497,7 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
@@ -511,18 +512,18 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
ctx.SetContext(ctxKeyIsStreaming, false)
}
return c.setRequestModel(ctx, req, log)
return c.setRequestModel(ctx, req)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req, log)
return c.setRequestModel(ctx, req)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}) error {
var model *string
switch req := request.(type) {
@@ -534,16 +535,16 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model, log)
return c.mapModel(ctx, model)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping, log)
mappedModel := getMappedModel(*model, c.modelMapping)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
@@ -553,15 +554,15 @@ func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wr
return nil
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
func getMappedModel(model string, modelMapping map[string]string) string {
mappedModel := doGetMappedModel(model, modelMapping)
if len(mappedModel) != 0 {
return mappedModel
}
return model
}
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
func doGetMappedModel(model string, modelMapping map[string]string) string {
if len(modelMapping) == 0 {
return ""
}
@@ -590,7 +591,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
@@ -679,8 +680,7 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
// use original protocol
if c.IsOriginal() {
return types.ActionContinue, nil
@@ -689,13 +689,13 @@ func (c *ProviderConfig) handleRequestBody(
// use openai protocol
var err error
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
body, err = handler.TransformRequestBody(ctx, apiName, body)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
}
if err != nil {
@@ -704,28 +704,28 @@ func (c *ProviderConfig) handleRequestBody(
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
err = contextCache.GetContextFromFile(ctx, provider, body)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
headers := util.GetOriginalRequestHeaders()
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, headers, log)
handler.TransformRequestHeaders(ctx, apiName, headers)
util.ReplaceRequestHeaders(headers)
}
}
// defaultTransformRequestBody 默认的请求体转换方法只做模型映射用slog替换模型名称不用序列化和反序列化提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
switch apiName {
case ApiNameChatCompletion:
stream := gjson.GetBytes(body, "stream").Bool()
@@ -738,7 +738,7 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
}
model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {