mirror of
https://github.com/alibaba/higress.git
synced 2026-05-08 04:17:27 +08:00
289 lines
9.2 KiB
Go
289 lines
9.2 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/test"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试 HTTP/1.1 协议(应该直接通过)
|
||
t.Run("HTTP/1.1 protocol", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTP/1.1 请求
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/1.1"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "http"},
|
||
{":authority", "example.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "HTTP/1.1 request should pass through")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试 HTTP 协议(非 HTTPS,应该直接通过)
|
||
t.Run("HTTP scheme", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTP 请求
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "HTTP request should pass through")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试 gRPC 请求(应该直接通过)
|
||
t.Run("gRPC request", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 gRPC 请求
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "example.com"},
|
||
{":path", "/test"},
|
||
{":method", "POST"},
|
||
{"content-type", "application/grpc"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "gRPC request should pass through")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试 SNI 和 Host 匹配的情况(应该通过)
|
||
t.Run("SNI matches Host", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTPS 请求,SNI 和 Host 匹配
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("example.com"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "example.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "Matching SNI and Host should pass through")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试 SNI 和 Host 不匹配的情况(非通配符,应该被阻止)
|
||
t.Run("SNI mismatches Host non-wildcard", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTPS 请求,SNI 和 Host 不匹配
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("evil.com"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "example.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
{"content-type", "text/plain"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionPause, action)
|
||
require.Equal(t, types.ActionPause, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.NotNil(t, localResponse)
|
||
require.Equal(t, uint32(421), localResponse.StatusCode) // 421 Misdirected Request
|
||
require.Equal(t, "Misdirected Request", string(localResponse.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试通配符 SNI 匹配的情况(应该通过)
|
||
t.Run("Wildcard SNI matches Host", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTPS 请求,通配符 SNI 匹配 Host
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("*.example.com"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "sub.example.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
{"content-type", "text/plain"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "Wildcard SNI matching Host should pass through")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试通配符 SNI 不匹配的情况(应该被阻止)
|
||
t.Run("Wildcard SNI mismatches Host", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTPS 请求,通配符 SNI 不匹配 Host
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("*.example.com"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "other.com"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
{"content-type", "text/plain"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionPause, action)
|
||
require.Equal(t, types.ActionPause, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.NotNil(t, localResponse)
|
||
require.Equal(t, uint32(421), localResponse.StatusCode) // 421 Misdirected Request
|
||
require.Equal(t, "Misdirected Request", string(localResponse.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试带端口的 Host(应该正确处理)
|
||
t.Run("Host with port", func(t *testing.T) {
|
||
host, status := test.NewTestHost(nil)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 模拟 HTTPS 请求,Host 带端口
|
||
host.SetProperty([]string{"request", "protocol"}, []byte("HTTP/2"))
|
||
host.SetProperty([]string{"connection", "requested_server_name"}, []byte("example.com"))
|
||
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":scheme", "https"},
|
||
{":authority", "example.com:443"},
|
||
{":path", "/test"},
|
||
{":method", "GET"},
|
||
{"content-type", "text/plain"},
|
||
})
|
||
|
||
require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.ActionContinue, host.GetHttpStreamAction())
|
||
|
||
localResponse := host.GetLocalResponse()
|
||
require.Nil(t, localResponse, "Host with port should be handled correctly")
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestStripPortFromHost(t *testing.T) {
|
||
// 测试 stripPortFromHost 函数
|
||
t.Run("host without port", func(t *testing.T) {
|
||
result := stripPortFromHost("example.com")
|
||
require.Equal(t, "example.com", result)
|
||
})
|
||
|
||
t.Run("host with port", func(t *testing.T) {
|
||
result := stripPortFromHost("example.com:8080")
|
||
require.Equal(t, "example.com", result)
|
||
})
|
||
|
||
t.Run("host with multiple colons", func(t *testing.T) {
|
||
result := stripPortFromHost("example.com:8080:9090")
|
||
require.Equal(t, "example.com:8080", result)
|
||
})
|
||
|
||
t.Run("IPv6 host without port", func(t *testing.T) {
|
||
result := stripPortFromHost("[2001:db8::1]")
|
||
require.Equal(t, "[2001:db8::1]", result)
|
||
})
|
||
|
||
t.Run("IPv6 host with port", func(t *testing.T) {
|
||
result := stripPortFromHost("[2001:db8::1]:443")
|
||
require.Equal(t, "[2001:db8::1]", result)
|
||
})
|
||
|
||
t.Run("IPv6 host with port and multiple colons", func(t *testing.T) {
|
||
result := stripPortFromHost("[2001:db8::1]:443:8080")
|
||
require.Equal(t, "[2001:db8::1]:443", result)
|
||
})
|
||
|
||
t.Run("empty host", func(t *testing.T) {
|
||
result := stripPortFromHost("")
|
||
require.Equal(t, "", result)
|
||
})
|
||
|
||
t.Run("host with colon at end", func(t *testing.T) {
|
||
result := stripPortFromHost("example.com:")
|
||
require.Equal(t, "example.com", result)
|
||
})
|
||
|
||
t.Run("IPv6 host with colon at end", func(t *testing.T) {
|
||
result := stripPortFromHost("[2001:db8::1]:")
|
||
require.Equal(t, "[2001:db8::1]", result)
|
||
})
|
||
}
|