fix:fix bug in ext-auth wasm plugin (#1152)

This commit is contained in:
韩贤涛
2024-08-05 11:04:31 +08:00
committed by GitHub
parent cc74c0da93
commit 08c64ed467
8 changed files with 275 additions and 122 deletions

View File

@@ -13,7 +13,13 @@ import (
const (
DefaultStatusOnError uint32 = http.StatusForbidden
DefaultHttpServiceTimeout uint32 = 200
DefaultHttpServiceTimeout uint32 = 1000
DefaultMaxRequestBodyBytes uint32 = 10 * 1024 * 1024
EndpointModeEnvoy = "envoy"
EndpointModeForwardAuth = "forward_auth"
)
type ExtAuthConfig struct {
@@ -24,8 +30,13 @@ type ExtAuthConfig struct {
}
type HttpService struct {
client wrapper.HttpClient
requestMethod string
endpointMode string
client wrapper.HttpClient
// pathPrefix is only used when endpoint_mode is envoy
pathPrefix string
// requestMethod is only used when endpoint_mode is forward_auth
requestMethod string
// path is only used when endpoint_mode is forward_auth
path string
timeout uint32
authorizationRequest AuthorizationRequest
@@ -35,9 +46,10 @@ type HttpService struct {
type AuthorizationRequest struct {
// allowedHeaders In addition to the users supplied matchers,
// Host, Method, Path, Content-Length, and Authorization are automatically included to the list.
allowedHeaders expr.Matcher
headersToAdd map[string]string
withRequestBody bool
allowedHeaders expr.Matcher
headersToAdd map[string]string
withRequestBody bool
maxRequestBodyBytes uint32
}
type AuthorizationResponse struct {
@@ -50,7 +62,7 @@ func parseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) erro
if !httpServiceConfig.Exists() {
return errors.New("missing http_service in config")
}
err := parseHttpServiceConfig(httpServiceConfig, config)
err := parseHttpServiceConfig(httpServiceConfig, config, log)
if err != nil {
return err
}
@@ -65,20 +77,19 @@ func parseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) erro
config.failureModeAllowHeaderAdd = failureModeAllowHeaderAdd.Bool()
}
statusOnError := json.Get("status_on_error")
if statusOnError.Exists() {
config.statusOnError = uint32(statusOnError.Uint())
} else {
config.statusOnError = DefaultStatusOnError
statusOnError := uint32(json.Get("status_on_error").Uint())
if statusOnError == 0 {
statusOnError = DefaultStatusOnError
}
config.statusOnError = statusOnError
return nil
}
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig) error {
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) error {
var httpService HttpService
if err := parseEndpointConfig(json, &httpService); err != nil {
if err := parseEndpointConfig(json, &httpService, log); err != nil {
return err
}
@@ -101,64 +112,63 @@ func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig) error {
return nil
}
func parseEndpointConfig(json gjson.Result, httpService *HttpService) error {
func parseEndpointConfig(json gjson.Result, httpService *HttpService, log wrapper.Log) error {
endpointMode := json.Get("endpoint_mode").String()
if endpointMode == "" {
endpointMode = EndpointModeEnvoy
} else if endpointMode != EndpointModeEnvoy && endpointMode != EndpointModeForwardAuth {
return errors.New(fmt.Sprintf("endpoint_mode %s is not supported", endpointMode))
}
httpService.endpointMode = endpointMode
endpointConfig := json.Get("endpoint")
if !endpointConfig.Exists() {
return errors.New("missing endpoint in config")
}
serviceSource := endpointConfig.Get("service_source").String()
serviceName := endpointConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("endpoint service name must not be empty")
}
servicePort := endpointConfig.Get("service_port").Int()
if serviceName == "" || servicePort == 0 {
return errors.New("invalid service config")
}
switch serviceSource {
case "k8s":
namespace := json.Get("namespace").String()
httpService.client = wrapper.NewClusterClient(wrapper.K8sCluster{
ServiceName: serviceName,
Namespace: namespace,
Port: servicePort,
})
return nil
case "nacos":
namespace := json.Get("namespace").String()
httpService.client = wrapper.NewClusterClient(wrapper.NacosCluster{
ServiceName: serviceName,
NamespaceID: namespace,
Port: servicePort,
})
return nil
case "ip":
httpService.client = wrapper.NewClusterClient(wrapper.StaticIpCluster{
ServiceName: serviceName,
Port: servicePort,
})
case "dns":
domain := endpointConfig.Get("domain").String()
httpService.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
Port: servicePort,
Domain: domain,
})
default:
return errors.New("unknown service source: " + serviceSource)
if servicePort == 0 {
servicePort = 80
}
requestMethodConfig := endpointConfig.Get("request_method")
if !requestMethodConfig.Exists() {
httpService.requestMethod = http.MethodGet
} else {
httpService.requestMethod = strings.ToUpper(requestMethodConfig.String())
}
httpService.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
})
pathConfig := endpointConfig.Get("path")
if !pathConfig.Exists() {
return errors.New("missing path in config")
}
httpService.path = pathConfig.String()
switch endpointMode {
case EndpointModeEnvoy:
pathPrefixConfig := endpointConfig.Get("path_prefix")
if !pathPrefixConfig.Exists() {
return errors.New("when endpoint_mode is envoy, endpoint path_prefix must not be empty")
}
httpService.pathPrefix = pathPrefixConfig.String()
if endpointConfig.Get("request_method").Exists() || endpointConfig.Get("path").Exists() {
log.Warn("when endpoint_mode is envoy, endpoint request_method and path will be ignored")
}
case EndpointModeForwardAuth:
requestMethodConfig := endpointConfig.Get("request_method")
if !requestMethodConfig.Exists() {
httpService.requestMethod = http.MethodGet
} else {
httpService.requestMethod = strings.ToUpper(requestMethodConfig.String())
}
pathConfig := endpointConfig.Get("path")
if !pathConfig.Exists() {
return errors.New("when endpoint_mode is forward_auth, endpoint path must not be empty")
}
httpService.path = pathConfig.String()
if endpointConfig.Get("path_prefix").Exists() {
log.Warn("when endpoint_mode is forward_auth, endpoint path_prefix will be ignored")
}
}
return nil
}
@@ -167,6 +177,15 @@ func parseAuthorizationRequestConfig(json gjson.Result, httpService *HttpService
if authorizationRequestConfig.Exists() {
var authorizationRequest AuthorizationRequest
allowedHeaders := authorizationRequestConfig.Get("allowed_headers")
if allowedHeaders.Exists() {
result, err := expr.BuildRepeatedStringMatcherIgnoreCase(allowedHeaders.Array())
if err != nil {
return err
}
authorizationRequest.allowedHeaders = result
}
headersToAdd := map[string]string{}
headersToAddConfig := authorizationRequestConfig.Get("headers_to_add")
if headersToAddConfig.Exists() {
@@ -186,14 +205,11 @@ func parseAuthorizationRequestConfig(json gjson.Result, httpService *HttpService
authorizationRequest.withRequestBody = withRequestBody.Bool()
}
allowedHeaders := authorizationRequestConfig.Get("allowed_headers")
if allowedHeaders.Exists() {
result, err := expr.BuildRepeatedStringMatcherIgnoreCase(allowedHeaders.Array())
if err != nil {
return err
}
authorizationRequest.allowedHeaders = result
maxRequestBodyBytes := uint32(authorizationRequestConfig.Get("max_request_body_bytes").Uint())
if maxRequestBodyBytes == 0 {
maxRequestBodyBytes = DefaultMaxRequestBodyBytes
}
authorizationRequest.maxRequestBodyBytes = maxRequestBodyBytes
httpService.authorizationRequest = authorizationRequest
}