fix: optimize host pattern matching and fix SSE newline bug (#2899)

This commit is contained in:
Jingze
2025-09-21 14:34:51 +08:00
committed by GitHub
parent de8a9c539b
commit f1345f9973
6 changed files with 34 additions and 30 deletions

View File

@@ -19,8 +19,8 @@ const (
) )
type SSEServerWrapper struct { type SSEServerWrapper struct {
BaseServer *common.SSEServer BaseServer *common.SSEServer
DomainList []string HostMatchers []common.HostMatcher // Pre-parsed host matchers for efficient matching
} }
type config struct { type config struct {
@@ -68,15 +68,18 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
return nil, fmt.Errorf("server %s path is not set", serverType) return nil, fmt.Errorf("server %s path is not set", serverType)
} }
serverDomainList := []string{} // Parse domain list directly into HostMatchers for efficient matching
var hostMatchers []common.HostMatcher
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok { if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
hostMatchers = make([]common.HostMatcher, 0, len(domainList))
for _, domain := range domainList { for _, domain := range domainList {
if domainStr, ok := domain.(string); ok { if domainStr, ok := domain.(string); ok {
serverDomainList = append(serverDomainList, domainStr) hostMatchers = append(hostMatchers, common.ParseHostPattern(domainStr))
} }
} }
} else { } else {
serverDomainList = []string{"*"} // Default to match all domains
hostMatchers = []common.HostMatcher{common.ParseHostPattern("*")}
} }
serverName, ok := serverConfigMap["name"].(string) serverName, ok := serverConfigMap["name"].(string)
@@ -108,7 +111,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
BaseServer: common.NewSSEServer(serverInstance, BaseServer: common.NewSSEServer(serverInstance,
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)), common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
common.WithMessageEndpoint(serverPath)), common.WithMessageEndpoint(serverPath)),
DomainList: serverDomainList, HostMatchers: hostMatchers,
}) })
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType)) api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
} }

View File

