From f1345f9973d94516e784b34165f73b9b2d4e4a69 Mon Sep 17 00:00:00 2001 From: Jingze <52855280+Jing-ze@users.noreply.github.com> Date: Sun, 21 Sep 2025 14:34:51 +0800 Subject: [PATCH] fix: optimize host pattern matching and fix SSE newline bug (#2899) --- plugins/golang-filter/mcp-server/config.go | 15 ++++++----- plugins/golang-filter/mcp-server/filter.go | 4 +-- .../golang-filter/mcp-session/common/match.go | 27 +++++++++---------- .../golang-filter/mcp-session/common/sse.go | 7 ++++- .../golang-filter/mcp-session/common/utils.go | 9 +++---- plugins/golang-filter/mcp-session/filter.go | 2 +- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/plugins/golang-filter/mcp-server/config.go b/plugins/golang-filter/mcp-server/config.go index d47f038e9..05cf2d12a 100644 --- a/plugins/golang-filter/mcp-server/config.go +++ b/plugins/golang-filter/mcp-server/config.go @@ -19,8 +19,8 @@ const ( ) type SSEServerWrapper struct { - BaseServer *common.SSEServer - DomainList []string + BaseServer *common.SSEServer + HostMatchers []common.HostMatcher // Pre-parsed host matchers for efficient matching } type config struct { @@ -68,15 +68,18 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int return nil, fmt.Errorf("server %s path is not set", serverType) } - serverDomainList := []string{} + // Parse domain list directly into HostMatchers for efficient matching + var hostMatchers []common.HostMatcher if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok { + hostMatchers = make([]common.HostMatcher, 0, len(domainList)) for _, domain := range domainList { if domainStr, ok := domain.(string); ok { - serverDomainList = append(serverDomainList, domainStr) + hostMatchers = append(hostMatchers, common.ParseHostPattern(domainStr)) } } } else { - serverDomainList = []string{"*"} + // Default to match all domains + hostMatchers = []common.HostMatcher{common.ParseHostPattern("*")} } serverName, ok := serverConfigMap["name"].(string) @@ -108,7 +111,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int BaseServer: common.NewSSEServer(serverInstance, common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)), common.WithMessageEndpoint(serverPath)), - DomainList: serverDomainList, + HostMatchers: hostMatchers, }) api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType)) } diff --git a/plugins/golang-filter/mcp-server/filter.go b/plugins/golang-filter/mcp-server/filter.go index 9251839a2..5d325f556 100644 --- a/plugins/golang-filter/mcp-server/filter.go +++ b/plugins/golang-filter/mcp-server/filter.go @@ -30,7 +30,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. f.host = url.Host for _, server := range f.config.servers { - if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { + if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { if url.Method != http.MethodPost { f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") return api.LocalReply @@ -62,7 +62,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType { if f.message { for _, server := range f.config.servers { - if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { + if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) { if !endStream { return api.StopAndBuffer } diff --git a/plugins/golang-filter/mcp-session/common/match.go b/plugins/golang-filter/mcp-session/common/match.go index 405a139be..fd0b10263 100644 --- a/plugins/golang-filter/mcp-session/common/match.go +++ b/plugins/golang-filter/mcp-session/common/match.go @@ -46,6 +46,7 @@ type MatchRule struct { UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to. + HostMatcher HostMatcher // Host matcher for efficient matching } // ParseMatchList parses the match list from the config @@ -91,6 +92,9 @@ func ParseMatchList(matchListConfig []interface{}) []MatchRule { rule.PathRewritePrefix = "/" + rule.PathRewritePrefix } } + + rule.HostMatcher = ParseHostPattern(rule.MatchRuleDomain) + matchList = append(matchList, rule) } } @@ -115,8 +119,8 @@ func stripPortFromHost(reqHost string) string { return reqHost } -// parseHostPattern parses a host pattern and returns a HostMatcher -func parseHostPattern(pattern string) HostMatcher { +// ParseHostPattern parses a host pattern and returns a HostMatcher +func ParseHostPattern(pattern string) HostMatcher { var hostMatcher HostMatcher if strings.HasPrefix(pattern, "*") { hostMatcher.matchType = HostSuffix @@ -157,18 +161,11 @@ func matchPattern(pattern string, target string, ruleType RuleType) bool { } } -// matchDomain checks if the domain matches the pattern using HostMatcher approach -func matchDomain(domain string, pattern string) bool { - if pattern == "" || pattern == "*" { - return true - } - +// matchDomainWithMatcher checks if the domain matches using a pre-parsed HostMatcher +func matchDomainWithMatcher(domain string, hostMatcher HostMatcher) bool { // Strip port from domain domain = stripPortFromHost(domain) - // Parse the pattern into a HostMatcher - hostMatcher := parseHostPattern(pattern) - // Perform matching based on match type switch hostMatcher.matchType { case HostSuffix: @@ -184,7 +181,7 @@ func matchDomain(domain string, pattern string) bool { // matchDomainAndPath checks if both domain and path match the rule func matchDomainAndPath(domain, path string, rule MatchRule) bool { - return matchDomain(domain, rule.MatchRuleDomain) && + return matchDomainWithMatcher(domain, rule.HostMatcher) && matchPattern(rule.MatchRulePath, path, rule.MatchRuleType) } @@ -204,9 +201,9 @@ func IsMatch(rules []MatchRule, host, path string) (bool, MatchRule) { } // MatchDomainList checks if the domain matches any of the domains in the list -func MatchDomainList(domain string, domainList []string) bool { - for _, d := range domainList { - if matchDomain(domain, d) { +func MatchDomainWithMatchers(domain string, hostMatchers []HostMatcher) bool { + for _, hostMatcher := range hostMatchers { + if matchDomainWithMatcher(domain, hostMatcher) { return true } } diff --git a/plugins/golang-filter/mcp-session/common/sse.go b/plugins/golang-filter/mcp-session/common/sse.go index 77acc0dfe..f845b156a 100644 --- a/plugins/golang-filter/mcp-session/common/sse.go +++ b/plugins/golang-filter/mcp-session/common/sse.go @@ -223,7 +223,12 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j } // Send HTTP response w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + jsonData, err := json.Marshal(response) + if err != nil { + api.LogErrorf("Failed to marshal SSE Message response: %v", err) + status = http.StatusInternalServerError + } + w.Write(jsonData) } else { // For notifications, just send 202 Accepted with no body w.WriteHeader(http.StatusAccepted) diff --git a/plugins/golang-filter/mcp-session/common/utils.go b/plugins/golang-filter/mcp-session/common/utils.go index 099dd5c05..f2c2c3cc0 100644 --- a/plugins/golang-filter/mcp-session/common/utils.go +++ b/plugins/golang-filter/mcp-session/common/utils.go @@ -12,7 +12,6 @@ type RequestURL struct { Scheme string Host string Path string - BaseURL string ParsedURL *url.URL InternalIP bool } @@ -23,12 +22,12 @@ func NewRequestURL(header api.RequestHeaderMap) *RequestURL { host, _ := header.Get(":authority") path, _ := header.Get(":path") internalIP, _ := header.Get("x-envoy-internal") - baseURL := fmt.Sprintf("%s://%s", scheme, host) - parsedURL, err := url.Parse(path) + fullURL := fmt.Sprintf("%s://%s%s", scheme, host, path) + parsedURL, err := url.Parse(fullURL) if err != nil { - api.LogWarnf("url parse path:%s failed:%s", path, err) + api.LogWarnf("url parse fullURL:%s failed:%s", fullURL, err) return nil } api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path) - return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"} + return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, ParsedURL: parsedURL, InternalIP: internalIP == "true"} } diff --git a/plugins/golang-filter/mcp-session/filter.go b/plugins/golang-filter/mcp-session/filter.go index e67da1151..f31fa2404 100644 --- a/plugins/golang-filter/mcp-session/filter.go +++ b/plugins/golang-filter/mcp-session/filter.go @@ -264,7 +264,7 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream sessionID := f.proxyURL.Query().Get("sessionId") if sessionID != "" { channel := common.GetSSEChannelName(sessionID) - eventData := fmt.Sprintf("event: message\ndata: %s\n\n", strings.TrimSuffix(buffer.String(), "\n")) + eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String()) publishErr := f.config.redisClient.Publish(channel, eventData) if publishErr != nil { api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)