mirror of
https://github.com/alibaba/higress.git
synced 2026-03-05 09:00:47 +08:00
433 lines
12 KiB
Go
433 lines
12 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"testing"
|
||
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/test"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// 测试配置:基本 CORS 配置
|
||
var basicCorsConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"allow_origins": []string{
|
||
"http://example.com",
|
||
"https://example.com",
|
||
},
|
||
"allow_methods": []string{
|
||
"GET",
|
||
"POST",
|
||
"OPTIONS",
|
||
},
|
||
"allow_headers": []string{
|
||
"Content-Type",
|
||
"Authorization",
|
||
},
|
||
"expose_headers": []string{
|
||
"X-Custom-Header",
|
||
},
|
||
"allow_credentials": false,
|
||
"max_age": 3600,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:允许所有 Origin 的配置
|
||
var allowAllOriginsConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"allow_origins": []string{
|
||
"*",
|
||
},
|
||
"allow_methods": []string{
|
||
"*",
|
||
},
|
||
"allow_headers": []string{
|
||
"*",
|
||
},
|
||
"expose_headers": []string{
|
||
"*",
|
||
},
|
||
"allow_credentials": false,
|
||
"max_age": 7200,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:带模式匹配的配置
|
||
var patternMatchConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"allow_origin_patterns": []string{
|
||
"http://*.example.com",
|
||
"http://*.example.org:[8080,9090]",
|
||
},
|
||
"allow_methods": []string{
|
||
"GET",
|
||
"POST",
|
||
"PUT",
|
||
"DELETE",
|
||
},
|
||
"allow_headers": []string{
|
||
"Content-Type",
|
||
"Token",
|
||
"Authorization",
|
||
},
|
||
"expose_headers": []string{
|
||
"X-Custom-Header",
|
||
"X-Env-UTM",
|
||
},
|
||
"allow_credentials": true,
|
||
"max_age": 1800,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:允许凭据的配置
|
||
var credentialsConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"allow_origin_patterns": []string{
|
||
"*",
|
||
},
|
||
"allow_methods": []string{
|
||
"GET",
|
||
"POST",
|
||
},
|
||
"allow_headers": []string{
|
||
"Content-Type",
|
||
"Authorization",
|
||
},
|
||
"expose_headers": []string{
|
||
"X-Custom-Header",
|
||
},
|
||
"allow_credentials": true,
|
||
"max_age": 86400,
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:默认值配置
|
||
var defaultConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{})
|
||
return data
|
||
}()
|
||
|
||
func TestParseConfig(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
// 测试基本 CORS 配置解析
|
||
t.Run("basic cors config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试允许所有 Origin 的配置解析
|
||
t.Run("allow all origins config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(allowAllOriginsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试带模式匹配的配置解析
|
||
t.Run("pattern match config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(patternMatchConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试允许凭据的配置解析
|
||
t.Run("credentials config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(credentialsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试默认值配置解析
|
||
t.Run("default config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(defaultConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试简单 CORS 请求头处理
|
||
t.Run("simple cors request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含 Origin
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://example.com"},
|
||
})
|
||
|
||
// 有效的 CORS 请求应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试预检请求头处理
|
||
t.Run("preflight request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置预检请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "OPTIONS"},
|
||
{"origin", "http://example.com"},
|
||
{"access-control-request-method", "POST"},
|
||
{"access-control-request-headers", "Content-Type, Authorization"},
|
||
})
|
||
|
||
// 预检请求应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试无效 Origin 的请求头处理
|
||
t.Run("invalid origin request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含无效的 Origin
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://invalid.com"},
|
||
})
|
||
|
||
// 无效的 CORS 请求应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试允许所有 Origin 的请求头处理
|
||
t.Run("allow all origins request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(allowAllOriginsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含任意 Origin
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://any-domain.com"},
|
||
})
|
||
|
||
// 允许所有 Origin 的配置应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试模式匹配的请求头处理
|
||
t.Run("pattern match request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(patternMatchConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含匹配模式的 Origin
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://sub.example.com"},
|
||
})
|
||
|
||
// 匹配模式的 Origin 应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试非 CORS 请求头处理
|
||
t.Run("non-cors request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,不包含 Origin
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 非 CORS 请求应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpResponseHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试 CORS 响应头处理
|
||
t.Run("cors response headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先处理请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://example.com"},
|
||
})
|
||
|
||
// 处理响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证是否添加了 CORS 响应头
|
||
responseHeaders := host.GetResponseHeaders()
|
||
require.True(t, test.HasHeader(responseHeaders, "access-control-allow-origin"))
|
||
require.True(t, test.HasHeader(responseHeaders, "access-control-expose-headers"))
|
||
|
||
// 对于简单请求,不添加 AllowMethods 和 AllowHeaders(这些只在预检请求时添加)
|
||
require.False(t, test.HasHeader(responseHeaders, "access-control-allow-methods"))
|
||
require.False(t, test.HasHeader(responseHeaders, "access-control-allow-headers"))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试非 CORS 请求的响应头处理
|
||
t.Run("non-cors response headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先处理请求头,不包含 Origin
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 处理响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证是否没有添加 CORS 响应头
|
||
responseHeaders := host.GetResponseHeaders()
|
||
require.False(t, test.HasHeader(responseHeaders, "access-control-allow-origin"))
|
||
require.False(t, test.HasHeader(responseHeaders, "access-control-expose-headers"))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试允许凭据的响应头处理
|
||
t.Run("credentials response headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(credentialsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先处理请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "GET"},
|
||
{"origin", "http://any-domain.com"},
|
||
})
|
||
|
||
// 处理响应头
|
||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||
{":status", "200"},
|
||
{"content-type", "application/json"},
|
||
})
|
||
|
||
// 应该返回 ActionContinue
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证是否添加了允许凭据的响应头
|
||
responseHeaders := host.GetResponseHeaders()
|
||
require.True(t, test.HasHeaderWithValue(responseHeaders, "access-control-allow-credentials", "true"))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试预检请求的响应头处理
|
||
t.Run("preflight response headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicCorsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 先处理预检请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/test"},
|
||
{":method", "OPTIONS"},
|
||
{"origin", "http://example.com"},
|
||
{"access-control-request-method", "POST"},
|
||
{"access-control-request-headers", "Content-Type, Authorization"},
|
||
})
|
||
|
||
// 预检请求应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|