mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 20:17:29 +08:00
Plugin cors (#349)
This commit is contained in:
457
plugins/wasm-go/extensions/cors/config/cors_config.go
Normal file
457
plugins/wasm-go/extensions/cors/config/cors_config.go
Normal file
@@ -0,0 +1,457 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMatchAll = "*"
|
||||
defaultAllowMethods = "GET, PUT, POST, DELETE, PATCH, OPTIONS"
|
||||
defaultAllAllowMethods = "GET, PUT, POST, DELETE, PATCH, OPTIONS, HEAD, TRACE, CONNECT"
|
||||
defaultAllowHeaders = "DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With," +
|
||||
"If-Modified-Since,Cache-Control,Content-Type,Authorization"
|
||||
defaultMaxAge = 86400
|
||||
protocolHttpName = "http"
|
||||
protocolHttpPort = "80"
|
||||
protocolHttpsName = "https"
|
||||
protocolHttpsPort = "443"
|
||||
|
||||
HeaderPluginDebug = "X-Cors-Version"
|
||||
HeaderPluginTrace = "X-Cors-Trace"
|
||||
HeaderOrigin = "Origin"
|
||||
HttpMethodOptions = "OPTIONS"
|
||||
|
||||
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
|
||||
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
|
||||
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
|
||||
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
|
||||
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
|
||||
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
|
||||
|
||||
HeaderControlRequestMethod = "Access-Control-Request-Method"
|
||||
HeaderControlRequestHeaders = "Access-Control-Request-Headers"
|
||||
|
||||
HttpContextKey = "CORS"
|
||||
)
|
||||
|
||||
var portsRegex = regexp.MustCompile(`(.*):\[(\*|\d+(,\d+)*)]`)
|
||||
|
||||
type OriginPattern struct {
|
||||
declaredPattern string
|
||||
pattern *regexp.Regexp
|
||||
patternValue string
|
||||
}
|
||||
|
||||
func newOriginPatternFromString(declaredPattern string) OriginPattern {
|
||||
declaredPattern = strings.ToLower(strings.TrimSuffix(declaredPattern, "/"))
|
||||
matches := portsRegex.FindAllStringSubmatch(declaredPattern, -1)
|
||||
portList := ""
|
||||
patternValue := declaredPattern
|
||||
if len(matches) > 0 {
|
||||
patternValue = matches[0][1]
|
||||
portList = matches[0][2]
|
||||
}
|
||||
|
||||
patternValue = "\\Q" + patternValue + "\\E"
|
||||
patternValue = strings.ReplaceAll(patternValue, "*", "\\E.*\\Q")
|
||||
if len(portList) > 0 {
|
||||
if portList == defaultMatchAll {
|
||||
patternValue += "(:\\d+)?"
|
||||
} else {
|
||||
patternValue += ":(" + strings.ReplaceAll(portList, ",", "|") + ")"
|
||||
}
|
||||
}
|
||||
|
||||
return OriginPattern{
|
||||
declaredPattern: declaredPattern,
|
||||
patternValue: patternValue,
|
||||
pattern: regexp.MustCompile(patternValue),
|
||||
}
|
||||
}
|
||||
|
||||
type CorsConfig struct {
|
||||
// allowOrigins A list of origins for which cross-origin requests are allowed.
|
||||
// Be a specific domain, e.g. "https://example.com", or the CORS defined special value "*" for all origins.
|
||||
// Keep in mind however that the CORS spec does not allow "*" when allowCredentials is set to true, using allowOriginPatterns instead
|
||||
// By default, it is set to "*" when allowOriginPatterns is not set too.
|
||||
allowOrigins []string
|
||||
|
||||
// allowOriginPatterns A list of origin patterns for which cross-origin requests are allowed
|
||||
// origins patterns with "*" anywhere in the host name in addition to port
|
||||
// lists Examples:
|
||||
// https://*.example.com -- domains ending with example.com
|
||||
// https://*.example.com:[8080,9090] -- domains ending with example.com on port 8080 or port 9090
|
||||
// https://*.example.com:[*] -- domains ending with example.com on any port, including the default port
|
||||
// The special value "*" allows all origins
|
||||
// By default, it is not set.
|
||||
allowOriginPatterns []OriginPattern
|
||||
|
||||
// allowMethods A list of method for which cross-origin requests are allowed
|
||||
// The special value "*" allows all methods.
|
||||
// By default, it is set to "GET, PUT, POST, DELETE, PATCH, OPTIONS".
|
||||
allowMethods []string
|
||||
|
||||
// allowHeaders A list of headers that a pre-flight request can list as allowed
|
||||
// The special value "*" allows actual requests to send any header
|
||||
// By default, it is set to "DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization"
|
||||
allowHeaders []string
|
||||
|
||||
// exposeHeaders A list of response headers an actual response might have and can be exposed.
|
||||
// The special value "*" allows all headers to be exposed for non-credentialed requests.
|
||||
// By default, it is not set
|
||||
exposeHeaders []string
|
||||
|
||||
// allowCredentials Whether user credentials are supported.
|
||||
// By default, it is not set (i.e. user credentials are not supported).
|
||||
allowCredentials bool
|
||||
|
||||
// maxAge Configure how long, in seconds, the response from a pre-flight request can be cached by clients.
|
||||
// By default, it is set to 86400 seconds.
|
||||
maxAge int
|
||||
}
|
||||
|
||||
type HttpCorsContext struct {
|
||||
IsValid bool
|
||||
ValidReason string
|
||||
IsPreFlight bool
|
||||
IsCorsRequest bool
|
||||
AllowOrigin string
|
||||
AllowMethods string
|
||||
AllowHeaders string
|
||||
ExposeHeaders string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
func (c *CorsConfig) GetVersion() string {
|
||||
return "1.0.0"
|
||||
}
|
||||
|
||||
func (c *CorsConfig) FillDefaultValues() {
|
||||
if len(c.allowOrigins) == 0 && len(c.allowOriginPatterns) == 0 && c.allowCredentials == false {
|
||||
c.allowOrigins = []string{defaultMatchAll}
|
||||
}
|
||||
if len(c.allowHeaders) == 0 {
|
||||
c.allowHeaders = []string{defaultAllowHeaders}
|
||||
}
|
||||
if len(c.allowMethods) == 0 {
|
||||
c.allowMethods = strings.Split(defaultAllowMethods, ",")
|
||||
}
|
||||
if c.maxAge == 0 {
|
||||
c.maxAge = defaultMaxAge
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CorsConfig) AddAllowOrigin(origin string) error {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if len(origin) == 0 {
|
||||
return nil
|
||||
}
|
||||
if origin == defaultMatchAll {
|
||||
if c.allowCredentials == true {
|
||||
return errors.New("can't set origin to * when allowCredentials is true, use AllowOriginPatterns instead")
|
||||
}
|
||||
c.allowOrigins = []string{defaultMatchAll}
|
||||
return nil
|
||||
}
|
||||
c.allowOrigins = append(c.allowOrigins, strings.TrimSuffix(origin, "/"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CorsConfig) AddAllowHeader(header string) {
|
||||
header = strings.TrimSpace(header)
|
||||
if len(header) == 0 {
|
||||
return
|
||||
}
|
||||
if header == defaultMatchAll {
|
||||
c.allowHeaders = []string{defaultMatchAll}
|
||||
return
|
||||
}
|
||||
c.allowHeaders = append(c.allowHeaders, header)
|
||||
}
|
||||
|
||||
func (c *CorsConfig) AddAllowMethod(method string) {
|
||||
method = strings.TrimSpace(method)
|
||||
if len(method) == 0 {
|
||||
return
|
||||
}
|
||||
if method == defaultMatchAll {
|
||||
c.allowMethods = []string{defaultMatchAll}
|
||||
return
|
||||
}
|
||||
c.allowMethods = append(c.allowMethods, strings.ToUpper(method))
|
||||
}
|
||||
|
||||
func (c *CorsConfig) AddExposeHeader(header string) {
|
||||
header = strings.TrimSpace(header)
|
||||
if len(header) == 0 {
|
||||
return
|
||||
}
|
||||
if header == defaultMatchAll {
|
||||
c.exposeHeaders = []string{defaultMatchAll}
|
||||
return
|
||||
}
|
||||
c.exposeHeaders = append(c.exposeHeaders, header)
|
||||
}
|
||||
|
||||
func (c *CorsConfig) AddAllowOriginPattern(pattern string) {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if len(pattern) == 0 {
|
||||
return
|
||||
}
|
||||
originPattern := newOriginPatternFromString(pattern)
|
||||
c.allowOriginPatterns = append(c.allowOriginPatterns, originPattern)
|
||||
}
|
||||
|
||||
func (c *CorsConfig) SetAllowCredentials(allowCredentials bool) error {
|
||||
if allowCredentials && len(c.allowOrigins) > 0 && c.allowOrigins[0] == defaultMatchAll {
|
||||
return errors.New("can't set allowCredentials to true when allowOrigin is *")
|
||||
}
|
||||
c.allowCredentials = allowCredentials
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CorsConfig) SetMaxAge(maxAge int) {
|
||||
if maxAge <= 0 {
|
||||
c.maxAge = defaultMaxAge
|
||||
} else {
|
||||
c.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CorsConfig) Process(scheme string, host string, method string, headers [][2]string) (HttpCorsContext, error) {
|
||||
scheme = strings.ToLower(strings.TrimSpace(scheme))
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
method = strings.ToLower(strings.TrimSpace(method))
|
||||
|
||||
// Init httpCorsContext with default values
|
||||
httpCorsContext := HttpCorsContext{IsValid: true, IsPreFlight: false, IsCorsRequest: false, AllowCredentials: false, MaxAge: 0}
|
||||
|
||||
// Get request origin, controlRequestMethod, controlRequestHeaders from http headers
|
||||
origin := ""
|
||||
controlRequestMethod := ""
|
||||
controlRequestHeaders := ""
|
||||
for _, header := range headers {
|
||||
key := header[0]
|
||||
// Get origin
|
||||
if strings.ToLower(key) == strings.ToLower(HeaderOrigin) {
|
||||
origin = strings.TrimSuffix(strings.TrimSpace(header[1]), "/")
|
||||
}
|
||||
// Get control request method & headers
|
||||
if strings.ToLower(key) == strings.ToLower(HeaderControlRequestMethod) {
|
||||
controlRequestMethod = strings.TrimSpace(header[1])
|
||||
}
|
||||
if strings.ToLower(key) == strings.ToLower(HeaderControlRequestHeaders) {
|
||||
controlRequestHeaders = strings.TrimSpace(header[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse if request is CORS and pre-flight request.
|
||||
isCorsRequest := c.isCorsRequest(scheme, host, origin)
|
||||
isPreFlight := c.isPreFlight(origin, method, controlRequestMethod)
|
||||
httpCorsContext.IsCorsRequest = isCorsRequest
|
||||
httpCorsContext.IsPreFlight = isPreFlight
|
||||
|
||||
// Skip when it is not CORS request
|
||||
if !isCorsRequest {
|
||||
httpCorsContext.IsValid = true
|
||||
return httpCorsContext, nil
|
||||
}
|
||||
|
||||
// Check origin
|
||||
allowOrigin, originOk := c.checkOrigin(origin)
|
||||
if !originOk {
|
||||
// Reject: origin is not allowed
|
||||
httpCorsContext.IsValid = false
|
||||
httpCorsContext.ValidReason = fmt.Sprintf("origin:%s is not allowed", origin)
|
||||
return httpCorsContext, nil
|
||||
}
|
||||
|
||||
// Check method
|
||||
requestMethod := method
|
||||
if isPreFlight {
|
||||
requestMethod = controlRequestMethod
|
||||
}
|
||||
allowMethods, methodOk := c.checkMethods(requestMethod)
|
||||
if !methodOk {
|
||||
// Reject: method is not allowed
|
||||
httpCorsContext.IsValid = false
|
||||
httpCorsContext.ValidReason = fmt.Sprintf("method:%s is not allowed", requestMethod)
|
||||
return httpCorsContext, nil
|
||||
}
|
||||
|
||||
// Check headers
|
||||
allowHeaders, headerOK := c.checkHeaders(controlRequestHeaders)
|
||||
|
||||
if isPreFlight && !headerOK {
|
||||
// Reject: headers are not allowed
|
||||
httpCorsContext.IsValid = false
|
||||
httpCorsContext.ValidReason = "Reject: headers are not allowed"
|
||||
return httpCorsContext, nil
|
||||
}
|
||||
|
||||
// Store result in httpCorsContext and return it.
|
||||
httpCorsContext.AllowOrigin = allowOrigin
|
||||
if isPreFlight {
|
||||
httpCorsContext.AllowMethods = allowMethods
|
||||
}
|
||||
if isPreFlight && len(allowHeaders) > 0 {
|
||||
httpCorsContext.AllowHeaders = allowHeaders
|
||||
}
|
||||
if isPreFlight && c.maxAge > 0 {
|
||||
httpCorsContext.MaxAge = c.maxAge
|
||||
}
|
||||
if len(c.exposeHeaders) > 0 {
|
||||
httpCorsContext.ExposeHeaders = strings.Join(c.exposeHeaders, ",")
|
||||
}
|
||||
httpCorsContext.AllowCredentials = c.allowCredentials
|
||||
|
||||
return httpCorsContext, nil
|
||||
}
|
||||
|
||||
func (c *CorsConfig) checkOrigin(origin string) (string, bool) {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if len(origin) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
matchOrigin := strings.ToLower(origin)
|
||||
// Check exact match
|
||||
for _, allowOrigin := range c.allowOrigins {
|
||||
if allowOrigin == defaultMatchAll {
|
||||
return origin, true
|
||||
}
|
||||
if strings.ToLower(allowOrigin) == matchOrigin {
|
||||
return origin, true
|
||||
}
|
||||
}
|
||||
|
||||
// Check pattern match
|
||||
for _, allowOriginPattern := range c.allowOriginPatterns {
|
||||
if allowOriginPattern.declaredPattern == defaultMatchAll || allowOriginPattern.pattern.MatchString(matchOrigin) {
|
||||
return origin, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *CorsConfig) checkHeaders(requestHeaders string) (string, bool) {
|
||||
if len(c.allowHeaders) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if len(requestHeaders) == 0 {
|
||||
return strings.Join(c.allowHeaders, ","), true
|
||||
}
|
||||
|
||||
// Return all request headers when allowHeaders contains *
|
||||
if c.allowHeaders[0] == defaultMatchAll {
|
||||
return requestHeaders, true
|
||||
}
|
||||
|
||||
checkHeaders := strings.Split(requestHeaders, ",")
|
||||
// Each request header should be existed in allowHeaders configuration
|
||||
for _, h := range checkHeaders {
|
||||
isExist := false
|
||||
for _, allowHeader := range c.allowHeaders {
|
||||
if strings.ToLower(h) == strings.ToLower(allowHeader) {
|
||||
isExist = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isExist {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(c.allowHeaders, ","), true
|
||||
}
|
||||
|
||||
func (c *CorsConfig) checkMethods(requestMethod string) (string, bool) {
|
||||
if len(requestMethod) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Find method existed in allowMethods configuration
|
||||
for _, method := range c.allowMethods {
|
||||
if method == defaultMatchAll {
|
||||
return defaultAllAllowMethods, true
|
||||
}
|
||||
if strings.ToLower(method) == strings.ToLower(requestMethod) {
|
||||
return strings.Join(c.allowMethods, ","), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *CorsConfig) isPreFlight(origin, method, controllerRequestMethod string) bool {
|
||||
return len(origin) > 0 && strings.ToLower(method) == strings.ToLower(HttpMethodOptions) && len(controllerRequestMethod) > 0
|
||||
}
|
||||
|
||||
func (c *CorsConfig) isCorsRequest(scheme, host, origin string) bool {
|
||||
if len(origin) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
url, err := url.Parse(strings.TrimSpace(origin))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check scheme
|
||||
if strings.ToLower(scheme) != strings.ToLower(url.Scheme) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check host and port
|
||||
port := ""
|
||||
originPort := ""
|
||||
originHost := ""
|
||||
host, port = c.getHostAndPort(scheme, host)
|
||||
originHost, originPort = c.getHostAndPort(url.Scheme, url.Host)
|
||||
if host != originHost || port != originPort {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *CorsConfig) getHostAndPort(scheme string, host string) (string, string) {
|
||||
// Get host and port
|
||||
scheme = strings.ToLower(scheme)
|
||||
host = strings.ToLower(host)
|
||||
port := ""
|
||||
hosts := strings.Split(host, ":")
|
||||
if len(hosts) > 1 {
|
||||
host = hosts[0]
|
||||
port = hosts[1]
|
||||
}
|
||||
// Get default port according scheme
|
||||
if len(port) == 0 && scheme == protocolHttpName {
|
||||
port = protocolHttpPort
|
||||
}
|
||||
if len(port) == 0 && scheme == protocolHttpsName {
|
||||
port = protocolHttpsPort
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
408
plugins/wasm-go/extensions/cors/config/cors_config_test.go
Normal file
408
plugins/wasm-go/extensions/cors/config/cors_config_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCorsConfig_getHostAndPort(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scheme string
|
||||
host string
|
||||
wantHost string
|
||||
wantPort string
|
||||
}{
|
||||
{
|
||||
name: "http without port",
|
||||
scheme: "http",
|
||||
host: "http.example.com",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "80",
|
||||
},
|
||||
{
|
||||
name: "https without port",
|
||||
scheme: "https",
|
||||
host: "http.example.com",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "443",
|
||||
},
|
||||
|
||||
{
|
||||
name: "http with port and case insensitive",
|
||||
scheme: "hTTp",
|
||||
host: "hTTp.Example.com:8080",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "8080",
|
||||
},
|
||||
|
||||
{
|
||||
name: "https with port and case insensitive",
|
||||
scheme: "hTTps",
|
||||
host: "hTTp.Example.com:8080",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "8080",
|
||||
},
|
||||
|
||||
{
|
||||
name: "protocal is not http",
|
||||
scheme: "wss",
|
||||
host: "hTTp.Example.com",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "",
|
||||
},
|
||||
|
||||
{
|
||||
name: "protocal is not http",
|
||||
scheme: "wss",
|
||||
host: "hTTp.Example.com:8080",
|
||||
wantHost: "http.example.com",
|
||||
wantPort: "8080",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{}
|
||||
host, port := c.getHostAndPort(tt.scheme, tt.host)
|
||||
assert.Equal(t, tt.wantHost, host)
|
||||
assert.Equal(t, tt.wantPort, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsConfig_isCorsRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scheme string
|
||||
host string
|
||||
origin string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "blank origin",
|
||||
scheme: "http",
|
||||
host: "httpbin.example.com",
|
||||
origin: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "normal equal case with space and case ",
|
||||
scheme: "http",
|
||||
host: "httpbin.example.com",
|
||||
origin: "http://hTTPbin.Example.com",
|
||||
want: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "cors request with port diff",
|
||||
scheme: "http",
|
||||
host: "httpbin.example.com",
|
||||
origin: " http://httpbin.example.com:8080 ",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "cors request with scheme diff",
|
||||
scheme: "http",
|
||||
host: "httpbin.example.com",
|
||||
origin: " https://HTTPpbin.Example.com ",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "cors request with host diff",
|
||||
scheme: "http",
|
||||
host: "httpbin.example.com",
|
||||
origin: " http://HTTPpbin.Example.org ",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{}
|
||||
assert.Equalf(t, tt.want, c.isCorsRequest(tt.scheme, tt.host, tt.origin), "isCorsRequest(%v, %v, %v)", tt.scheme, tt.host, tt.origin)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsConfig_isPreFlight(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
method string
|
||||
controllerRequestMethod string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "blank case",
|
||||
origin: "",
|
||||
method: "",
|
||||
controllerRequestMethod: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "normal case",
|
||||
origin: "http://httpbin.example.com",
|
||||
method: "Options",
|
||||
controllerRequestMethod: "PUT",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "bad case with diff method",
|
||||
origin: "http://httpbin.example.com",
|
||||
method: "GET",
|
||||
controllerRequestMethod: "PUT",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{}
|
||||
assert.Equalf(t, tt.want, c.isPreFlight(tt.origin, tt.method, tt.controllerRequestMethod), "isPreFlight(%v, %v, %v)", tt.origin, tt.method, tt.controllerRequestMethod)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsConfig_checkMethods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowMethods []string
|
||||
requestMethod string
|
||||
wantMethods string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "default *",
|
||||
allowMethods: []string{"*"},
|
||||
requestMethod: "GET",
|
||||
wantMethods: defaultAllAllowMethods,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "normal allow case",
|
||||
allowMethods: []string{"GET", "PUT", "HEAD"},
|
||||
requestMethod: "get",
|
||||
wantMethods: "GET,PUT,HEAD",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "forbidden case",
|
||||
allowMethods: []string{"GET", "PUT", "HEAD"},
|
||||
requestMethod: "POST",
|
||||
wantMethods: "",
|
||||
wantOk: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "blank method",
|
||||
allowMethods: []string{"GET", "PUT", "HEAD"},
|
||||
requestMethod: "",
|
||||
wantMethods: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{
|
||||
allowMethods: tt.allowMethods,
|
||||
}
|
||||
allowMethods, allowOk := c.checkMethods(tt.requestMethod)
|
||||
assert.Equalf(t, tt.wantMethods, allowMethods, "checkMethods(%v)", tt.requestMethod)
|
||||
assert.Equalf(t, tt.wantOk, allowOk, "checkMethods(%v)", tt.requestMethod)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsConfig_checkHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowHeaders []string
|
||||
requestHeaders string
|
||||
wantHeaders string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "not pre-flight",
|
||||
allowHeaders: []string{"Content-Type", "Authorization"},
|
||||
requestHeaders: "",
|
||||
wantHeaders: "Content-Type,Authorization",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "blank allowheaders case 1",
|
||||
allowHeaders: []string{},
|
||||
requestHeaders: "",
|
||||
wantHeaders: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "blank allowheaders case 2",
|
||||
requestHeaders: "Authorization",
|
||||
wantHeaders: "",
|
||||
wantOk: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "allowheaders *",
|
||||
allowHeaders: []string{"*"},
|
||||
requestHeaders: "Content-Type,Authorization",
|
||||
wantHeaders: "Content-Type,Authorization",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "allowheader values 1",
|
||||
allowHeaders: []string{"Content-Type", "Authorization"},
|
||||
requestHeaders: "Content-Type,Authorization",
|
||||
wantHeaders: "Content-Type,Authorization",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "allowheader values 2",
|
||||
allowHeaders: []string{"Content-Type", "Authorization"},
|
||||
requestHeaders: "",
|
||||
wantHeaders: "Content-Type,Authorization",
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{
|
||||
allowHeaders: tt.allowHeaders,
|
||||
}
|
||||
allowHeaders, allowOk := c.checkHeaders(tt.requestHeaders)
|
||||
assert.Equalf(t, tt.wantHeaders, allowHeaders, "checkHeaders(%v)", tt.requestHeaders)
|
||||
assert.Equalf(t, tt.wantOk, allowOk, "checkHeaders(%v)", tt.requestHeaders)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsConfig_checkOrigin(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowOrigins []string
|
||||
allowOriginPatterns []OriginPattern
|
||||
origin string
|
||||
wantOrigin string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "allowOrigins *",
|
||||
allowOrigins: []string{defaultMatchAll},
|
||||
allowOriginPatterns: []OriginPattern{},
|
||||
origin: "http://Httpbin.Example.COM",
|
||||
wantOrigin: "http://Httpbin.Example.COM",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "allowOrigins exact match case 1",
|
||||
allowOrigins: []string{"http://httpbin.example.com"},
|
||||
allowOriginPatterns: []OriginPattern{},
|
||||
origin: "http://HTTPBin.EXample.COM",
|
||||
wantOrigin: "http://HTTPBin.EXample.COM",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "allowOrigins exact match case 2",
|
||||
allowOrigins: []string{"https://httpbin.example.com"},
|
||||
allowOriginPatterns: []OriginPattern{},
|
||||
origin: "http://HTTPBin.EXample.COM",
|
||||
wantOrigin: "",
|
||||
wantOk: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "OriginPattern pattern match with *",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("*"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.COM",
|
||||
wantOrigin: "http://HTTPBin.EXample.COM",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "OriginPattern pattern match case with any port",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("http://*.example.com:[*]"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.COM",
|
||||
wantOrigin: "http://HTTPBin.EXample.COM",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "OriginPattern pattern match case with any port",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("http://*.example.com:[*]"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.COM:10000",
|
||||
wantOrigin: "http://HTTPBin.EXample.COM:10000",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "OriginPattern pattern match case with specail port 1",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("http://*.example.com:[8080,9090]"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.COM:10000",
|
||||
wantOrigin: "",
|
||||
wantOk: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "OriginPattern pattern match case with specail port 2",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("http://*.example.com:[8080,9090]"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.COM:9090",
|
||||
wantOrigin: "http://HTTPBin.EXample.COM:9090",
|
||||
wantOk: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "OriginPattern pattern match case with specail port 3",
|
||||
allowOrigins: []string{},
|
||||
allowOriginPatterns: []OriginPattern{
|
||||
newOriginPatternFromString("http://*.example.com:[8080,9090]"),
|
||||
},
|
||||
origin: "http://HTTPBin.EXample.org:9090",
|
||||
wantOrigin: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &CorsConfig{
|
||||
allowOrigins: tt.allowOrigins,
|
||||
allowOriginPatterns: tt.allowOriginPatterns,
|
||||
}
|
||||
allowOrigin, allowOk := c.checkOrigin(tt.origin)
|
||||
assert.Equalf(t, tt.wantOrigin, allowOrigin, "checkOrigin(%v)", tt.origin)
|
||||
assert.Equalf(t, tt.wantOk, allowOk, "checkOrigin(%v)", tt.origin)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user