Files
higress/plugins/wasm-go/extensions/sni-misdirect/main.go

96 lines
2.8 KiB
Go

package main
import (
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
func main() {
wrapper.SetCtx(
"sni-misdirect",
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
)
}
type Config struct {
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
// no need to check HTTP/1.0 and HTTP/1.1
protocol, err := proxywasm.GetProperty([]string{"request", "protocol"})
if err != nil {
log.Errorf("failed to get request protocol: %v", err)
return types.ActionContinue
}
if strings.HasPrefix(string(protocol), "HTTP/1") {
return types.ActionContinue
}
// no need to check http scheme
scheme := ctx.Scheme()
if scheme != "https" {
return types.ActionContinue
}
// no need to check grpc
contentType, err := proxywasm.GetHttpRequestHeader("content-type")
if err != nil {
log.Errorf("failed to get request content-type: %v", err)
return types.ActionContinue
}
if strings.HasPrefix(contentType, "application/grpc") {
return types.ActionContinue
}
// get sni
sni, err := proxywasm.GetProperty([]string{"connection", "requested_server_name"})
if err != nil {
log.Errorf("failed to get requested_server_name: %v", err)
return types.ActionContinue
}
// get authority
host, err := proxywasm.GetHttpRequestHeader(":authority")
if err != nil {
log.Errorf("failed to get request authority: %v", err)
return types.ActionContinue
}
host = stripPortFromHost(host)
if string(sni) == host {
return types.ActionContinue
}
if !strings.HasPrefix(string(sni), "*.") {
proxywasm.SendHttpResponseWithDetail(http.StatusMisdirectedRequest, "sni-misdirect.mismatched.non_wildcard", nil, []byte("Misdirected Request"), -1)
return types.ActionPause
}
if !strings.Contains(host, string(sni)[1:]) {
proxywasm.SendHttpResponseWithDetail(http.StatusMisdirectedRequest, "sni-misdirect.mismatched.wildcard", nil, []byte("Misdirected Request"), -1)
return types.ActionPause
}
return types.ActionContinue
}
func stripPortFromHost(requestHost string) string {
// Find the last occurrence of ':' to locate the port.
portStart := strings.LastIndex(requestHost, ":")
// Check if ':' is found.
if portStart != -1 {
// According to RFC3986, IPv6 address is always enclosed in "[]".
// section 3.2.2.
v6EndIndex := strings.LastIndex(requestHost, "]")
// Check if ']' is found and its position is after the ':'.
if v6EndIndex == -1 || v6EndIndex < portStart {
// Check if there are characters after ':'.
if portStart+1 <= len(requestHost) {
// Return the substring without the port.
return requestHost[:portStart]
}
}
}
// If no port is found or the conditions are not met, return the original requestHost.
return requestHost
}