@@ -30,7 +30,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.host = url.Host f.host = url.Host
for _, server := range f.config.servers { for _, server := range f.config.servers {
if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
if url.Method != http.MethodPost { if url.Method != http.MethodPost {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
return api.LocalReply return api.LocalReply
@@ -62,7 +62,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType { func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.message { if f.message {
for _, server := range f.config.servers { for _, server := range f.config.servers {
if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
if !endStream { if !endStream {
return api.StopAndBuffer return api.StopAndBuffer
} }

View File

@@ -46,6 +46,7 @@ type MatchRule struct {
UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule
EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes
PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to. PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to.
HostMatcher HostMatcher // Host matcher for efficient matching
} }
// ParseMatchList parses the match list from the config // ParseMatchList parses the match list from the config
@@ -91,6 +92,9 @@ func ParseMatchList(matchListConfig []interface{}) []MatchRule {
rule.PathRewritePrefix = "/" + rule.PathRewritePrefix rule.PathRewritePrefix = "/" + rule.PathRewritePrefix
} }
} }
rule.HostMatcher = ParseHostPattern(rule.MatchRuleDomain)
matchList = append(matchList, rule) matchList = append(matchList, rule)
} }
} }
@@ -115,8 +119,8 @@ func stripPortFromHost(reqHost string) string {
return reqHost return reqHost
} }
// parseHostPattern parses a host pattern and returns a HostMatcher // ParseHostPattern parses a host pattern and returns a HostMatcher
func parseHostPattern(pattern string) HostMatcher { func ParseHostPattern(pattern string) HostMatcher {
var hostMatcher HostMatcher var hostMatcher HostMatcher
if strings.HasPrefix(pattern, "*") { if strings.HasPrefix(pattern, "*") {
hostMatcher.matchType = HostSuffix hostMatcher.matchType = HostSuffix
@@ -157,18 +161,11 @@ func matchPattern(pattern string, target string, ruleType RuleType) bool {
} }
} }
// matchDomain checks if the domain matches the pattern using HostMatcher approach // matchDomainWithMatcher checks if the domain matches using a pre-parsed HostMatcher
func matchDomain(domain string, pattern string) bool { func matchDomainWithMatcher(domain string, hostMatcher HostMatcher) bool {
if pattern == "" || pattern == "*" {
return true
}
// Strip port from domain // Strip port from domain
domain = stripPortFromHost(domain) domain = stripPortFromHost(domain)
// Parse the pattern into a HostMatcher
hostMatcher := parseHostPattern(pattern)
// Perform matching based on match type // Perform matching based on match type
switch hostMatcher.matchType { switch hostMatcher.matchType {
case HostSuffix: case HostSuffix:
@@ -184,7 +181,7 @@ func matchDomain(domain string, pattern string) bool {
// matchDomainAndPath checks if both domain and path match the rule // matchDomainAndPath checks if both domain and path match the rule
func matchDomainAndPath(domain, path string, rule MatchRule) bool { func matchDomainAndPath(domain, path string, rule MatchRule) bool {
return matchDomain(domain, rule.MatchRuleDomain) && return matchDomainWithMatcher(domain, rule.HostMatcher) &&
matchPattern(rule.MatchRulePath, path, rule.MatchRuleType) matchPattern(rule.MatchRulePath, path, rule.MatchRuleType)
} }
@@ -204,9 +201,9 @@ func IsMatch(rules []MatchRule, host, path string) (bool, MatchRule) {
} }
// MatchDomainList checks if the domain matches any of the domains in the list // MatchDomainList checks if the domain matches any of the domains in the list
func MatchDomainList(domain string, domainList []string) bool { func MatchDomainWithMatchers(domain string, hostMatchers []HostMatcher) bool {
for _, d := range domainList { for _, hostMatcher := range hostMatchers {
if matchDomain(domain, d) { if matchDomainWithMatcher(domain, hostMatcher) {
return true return true
} }
} }

View File

@@ -223,7 +223,12 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
} }
// Send HTTP response // Send HTTP response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response) jsonData, err := json.Marshal(response)
if err != nil {
api.LogErrorf("Failed to marshal SSE Message response: %v", err)
status = http.StatusInternalServerError
}
w.Write(jsonData)
} else { } else {
// For notifications, just send 202 Accepted with no body // For notifications, just send 202 Accepted with no body
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)

View File

@@ -12,7 +12,6 @@ type RequestURL struct {
Scheme string Scheme string
Host string Host string
Path string Path string
BaseURL string
ParsedURL *url.URL ParsedURL *url.URL
InternalIP bool InternalIP bool
} }
@@ -23,12 +22,12 @@ func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
host, _ := header.Get(":authority") host, _ := header.Get(":authority")
path, _ := header.Get(":path") path, _ := header.Get(":path")
internalIP, _ := header.Get("x-envoy-internal") internalIP, _ := header.Get("x-envoy-internal")
baseURL := fmt.Sprintf("%s://%s", scheme, host) fullURL := fmt.Sprintf("%s://%s%s", scheme, host, path)
parsedURL, err := url.Parse(path) parsedURL, err := url.Parse(fullURL)
if err != nil { if err != nil {
api.LogWarnf("url parse path:%s failed:%s", path, err) api.LogWarnf("url parse fullURL:%s failed:%s", fullURL, err)
return nil return nil
} }
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path) api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"} return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
} }

View File

@@ -264,7 +264,7 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
sessionID := f.proxyURL.Query().Get("sessionId") sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" { if sessionID != "" {
channel := common.GetSSEChannelName(sessionID) channel := common.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", strings.TrimSuffix(buffer.String(), "\n")) eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := f.config.redisClient.Publish(channel, eventData) publishErr := f.config.redisClient.Publish(channel, eventData)
if publishErr != nil { if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr) api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)