mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
optimize plugin sdk (#1930)
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -145,19 +146,19 @@ type Provider interface {
|
||||
}
|
||||
|
||||
type RequestHeadersHandler interface {
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error
|
||||
}
|
||||
|
||||
type RequestBodyHandler interface {
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error)
|
||||
}
|
||||
|
||||
type StreamingResponseBodyHandler interface {
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error)
|
||||
}
|
||||
|
||||
type StreamingEventHandler interface {
|
||||
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
|
||||
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error)
|
||||
}
|
||||
|
||||
type ApiNameHandler interface {
|
||||
@@ -165,25 +166,25 @@ type ApiNameHandler interface {
|
||||
}
|
||||
|
||||
type TransformRequestHeadersHandler interface {
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
|
||||
}
|
||||
|
||||
type TransformRequestBodyHandler interface {
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
|
||||
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
|
||||
type TransformRequestBodyHeadersHandler interface {
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error)
|
||||
}
|
||||
|
||||
type TransformResponseHeadersHandler interface {
|
||||
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
|
||||
}
|
||||
|
||||
type TransformResponseBodyHandler interface {
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -496,7 +497,7 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
if err := decodeChatCompletionRequest(body, req); err != nil {
|
||||
@@ -511,18 +512,18 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||
}
|
||||
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *embeddingsRequest:
|
||||
if err := decodeEmbeddingsRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}) error {
|
||||
var model *string
|
||||
|
||||
switch req := request.(type) {
|
||||
@@ -534,16 +535,16 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
return c.mapModel(ctx, model, log)
|
||||
return c.mapModel(ctx, model)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string) error {
|
||||
if *model == "" {
|
||||
return errors.New("missing model in request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
|
||||
|
||||
mappedModel := getMappedModel(*model, c.modelMapping, log)
|
||||
mappedModel := getMappedModel(*model, c.modelMapping)
|
||||
if mappedModel == "" {
|
||||
return errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
@@ -553,15 +554,15 @@ func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wr
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
mappedModel := doGetMappedModel(model, modelMapping, log)
|
||||
func getMappedModel(model string, modelMapping map[string]string) string {
|
||||
mappedModel := doGetMappedModel(model, modelMapping)
|
||||
if len(mappedModel) != 0 {
|
||||
return mappedModel
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
if len(modelMapping) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -590,7 +591,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
||||
return ""
|
||||
}
|
||||
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
body = append(bufferedStreamingBody, chunk...)
|
||||
@@ -679,8 +680,7 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||||
) (types.Action, error) {
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
@@ -689,13 +689,13 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
// use openai protocol
|
||||
var err error
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -704,28 +704,28 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if c.context == nil {
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body)
|
||||
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers, log)
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
@@ -738,7 +738,7 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
|
||||
Reference in New Issue
Block a user