Files
higress/plugins/wasm-go/extensions/sni-misdirect/main_test.go

289 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 HTTP/1.1 协议(应该直接通过)
t.Run("HTTP/1.1 protocol", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTP/1.1 请求
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/1.1"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "http"},
{":authority", "example.com"},
{":path", "/test"},
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "HTTP/1.1 request should pass through")
host.CompleteHttp()
})
// 测试 HTTP 协议(非 HTTPS应该直接通过
t.Run("HTTP scheme", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTP 请求
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/test"},
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "HTTP request should pass through")
host.CompleteHttp()
})
// 测试 gRPC 请求(应该直接通过)
t.Run("gRPC request", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 gRPC 请求
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "example.com"},
{":path", "/test"},
{":method", "POST"},
{"content-type", "application/grpc"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "gRPC request should pass through")
host.CompleteHttp()
})
// 测试 SNI 和 Host 匹配的情况(应该通过)
t.Run("SNI matches Host", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTPS 请求SNI 和 Host 匹配
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("example.com"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "example.com"},
{":path", "/test"},
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "Matching SNI and Host should pass through")
host.CompleteHttp()
})
// 测试 SNI 和 Host 不匹配的情况(非通配符,应该被阻止)
t.Run("SNI mismatches Host non-wildcard", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTPS 请求SNI 和 Host 不匹配
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("evil.com"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "example.com"},
{":path", "/test"},
{":method", "GET"},
{"content-type", "text/plain"},
})
require.Equal(t, types.ActionPause, action)
require.Equal(t, types.ActionPause, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(421), localResponse.StatusCode) // 421 Misdirected Request
require.Equal(t, "Misdirected Request", string(localResponse.Data))
host.CompleteHttp()
})
// 测试通配符 SNI 匹配的情况(应该通过)
t.Run("Wildcard SNI matches Host", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTPS 请求,通配符 SNI 匹配 Host
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("*.example.com"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "sub.example.com"},
{":path", "/test"},
{":method", "GET"},
{"content-type", "text/plain"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "Wildcard SNI matching Host should pass through")
host.CompleteHttp()
})
// 测试通配符 SNI 不匹配的情况(应该被阻止)
t.Run("Wildcard SNI mismatches Host", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTPS 请求,通配符 SNI 不匹配 Host
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("*.example.com"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "other.com"},
{":path", "/test"},
{":method", "GET"},
{"content-type", "text/plain"},
})
require.Equal(t, types.ActionPause, action)
require.Equal(t, types.ActionPause, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
require.Equal(t, uint32(421), localResponse.StatusCode) // 421 Misdirected Request
require.Equal(t, "Misdirected Request", string(localResponse.Data))
host.CompleteHttp()
})
// 测试带端口的 Host应该正确处理
t.Run("Host with port", func(t *testing.T) {
host, status := test.NewTestHost(nil)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟 HTTPS 请求Host 带端口
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("example.com"))
action := host.CallOnHttpRequestHeaders([][2]string{
{":scheme", "https"},
{":authority", "example.com:443"},
{":path", "/test"},
{":method", "GET"},
{"content-type", "text/plain"},
})
require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
localResponse := host.GetLocalResponse()
require.Nil(t, localResponse, "Host with port should be handled correctly")
host.CompleteHttp()
})
})
}
func TestStripPortFromHost(t *testing.T) {
// 测试 stripPortFromHost 函数
t.Run("host without port", func(t *testing.T) {
result := stripPortFromHost("example.com")
require.Equal(t, "example.com", result)
})
t.Run("host with port", func(t *testing.T) {
result := stripPortFromHost("example.com:8080")
require.Equal(t, "example.com", result)
})
t.Run("host with multiple colons", func(t *testing.T) {
result := stripPortFromHost("example.com:8080:9090")
require.Equal(t, "example.com:8080", result)
})
t.Run("IPv6 host without port", func(t *testing.T) {
result := stripPortFromHost("[2001:db8::1]")
require.Equal(t, "[2001:db8::1]", result)
})
t.Run("IPv6 host with port", func(t *testing.T) {
result := stripPortFromHost("[2001:db8::1]:443")
require.Equal(t, "[2001:db8::1]", result)
})
t.Run("IPv6 host with port and multiple colons", func(t *testing.T) {
result := stripPortFromHost("[2001:db8::1]:443:8080")
require.Equal(t, "[2001:db8::1]:443", result)
})
t.Run("empty host", func(t *testing.T) {
result := stripPortFromHost("")
require.Equal(t, "", result)
})
t.Run("host with colon at end", func(t *testing.T) {
result := stripPortFromHost("example.com:")
require.Equal(t, "example.com", result)
})
t.Run("IPv6 host with colon at end", func(t *testing.T) {
result := stripPortFromHost("[2001:db8::1]:")
require.Equal(t, "[2001:db8::1]", result)
})
}