Files
higress/plugins/wasm-go/extensions/request-block/main_test.go

563 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
var testConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_code": 403,
"blocked_message": "Access denied",
"case_sensitive": false,
"block_urls": []string{"blocked", "forbidden"},
"block_exact_urls": []string{"/exact-block", "/admin"},
"block_regexp_urls": []string{`/api/v\d+/blocked`},
"block_headers": []string{"blocked-header", "malicious"},
"block_bodies": []string{"blocked-content", "spam"},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
blockConfig := config.(*RequestBlockConfig)
require.Equal(t, uint32(403), blockConfig.blockedCode)
require.Equal(t, "Access denied", blockConfig.blockedMessage)
require.False(t, blockConfig.caseSensitive)
require.Contains(t, blockConfig.blockUrls, "blocked")
require.Contains(t, blockConfig.blockUrls, "forbidden")
require.Contains(t, blockConfig.blockExactUrls, "/exact-block")
require.Contains(t, blockConfig.blockExactUrls, "/admin")
require.Contains(t, blockConfig.blockHeaders, "blocked-header")
require.Contains(t, blockConfig.blockHeaders, "malicious")
require.Contains(t, blockConfig.blockBodies, "blocked-content")
require.Contains(t, blockConfig.blockBodies, "spam")
})
}
func TestBlockUrlByKeyword(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test blocked URL by keyword
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/blocked/endpoint"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
require.Equal(t, "Access denied", string(localResponse.Data))
host.CompleteHttp()
})
}
func TestBlockUrlByExactMatch(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test blocked URL by exact match
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/exact-block"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
require.Equal(t, "Access denied", string(localResponse.Data))
host.CompleteHttp()
})
}
func TestBlockUrlByRegexp(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test blocked URL by regexp
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/v1/blocked"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
require.Equal(t, "Access denied", string(localResponse.Data))
host.CompleteHttp()
})
}
func TestBlockByHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test blocked by headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/valid"},
{"blocked-header", "some-value"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
require.Equal(t, "Access denied", string(localResponse.Data))
host.CompleteHttp()
})
}
func TestBlockByBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Use a config that only has body blocking rules
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// First call headers to set up context - use a path that won't be blocked by URL rules
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/safe/endpoint"},
})
require.Equal(t, types.ActionContinue, action)
// Test blocked by body content
action = host.CallOnHttpRequestBody([]byte("This is blocked-content in the body"))
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
require.Equal(t, "Access denied", string(localResponse.Data))
host.CompleteHttp()
})
}
func TestAllowValidRequest(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test valid request should be allowed
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/valid/endpoint"},
{"valid-header", "valid-value"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "Valid request should not be blocked")
host.CompleteHttp()
})
}
func TestCaseInsensitiveBlocking(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Test case insensitive blocking (config has case_sensitive: false)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/API/BLOCKED/ENDPOINT"}, // Uppercase should still be blocked
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
host.CompleteHttp()
})
}
func TestCustomBlockedCode(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
customConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_code": 429,
"blocked_message": "Too many requests",
"case_sensitive": false,
"block_urls": []string{"rate-limit"},
})
return data
}()
host, status := test.NewTestHost(customConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/rate-limit/test"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(429), localResponse.StatusCode)
require.Equal(t, "Too many requests", string(localResponse.Data))
host.CompleteHttp()
})
}
// 测试配置解析中的边界情况
func TestParseConfigEdgeCases(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试无效的blocked_code使用默认值403
t.Run("invalid blocked_code", func(t *testing.T) {
invalidCodeConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_code": 999, // 无效状态码
"blocked_message": "Invalid code",
"block_urls": []string{"test"},
})
return data
}()
host, status := test.NewTestHost(invalidCodeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
blockConfig := config.(*RequestBlockConfig)
require.Equal(t, uint32(403), blockConfig.blockedCode) // 应该使用默认值
})
// 测试case_sensitive为true的情况
t.Run("case sensitive true", func(t *testing.T) {
caseSensitiveConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"case_sensitive": true,
"block_urls": []string{"BLOCKED"},
"block_headers": []string{"BLOCKED-HEADER"},
"block_bodies": []string{"BLOCKED-CONTENT"},
})
return data
}()
host, status := test.NewTestHost(caseSensitiveConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
blockConfig := config.(*RequestBlockConfig)
require.True(t, blockConfig.caseSensitive)
require.Contains(t, blockConfig.blockUrls, "BLOCKED") // 保持大写
require.Contains(t, blockConfig.blockHeaders, "BLOCKED-HEADER")
require.Contains(t, blockConfig.blockBodies, "BLOCKED-CONTENT")
})
// 测试空字符串的处理
t.Run("empty strings handling", func(t *testing.T) {
emptyStringsConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"block_urls": []string{"valid", ""}, // 包含空字符串
"block_exact_urls": []string{"", "valid"}, // 包含空字符串
"block_regexp_urls": []string{"", "valid"}, // 包含空字符串
"block_headers": []string{"", "valid"}, // 包含空字符串
"block_bodies": []string{"valid", ""}, // 包含空字符串
})
return data
}()
host, status := test.NewTestHost(emptyStringsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
blockConfig := config.(*RequestBlockConfig)
// 空字符串应该被过滤掉
require.Contains(t, blockConfig.blockUrls, "valid")
require.NotContains(t, blockConfig.blockUrls, "")
require.Contains(t, blockConfig.blockExactUrls, "valid")
require.NotContains(t, blockConfig.blockExactUrls, "")
})
// 测试没有block规则的情况应该返回错误
t.Run("no block rules", func(t *testing.T) {
noRulesConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_message": "No rules",
// 没有提供任何block规则
})
return data
}()
host, status := test.NewTestHost(noRulesConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
// 测试onHttpRequestHeaders中的错误处理路径
func TestOnHttpRequestHeadersErrorHandling(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试获取路径失败的情况
t.Run("get path failed", func(t *testing.T) {
host, status := test.NewTestHost(testConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 使用不包含:path的头部模拟获取路径失败
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
// 缺少 :path 头部
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse)
host.CompleteHttp()
})
// 测试获取头部失败的情况
t.Run("get headers failed", func(t *testing.T) {
// 创建一个只有block_headers的配置
headerOnlyConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_code": 403,
"blocked_message": "Header blocked",
"block_headers": []string{"blocked-header"},
})
return data
}()
host, status := test.NewTestHost(headerOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/test"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
host.CompleteHttp()
})
// 测试只有block_bodies的情况应该调用DontReadRequestBody
t.Run("only block bodies", func(t *testing.T) {
bodyOnlyConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"blocked_code": 403,
"blocked_message": "Body blocked",
"block_bodies": []string{"blocked-content"},
})
return data
}()
host, status := test.NewTestHost(bodyOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/test"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
host.CompleteHttp()
})
})
}
// 测试onHttpRequestBody中的case_sensitive处理
func TestOnHttpRequestBodyCaseSensitive(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试case_sensitive为true的情况
t.Run("case sensitive true", func(t *testing.T) {
caseSensitiveConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"case_sensitive": true,
"blocked_code": 403,
"blocked_message": "Body blocked",
"block_bodies": []string{"BLOCKED"},
})
return data
}()
host, status := test.NewTestHost(caseSensitiveConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先调用头部处理
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/test"},
})
require.Equal(t, types.ActionContinue, action)
// 测试大写内容应该被阻止
action = host.CallOnHttpRequestBody([]byte("This contains BLOCKED content"))
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
host.CompleteHttp()
})
// 测试case_sensitive为false的情况小写内容应该被阻止
t.Run("case sensitive false", func(t *testing.T) {
caseInsensitiveConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"case_sensitive": false,
"block_bodies": []string{"blocked"},
})
return data
}()
host, status := test.NewTestHost(caseInsensitiveConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先调用头部处理
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/test"},
})
require.Equal(t, types.ActionContinue, action)
// 测试大写内容应该被阻止因为case_sensitive为false
action = host.CallOnHttpRequestBody([]byte("This contains BLOCKED content"))
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
host.CompleteHttp()
})
})
}
// 测试正则表达式URL阻塞的边界情况
func TestBlockUrlByRegexpEdgeCases(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试复杂的正则表达式
t.Run("complex regexp", func(t *testing.T) {
complexRegexpConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"case_sensitive": true,
"blocked_code": 403,
"blocked_message": "Blocked by regexp",
"block_urls": []string{"dummy"}, // 添加一个dummy规则以满足配置检查
"block_regexp_urls": []string{`/api/v\d+/users/\d+/posts`},
})
return data
}()
host, status := test.NewTestHost(complexRegexpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 测试匹配的URL
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/v2/users/123/posts"},
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(403), localResponse.StatusCode)
// 确保请求完成
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
host.CompleteHttp()
})
// 测试不匹配的正则表达式
t.Run("non-matching regexp", func(t *testing.T) {
regexpConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"case_sensitive": true,
"blocked_code": 403,
"blocked_message": "Blocked by regexp",
"block_urls": []string{"dummy"}, // 添加一个dummy规则以满足配置检查
"block_regexp_urls": []string{`/api/v\d+/blocked`},
})
return data
}()
host, status := test.NewTestHost(regexpConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 测试不匹配的URL
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "test.com"},
{":path", "/api/blocked"}, // 不匹配 /api/v\d+/blocked
})
require.Equal(t, types.ActionContinue, action)
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse)
// 确保请求完成
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
host.CompleteHttp()
})
})
}