mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 20:17:29 +08:00
optimize plugin sdk (#1930)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -27,16 +28,16 @@ const (
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
pluginName,
|
||||
wrapper.ParseOverrideConfigBy(parseGlobalConfig, parseOverrideRuleConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeader),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeader),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
||||
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error {
|
||||
log.Debugf("loading global config: %s", json.String())
|
||||
|
||||
pluginConfig.FromJson(json)
|
||||
@@ -44,7 +45,7 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
|
||||
log.Errorf("global rule config is invalid: %v", err)
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
log.Errorf("failed to apply global rule config: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +53,7 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
||||
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error {
|
||||
log.Debugf("loading override rule config: %s", json.String())
|
||||
|
||||
*pluginConfig = global
|
||||
@@ -62,7 +63,7 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
||||
log.Errorf("overriden rule config is invalid: %v", err)
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
log.Errorf("failed to apply overriden rule config: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -70,7 +71,7 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -112,15 +113,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||
// Set the apiToken for the current request.
|
||||
providerConfig.SetApiTokenInUse(ctx, log)
|
||||
providerConfig.SetApiTokenInUse(ctx)
|
||||
// Set available apiTokens of current request in the context, will be used in the retryOnFailure
|
||||
providerConfig.SetAvailableApiTokens(ctx, log)
|
||||
providerConfig.SetAvailableApiTokens(ctx)
|
||||
|
||||
// save the original request host and path in case they are needed for apiToken health check and retry
|
||||
ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost())
|
||||
ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath())
|
||||
|
||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
err := handler.OnRequestHeaders(ctx, apiName)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||
return types.ActionContinue
|
||||
@@ -140,7 +141,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -161,11 +162,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
if providerConfig.IsOpenAIProtocol() {
|
||||
newBody = normalizeOpenAiRequestBody(newBody, log)
|
||||
newBody = normalizeOpenAiRequestBody(newBody)
|
||||
}
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
body = newBody
|
||||
action, err := handler.OnRequestBody(ctx, apiName, body, log)
|
||||
action, err := handler.OnRequestBody(ctx, apiName, body)
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
@@ -174,7 +175,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
||||
if !wrapper.IsResponseFromUpstream() {
|
||||
// Response is not coming from the upstream. Let it pass through.
|
||||
ctx.DontReadResponseBody()
|
||||
@@ -201,23 +202,23 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
log.Errorf("unable to load :status header from response: %v", err)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status, log)
|
||||
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status)
|
||||
}
|
||||
|
||||
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
|
||||
|
||||
headers := util.GetOriginalResponseHeaders()
|
||||
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers)
|
||||
} else {
|
||||
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
||||
}
|
||||
util.ReplaceResponseHeaders(headers)
|
||||
|
||||
checkStream(ctx, log)
|
||||
checkStream(ctx)
|
||||
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
||||
var needHandleStreamingBody bool
|
||||
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
|
||||
@@ -233,7 +234,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -246,7 +247,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
return modifiedChunk
|
||||
}
|
||||
@@ -254,7 +255,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
}
|
||||
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk, log)
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
|
||||
if len(events) == 0 {
|
||||
// No events are extracted, return the original chunk
|
||||
@@ -269,7 +270,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
continue
|
||||
}
|
||||
|
||||
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event, log)
|
||||
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event)
|
||||
if err != nil {
|
||||
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
|
||||
return chunk
|
||||
@@ -287,7 +288,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return chunk
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -299,19 +300,19 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
|
||||
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
||||
if err = provider.ReplaceResponseBody(body); err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||
}
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
|
||||
func normalizeOpenAiRequestBody(body []byte) []byte {
|
||||
var err error
|
||||
// Default setting include_usage.
|
||||
if gjson.GetBytes(body, "stream").Bool() {
|
||||
@@ -323,7 +324,7 @@ func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
func checkStream(ctx wrapper.HttpContext) {
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
||||
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user