mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 20:17:29 +08:00
fix: optimize host pattern matching and fix SSE newline bug (#2899)
This commit is contained in:
@@ -19,8 +19,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type SSEServerWrapper struct {
|
type SSEServerWrapper struct {
|
||||||
BaseServer *common.SSEServer
|
BaseServer *common.SSEServer
|
||||||
DomainList []string
|
HostMatchers []common.HostMatcher // Pre-parsed host matchers for efficient matching
|
||||||
}
|
}
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
@@ -68,15 +68,18 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
|||||||
return nil, fmt.Errorf("server %s path is not set", serverType)
|
return nil, fmt.Errorf("server %s path is not set", serverType)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverDomainList := []string{}
|
// Parse domain list directly into HostMatchers for efficient matching
|
||||||
|
var hostMatchers []common.HostMatcher
|
||||||
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
|
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
|
||||||
|
hostMatchers = make([]common.HostMatcher, 0, len(domainList))
|
||||||
for _, domain := range domainList {
|
for _, domain := range domainList {
|
||||||
if domainStr, ok := domain.(string); ok {
|
if domainStr, ok := domain.(string); ok {
|
||||||
serverDomainList = append(serverDomainList, domainStr)
|
hostMatchers = append(hostMatchers, common.ParseHostPattern(domainStr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
serverDomainList = []string{"*"}
|
// Default to match all domains
|
||||||
|
hostMatchers = []common.HostMatcher{common.ParseHostPattern("*")}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverName, ok := serverConfigMap["name"].(string)
|
serverName, ok := serverConfigMap["name"].(string)
|
||||||
@@ -108,7 +111,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
|||||||
BaseServer: common.NewSSEServer(serverInstance,
|
BaseServer: common.NewSSEServer(serverInstance,
|
||||||
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
||||||
common.WithMessageEndpoint(serverPath)),
|
common.WithMessageEndpoint(serverPath)),
|
||||||
DomainList: serverDomainList,
|
HostMatchers: hostMatchers,
|
||||||
})
|
})
|
||||||
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
|||||||
f.host = url.Host
|
f.host = url.Host
|
||||||
|
|
||||||
for _, server := range f.config.servers {
|
for _, server := range f.config.servers {
|
||||||
if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
|
if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
|
||||||
if url.Method != http.MethodPost {
|
if url.Method != http.MethodPost {
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||||
return api.LocalReply
|
return api.LocalReply
|
||||||
@@ -62,7 +62,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
|||||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||||
if f.message {
|
if f.message {
|
||||||
for _, server := range f.config.servers {
|
for _, server := range f.config.servers {
|
||||||
if common.MatchDomainList(f.host, server.DomainList) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
|
if common.MatchDomainWithMatchers(f.host, server.HostMatchers) && strings.HasPrefix(f.path, server.BaseServer.GetMessageEndpoint()) {
|
||||||
if !endStream {
|
if !endStream {
|
||||||
return api.StopAndBuffer
|
return api.StopAndBuffer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type MatchRule struct {
|
|||||||
UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule
|
UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule
|
||||||
EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes
|
EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes
|
||||||
PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to.
|
PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to.
|
||||||
|
HostMatcher HostMatcher // Host matcher for efficient matching
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseMatchList parses the match list from the config
|
// ParseMatchList parses the match list from the config
|
||||||
@@ -91,6 +92,9 @@ func ParseMatchList(matchListConfig []interface{}) []MatchRule {
|
|||||||
rule.PathRewritePrefix = "/" + rule.PathRewritePrefix
|
rule.PathRewritePrefix = "/" + rule.PathRewritePrefix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rule.HostMatcher = ParseHostPattern(rule.MatchRuleDomain)
|
||||||
|
|
||||||
matchList = append(matchList, rule)
|
matchList = append(matchList, rule)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -115,8 +119,8 @@ func stripPortFromHost(reqHost string) string {
|
|||||||
return reqHost
|
return reqHost
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseHostPattern parses a host pattern and returns a HostMatcher
|
// ParseHostPattern parses a host pattern and returns a HostMatcher
|
||||||
func parseHostPattern(pattern string) HostMatcher {
|
func ParseHostPattern(pattern string) HostMatcher {
|
||||||
var hostMatcher HostMatcher
|
var hostMatcher HostMatcher
|
||||||
if strings.HasPrefix(pattern, "*") {
|
if strings.HasPrefix(pattern, "*") {
|
||||||
hostMatcher.matchType = HostSuffix
|
hostMatcher.matchType = HostSuffix
|
||||||
@@ -157,18 +161,11 @@ func matchPattern(pattern string, target string, ruleType RuleType) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchDomain checks if the domain matches the pattern using HostMatcher approach
|
// matchDomainWithMatcher checks if the domain matches using a pre-parsed HostMatcher
|
||||||
func matchDomain(domain string, pattern string) bool {
|
func matchDomainWithMatcher(domain string, hostMatcher HostMatcher) bool {
|
||||||
if pattern == "" || pattern == "*" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strip port from domain
|
// Strip port from domain
|
||||||
domain = stripPortFromHost(domain)
|
domain = stripPortFromHost(domain)
|
||||||
|
|
||||||
// Parse the pattern into a HostMatcher
|
|
||||||
hostMatcher := parseHostPattern(pattern)
|
|
||||||
|
|
||||||
// Perform matching based on match type
|
// Perform matching based on match type
|
||||||
switch hostMatcher.matchType {
|
switch hostMatcher.matchType {
|
||||||
case HostSuffix:
|
case HostSuffix:
|
||||||
@@ -184,7 +181,7 @@ func matchDomain(domain string, pattern string) bool {
|
|||||||
|
|
||||||
// matchDomainAndPath checks if both domain and path match the rule
|
// matchDomainAndPath checks if both domain and path match the rule
|
||||||
func matchDomainAndPath(domain, path string, rule MatchRule) bool {
|
func matchDomainAndPath(domain, path string, rule MatchRule) bool {
|
||||||
return matchDomain(domain, rule.MatchRuleDomain) &&
|
return matchDomainWithMatcher(domain, rule.HostMatcher) &&
|
||||||
matchPattern(rule.MatchRulePath, path, rule.MatchRuleType)
|
matchPattern(rule.MatchRulePath, path, rule.MatchRuleType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,9 +201,9 @@ func IsMatch(rules []MatchRule, host, path string) (bool, MatchRule) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MatchDomainList checks if the domain matches any of the domains in the list
|
// MatchDomainList checks if the domain matches any of the domains in the list
|
||||||
func MatchDomainList(domain string, domainList []string) bool {
|
func MatchDomainWithMatchers(domain string, hostMatchers []HostMatcher) bool {
|
||||||
for _, d := range domainList {
|
for _, hostMatcher := range hostMatchers {
|
||||||
if matchDomain(domain, d) {
|
if matchDomainWithMatcher(domain, hostMatcher) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -223,7 +223,12 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
|||||||
}
|
}
|
||||||
// Send HTTP response
|
// Send HTTP response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(response)
|
jsonData, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
api.LogErrorf("Failed to marshal SSE Message response: %v", err)
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
} else {
|
} else {
|
||||||
// For notifications, just send 202 Accepted with no body
|
// For notifications, just send 202 Accepted with no body
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ type RequestURL struct {
|
|||||||
Scheme string
|
Scheme string
|
||||||
Host string
|
Host string
|
||||||
Path string
|
Path string
|
||||||
BaseURL string
|
|
||||||
ParsedURL *url.URL
|
ParsedURL *url.URL
|
||||||
InternalIP bool
|
InternalIP bool
|
||||||
}
|
}
|
||||||
@@ -23,12 +22,12 @@ func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
|||||||
host, _ := header.Get(":authority")
|
host, _ := header.Get(":authority")
|
||||||
path, _ := header.Get(":path")
|
path, _ := header.Get(":path")
|
||||||
internalIP, _ := header.Get("x-envoy-internal")
|
internalIP, _ := header.Get("x-envoy-internal")
|
||||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
fullURL := fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||||
parsedURL, err := url.Parse(path)
|
parsedURL, err := url.Parse(fullURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.LogWarnf("url parse path:%s failed:%s", path, err)
|
api.LogWarnf("url parse fullURL:%s failed:%s", fullURL, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||||
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
|
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
|
|||||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
channel := common.GetSSEChannelName(sessionID)
|
channel := common.GetSSEChannelName(sessionID)
|
||||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", strings.TrimSuffix(buffer.String(), "\n"))
|
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||||
publishErr := f.config.redisClient.Publish(channel, eventData)
|
publishErr := f.config.redisClient.Publish(channel, eventData)
|
||||||
if publishErr != nil {
|
if publishErr != nil {
|
||||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||||
|
|||||||
Reference in New Issue
Block a user