diff --git a/plugins/wasm-go/extensions/hello-world/main.go b/plugins/wasm-go/extensions/hello-world/main.go index 7becbb27c..ac201e62b 100644 --- a/plugins/wasm-go/extensions/hello-world/main.go +++ b/plugins/wasm-go/extensions/hello-world/main.go @@ -31,7 +31,7 @@ func main() { type HelloWorldConfig struct { } -func onHttpRequestHeaders(ctx *wrapper.CommonHttpCtx[HelloWorldConfig], config HelloWorldConfig, needBody *bool, log wrapper.LogWrapper) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config HelloWorldConfig, log wrapper.Log) types.Action { err := proxywasm.AddHttpRequestHeader("hello", "world") if err != nil { log.Critical("failed to set request header") diff --git a/plugins/wasm-go/extensions/http-call/main.go b/plugins/wasm-go/extensions/http-call/main.go index 315d72bc4..ec668c2f5 100644 --- a/plugins/wasm-go/extensions/http-call/main.go +++ b/plugins/wasm-go/extensions/http-call/main.go @@ -41,7 +41,7 @@ type HttpCallConfig struct { tokenHeader string } -func parseConfig(json gjson.Result, config *HttpCallConfig, log wrapper.LogWrapper) error { +func parseConfig(json gjson.Result, config *HttpCallConfig, log wrapper.Log) error { config.bodyHeader = json.Get("bodyHeader").String() if config.bodyHeader == "" { return errors.New("missing bodyHeader in config") @@ -96,7 +96,7 @@ func parseConfig(json gjson.Result, config *HttpCallConfig, log wrapper.LogWrapp } } -func onHttpRequestHeaders(ctx *wrapper.CommonHttpCtx[HttpCallConfig], config HttpCallConfig, needBody *bool, log wrapper.LogWrapper) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config HttpCallConfig, log wrapper.Log) types.Action { config.client.Get(config.requestPath, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer proxywasm.ResumeHttpRequest() diff --git a/plugins/wasm-go/extensions/request-block/main.go b/plugins/wasm-go/extensions/request-block/main.go index 1d3097aa9..a36828411 100644 --- a/plugins/wasm-go/extensions/request-block/main.go +++ b/plugins/wasm-go/extensions/request-block/main.go @@ -44,7 +44,7 @@ type RequestBlockConfig struct { blockBodys []string } -func parseConfig(json gjson.Result, config *RequestBlockConfig, log wrapper.LogWrapper) error { +func parseConfig(json gjson.Result, config *RequestBlockConfig, log wrapper.Log) error { code := json.Get("blocked_code").Int() if code != 0 && code > 100 && code < 600 { config.blockedCode = uint32(code) @@ -93,7 +93,7 @@ func parseConfig(json gjson.Result, config *RequestBlockConfig, log wrapper.LogW return nil } -func onHttpRequestHeaders(ctx *wrapper.CommonHttpCtx[RequestBlockConfig], config RequestBlockConfig, needBody *bool, log wrapper.LogWrapper) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, log wrapper.Log) types.Action { if len(config.blockUrls) > 0 { requestUrl, err := proxywasm.GetHttpRequestHeader(":path") if err != nil { @@ -132,12 +132,12 @@ func onHttpRequestHeaders(ctx *wrapper.CommonHttpCtx[RequestBlockConfig], config } } if len(config.blockBodys) == 0 { - *needBody = false + ctx.DontReadRequestBody() } return types.ActionContinue } -func onHttpRequestBody(ctx *wrapper.CommonHttpCtx[RequestBlockConfig], config RequestBlockConfig, body []byte, log wrapper.LogWrapper) types.Action { +func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action { bodyStr := string(body) if !config.caseSensitive { bodyStr = strings.ToLower(bodyStr) diff --git a/plugins/wasm-go/pkg/wrapper/log_wrapper.go b/plugins/wasm-go/pkg/wrapper/log_wrapper.go index a7c02333a..0a0b7139c 100644 --- a/plugins/wasm-go/pkg/wrapper/log_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/log_wrapper.go @@ -31,11 +31,11 @@ const ( LogLevelCritical ) -type LogWrapper struct { +type Log struct { pluginName string } -func (l LogWrapper) log(level LogLevel, msg string) { +func (l Log) log(level LogLevel, msg string) { msg = fmt.Sprintf("[%s] %s", l.pluginName, msg) switch level { case LogLevelTrace: @@ -53,7 +53,7 @@ func (l LogWrapper) log(level LogLevel, msg string) { } } -func (l LogWrapper) logFormat(level LogLevel, format string, args ...interface{}) { +func (l Log) logFormat(level LogLevel, format string, args ...interface{}) { format = fmt.Sprintf("[%s] %s", l.pluginName, format) switch level { case LogLevelTrace: @@ -71,50 +71,50 @@ func (l LogWrapper) logFormat(level LogLevel, format string, args ...interface{} } } -func (l LogWrapper) Trace(msg string) { +func (l Log) Trace(msg string) { l.log(LogLevelTrace, msg) } -func (l LogWrapper) Tracef(format string, args ...interface{}) { +func (l Log) Tracef(format string, args ...interface{}) { l.logFormat(LogLevelTrace, format, args...) } -func (l LogWrapper) Debug(msg string) { +func (l Log) Debug(msg string) { l.log(LogLevelDebug, msg) } -func (l LogWrapper) Debugf(format string, args ...interface{}) { +func (l Log) Debugf(format string, args ...interface{}) { l.logFormat(LogLevelDebug, format, args...) } -func (l LogWrapper) Info(msg string) { +func (l Log) Info(msg string) { l.log(LogLevelInfo, msg) } -func (l LogWrapper) Infof(format string, args ...interface{}) { +func (l Log) Infof(format string, args ...interface{}) { l.logFormat(LogLevelInfo, format, args...) } -func (l LogWrapper) Warn(msg string) { +func (l Log) Warn(msg string) { l.log(LogLevelWarn, msg) } -func (l LogWrapper) Warnf(format string, args ...interface{}) { +func (l Log) Warnf(format string, args ...interface{}) { l.logFormat(LogLevelWarn, format, args...) } -func (l LogWrapper) Error(msg string) { +func (l Log) Error(msg string) { l.log(LogLevelError, msg) } -func (l LogWrapper) Errorf(format string, args ...interface{}) { +func (l Log) Errorf(format string, args ...interface{}) { l.logFormat(LogLevelError, format, args...) } -func (l LogWrapper) Critical(msg string) { +func (l Log) Critical(msg string) { l.log(LogLevelCritical, msg) } -func (l LogWrapper) Criticalf(format string, args ...interface{}) { +func (l Log) Criticalf(format string, args ...interface{}) { l.logFormat(LogLevelCritical, format, args...) } diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index 560363853..ae324acaf 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -24,14 +24,27 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/matcher" ) -type ParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig, log LogWrapper) error -type onHttpHeadersFunc[PluginConfig any] func(context *CommonHttpCtx[PluginConfig], config PluginConfig, needBody *bool, log LogWrapper) types.Action -type onHttpBodyFunc[PluginConfig any] func(context *CommonHttpCtx[PluginConfig], config PluginConfig, body []byte, log LogWrapper) types.Action +type HttpContext interface { + Scheme() string + Host() string + Path() string + Method() string + SetContext(key string, value interface{}) + GetContext(key string) interface{} + // If the onHttpRequestBody handle is not set, the request body will not be read by default + DontReadRequestBody() + // If the onHttpResponseBody handle is not set, the request body will not be read by default + DontReadResponseBody() +} + +type ParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig, log Log) error +type onHttpHeadersFunc[PluginConfig any] func(context HttpContext, config PluginConfig, log Log) types.Action +type onHttpBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, body []byte, log Log) types.Action type CommonVmCtx[PluginConfig any] struct { types.DefaultVMContext pluginName string - log LogWrapper + log Log hasCustomConfig bool parseConfig ParseConfigFunc[PluginConfig] onHttpRequestHeaders onHttpHeadersFunc[PluginConfig] @@ -76,14 +89,14 @@ func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) Set } } -func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig, LogWrapper) error { +func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig, Log) error { return nil } func NewCommonVmCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) *CommonVmCtx[PluginConfig] { ctx := &CommonVmCtx[PluginConfig]{ pluginName: pluginName, - log: LogWrapper{pluginName}, + log: Log{pluginName}, hasCustomConfig: true, } for _, set := range setFuncs { @@ -179,6 +192,34 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} { return ctx.userContext[key] } +func (ctx *CommonHttpCtx[PluginConfig]) Scheme() string { + proxywasm.SetEffectiveContext(ctx.contextID) + return GetRequestScheme() +} + +func (ctx *CommonHttpCtx[PluginConfig]) Host() string { + proxywasm.SetEffectiveContext(ctx.contextID) + return GetRequestHost() +} + +func (ctx *CommonHttpCtx[PluginConfig]) Path() string { + proxywasm.SetEffectiveContext(ctx.contextID) + return GetRequestPath() +} + +func (ctx *CommonHttpCtx[PluginConfig]) Method() string { + proxywasm.SetEffectiveContext(ctx.contextID) + return GetRequestMethod() +} + +func (ctx *CommonHttpCtx[PluginConfig]) DontReadRequestBody() { + ctx.needRequestBody = false +} + +func (ctx *CommonHttpCtx[PluginConfig]) DontReadResponseBody() { + ctx.needResponseBody = false +} + func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { config, err := ctx.plugin.GetMatchConfig() if err != nil { @@ -192,8 +233,7 @@ func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestHeaders(numHeaders int, end if ctx.plugin.vm.onHttpRequestHeaders == nil { return types.ActionContinue } - return ctx.plugin.vm.onHttpRequestHeaders(ctx, *config, - &ctx.needRequestBody, ctx.plugin.vm.log) + return ctx.plugin.vm.onHttpRequestHeaders(ctx, *config, ctx.plugin.vm.log) } func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { @@ -225,8 +265,7 @@ func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseHeaders(numHeaders int, en if ctx.plugin.vm.onHttpResponseHeaders == nil { return types.ActionContinue } - return ctx.plugin.vm.onHttpResponseHeaders(ctx, *ctx.config, - &ctx.needResponseBody, ctx.plugin.vm.log) + return ctx.plugin.vm.onHttpResponseHeaders(ctx, *ctx.config, ctx.plugin.vm.log) } func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseBody(bodySize int, endOfStream bool) types.Action {