From ab014cf912086ddac434258772db8baefc762bcc Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Fri, 9 May 2025 14:28:42 +0800 Subject: [PATCH] feat: Add SSE direct proxy support to mcp-session filter (#2157) --- pkg/ingress/kube/configmap/mcp_server.go | 53 +++- pkg/ingress/kube/configmap/mcp_server_test.go | 96 +++++- .../golang-filter/mcp-session/common/match.go | 54 +++- plugins/golang-filter/mcp-session/filter.go | 300 ++++++++++++++++-- 4 files changed, 450 insertions(+), 53 deletions(-) diff --git a/pkg/ingress/kube/configmap/mcp_server.go b/pkg/ingress/kube/configmap/mcp_server.go index dc77efc2b..6b942a0c5 100644 --- a/pkg/ingress/kube/configmap/mcp_server.go +++ b/pkg/ingress/kube/configmap/mcp_server.go @@ -56,7 +56,7 @@ type MCPRatelimitConfig struct { type SSEServer struct { // The name of the SSE server Name string `json:"name,omitempty"` - // The path where the SSE server will be mounted, the full path is (PATH + SsePathSuffix) + // The path where the SSE server will be mounted, the full path is (PATH + SSEPathSuffix) Path string `json:"path,omitempty"` // The type of the SSE server Type string `json:"type,omitempty"` @@ -74,6 +74,12 @@ type MatchRule struct { MatchRulePath string `json:"match_rule_path,omitempty"` // Type of match rule: exact, prefix, suffix, contains, regex MatchRuleType string `json:"match_rule_type,omitempty"` + // Type of upstream(s) matched by the rule: rest (default), sse + UpstreamType string `json:"upstream_type"` + // Enable request path rewrite for matched routes + EnablePathRewrite bool `json:"enable_path_rewrite"` + // Prefix the request path would be rewritten to. + PathRewritePrefix string `json:"path_rewrite_prefix"` } // McpServer defines the configuration for MCP (Model Context Protocol) server @@ -83,7 +89,7 @@ type McpServer struct { // Redis Config for MCP server Redis *RedisConfig `json:"redis,omitempty"` // The suffix to be appended to SSE paths, default is "/sse" - SsePathSuffix string `json:"sse_path_suffix,omitempty"` + SSEPathSuffix string `json:"sse_path_suffix,omitempty"` // List of SSE servers Configs Servers []*SSEServer `json:"servers,omitempty"` // List of match rules for filtering requests @@ -118,21 +124,32 @@ func validMcpServer(m *McpServer) error { // Validate match rule types if m.MatchList != nil { - validTypes := map[string]bool{ + validMatchRuleTypes := map[string]bool{ "exact": true, "prefix": true, "suffix": true, "contains": true, "regex": true, } + validUpstreamTypes := map[string]bool{ + "rest": true, + "sse": true, + "streamable": true, + } for _, rule := range m.MatchList { if rule.MatchRuleType == "" { return errors.New("match_rule_type cannot be empty, must be one of: exact, prefix, suffix, contains, regex") } - if !validTypes[rule.MatchRuleType] { + if !validMatchRuleTypes[rule.MatchRuleType] { return fmt.Errorf("invalid match_rule_type: %s, must be one of: exact, prefix, suffix, contains, regex", rule.MatchRuleType) } + if rule.UpstreamType != "" && !validUpstreamTypes[rule.UpstreamType] { + return fmt.Errorf("invalid upstream_type: %s, must be one of: rest, sse, streamable", rule.UpstreamType) + } + if rule.EnablePathRewrite && rule.UpstreamType != "sse" { + return errors.New("path rewrite is only supported for SSE upstream type") + } } } @@ -174,7 +191,7 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) { WhiteList: mcp.Ratelimit.WhiteList, } } - newMcp.SsePathSuffix = mcp.SsePathSuffix + newMcp.SSEPathSuffix = mcp.SSEPathSuffix newMcp.EnableUserLevelServer = mcp.EnableUserLevelServer @@ -201,9 +218,12 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) { newMcp.MatchList = make([]*MatchRule, len(mcp.MatchList)) for i, rule := range mcp.MatchList { newMcp.MatchList[i] = &MatchRule{ - MatchRuleDomain: rule.MatchRuleDomain, - MatchRulePath: rule.MatchRulePath, - MatchRuleType: rule.MatchRuleType, + MatchRuleDomain: rule.MatchRuleDomain, + MatchRulePath: rule.MatchRulePath, + MatchRuleType: rule.MatchRuleType, + UpstreamType: rule.UpstreamType, + EnablePathRewrite: rule.EnablePathRewrite, + PathRewritePrefix: rule.PathRewritePrefix, } } } @@ -216,7 +236,7 @@ type McpServerController struct { mcpServer atomic.Value Name string eventHandler ItemEventHandler - reconclier *reconcile.Reconciler + reconciler *reconcile.Reconciler } func NewMcpServerController(namespace string) *McpServerController { @@ -291,7 +311,7 @@ func (m *McpServerController) RegisterItemEventHandler(eventHandler ItemEventHan } func (m *McpServerController) RegisterMcpReconciler(reconciler *reconcile.Reconciler) { - m.reconclier = reconciler + m.reconciler = reconciler } func (m *McpServerController) ConstructEnvoyFilters() ([]*config.Config, error) { @@ -393,13 +413,16 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string { matchConfigs = append(matchConfigs, fmt.Sprintf(`{ "match_rule_domain": "%s", "match_rule_path": "%s", - "match_rule_type": "%s" - }`, rule.MatchRuleDomain, rule.MatchRulePath, rule.MatchRuleType)) + "match_rule_type": "%s", + "upstream_type": "%s", + "enable_path_rewrite": %t, + "path_rewrite_prefix": "%s" + }`, rule.MatchRuleDomain, rule.MatchRulePath, rule.MatchRuleType, rule.UpstreamType, rule.EnablePathRewrite, rule.PathRewritePrefix)) } } - if m.reconclier != nil { - vsFromMcp := m.reconclier.GetAllConfigs(gvk.VirtualService) + if m.reconciler != nil { + vsFromMcp := m.reconciler.GetAllConfigs(gvk.VirtualService) for _, c := range vsFromMcp { vs := c.Spec.(*networking.VirtualService) var host string @@ -468,7 +491,7 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string { }`, redisConfig, rateLimitConfig, - mcp.SsePathSuffix, + mcp.SSEPathSuffix, matchList, mcp.EnableUserLevelServer) } diff --git a/pkg/ingress/kube/configmap/mcp_server_test.go b/pkg/ingress/kube/configmap/mcp_server_test.go index 045c652cf..a2ab35fdd 100644 --- a/pkg/ingress/kube/configmap/mcp_server_test.go +++ b/pkg/ingress/kube/configmap/mcp_server_test.go @@ -54,6 +54,61 @@ func Test_validMcpServer(t *testing.T) { }, wantErr: nil, }, + { + name: "enabled but bad match_rule_type", + mcp: &McpServer{ + Enable: true, + EnableUserLevelServer: false, + Redis: nil, + MatchList: []*MatchRule{ + { + MatchRuleDomain: "*", + MatchRulePath: "/mcp", + MatchRuleType: "bad-type", + }, + }, + Servers: []*SSEServer{}, + }, + wantErr: errors.New("invalid match_rule_type: bad-type, must be one of: exact, prefix, suffix, contains, regex"), + }, + { + name: "enabled but bad upstream_type", + mcp: &McpServer{ + Enable: true, + EnableUserLevelServer: false, + Redis: nil, + MatchList: []*MatchRule{ + { + MatchRuleDomain: "*", + MatchRulePath: "/mcp", + MatchRuleType: "prefix", + UpstreamType: "bad-type", + }, + }, + Servers: []*SSEServer{}, + }, + wantErr: errors.New("invalid upstream_type: bad-type, must be one of: rest, sse, streamable"), + }, + { + name: "enabled but path rewrite with unsupported upstream type", + mcp: &McpServer{ + Enable: true, + EnableUserLevelServer: false, + Redis: nil, + MatchList: []*MatchRule{ + { + MatchRuleDomain: "*", + MatchRulePath: "/mcp", + MatchRuleType: "prefix", + UpstreamType: "rest", + EnablePathRewrite: true, + PathRewritePrefix: "/", + }, + }, + Servers: []*SSEServer{}, + }, + wantErr: errors.New("path rewrite is only supported for SSE upstream type"), + }, { name: "enabled with user level server but no redis config", mcp: &McpServer{ @@ -76,7 +131,7 @@ func Test_validMcpServer(t *testing.T) { Password: "password", DB: 0, }, - SsePathSuffix: "/sse", + SSEPathSuffix: "/sse", MatchList: []*MatchRule{ { MatchRuleDomain: "*", @@ -238,7 +293,7 @@ func Test_deepCopyMcpServer(t *testing.T) { Password: "password", DB: 0, }, - SsePathSuffix: "/sse", + SSEPathSuffix: "/sse", MatchList: []*MatchRule{ { MatchRuleDomain: "*", @@ -265,7 +320,7 @@ func Test_deepCopyMcpServer(t *testing.T) { Password: "password", DB: 0, }, - SsePathSuffix: "/sse", + SSEPathSuffix: "/sse", MatchList: []*MatchRule{ { MatchRuleDomain: "*", @@ -581,13 +636,27 @@ func TestMcpServerController_constructMcpSessionStruct(t *testing.T) { Password: "pass", DB: 1, }, - SsePathSuffix: "/sse", + SSEPathSuffix: "/sse", MatchList: []*MatchRule{ { MatchRuleDomain: "*", MatchRulePath: "/test", MatchRuleType: "exact", }, + { + MatchRuleDomain: "*", + MatchRulePath: "/sse-test-1", + MatchRuleType: "prefix", + UpstreamType: "sse", + }, + { + MatchRuleDomain: "*", + MatchRulePath: "/sse-test-2", + MatchRuleType: "prefix", + UpstreamType: "sse", + EnablePathRewrite: true, + PathRewritePrefix: "/mcp", + }, }, EnableUserLevelServer: true, Ratelimit: &MCPRatelimitConfig{ @@ -623,7 +692,24 @@ func TestMcpServerController_constructMcpSessionStruct(t *testing.T) { "match_list": [{ "match_rule_domain": "*", "match_rule_path": "/test", - "match_rule_type": "exact" + "match_rule_type": "exact", + "upstream_type": "", + "enable_path_rewrite": false, + "path_rewrite_prefix": "" + },{ + "match_rule_domain": "*", + "match_rule_path": "/sse-test-1", + "match_rule_type": "prefix", + "upstream_type": "sse", + "enable_path_rewrite": false, + "path_rewrite_prefix": "" + },{ + "match_rule_domain": "*", + "match_rule_path": "/sse-test-2", + "match_rule_type": "prefix", + "upstream_type": "sse", + "enable_path_rewrite": true, + "path_rewrite_prefix": "/mcp" }], "enable_user_level_server": true } diff --git a/plugins/golang-filter/mcp-session/common/match.go b/plugins/golang-filter/mcp-session/common/match.go index c7945f39b..bd5b250b5 100644 --- a/plugins/golang-filter/mcp-session/common/match.go +++ b/plugins/golang-filter/mcp-session/common/match.go @@ -3,24 +3,36 @@ package common import ( "regexp" "strings" + + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" ) // RuleType defines the type of matching rule type RuleType string +// UpstreamType defines the type of matching rule +type UpstreamType string + const ( ExactMatch RuleType = "exact" PrefixMatch RuleType = "prefix" SuffixMatch RuleType = "suffix" ContainsMatch RuleType = "contains" RegexMatch RuleType = "regex" + + RestUpstream UpstreamType = "rest" + SSEUpstream UpstreamType = "sse" + StreamableUpstream UpstreamType = "streamable" ) // MatchRule defines the structure for a matching rule type MatchRule struct { - MatchRuleDomain string `json:"match_rule_domain"` // Domain pattern, supports wildcards - MatchRulePath string `json:"match_rule_path"` // Path pattern to match - MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule + MatchRuleDomain string `json:"match_rule_domain"` // Domain pattern, supports wildcards + MatchRulePath string `json:"match_rule_path"` // Path pattern to match + MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule + UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule + EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes + PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to. } // ParseMatchList parses the match list from the config @@ -38,6 +50,34 @@ func ParseMatchList(matchListConfig []interface{}) []MatchRule { if ruleType, ok := ruleMap["match_rule_type"].(string); ok { rule.MatchRuleType = RuleType(ruleType) } + if upstreamType, ok := ruleMap["upstream_type"].(string); ok { + rule.UpstreamType = UpstreamType(upstreamType) + } + if len(rule.UpstreamType) == 0 { + rule.UpstreamType = RestUpstream + } else { + switch rule.UpstreamType { + case RestUpstream, SSEUpstream, StreamableUpstream: + break + default: + api.LogWarnf("Unknown upstream type: %s", rule.UpstreamType) + } + } + if enablePathRewrite, ok := ruleMap["enable_path_rewrite"].(bool); ok { + rule.EnablePathRewrite = enablePathRewrite + } + if pathRewritePrefix, ok := ruleMap["path_rewrite_prefix"].(string); ok { + rule.PathRewritePrefix = pathRewritePrefix + } + if rule.EnablePathRewrite { + if rule.UpstreamType != SSEUpstream { + api.LogWarnf("Path rewrite is only supported for SSE upstream type") + } else if rule.MatchRuleType != PrefixMatch { + api.LogWarnf("Path rewrite is only supported for prefix match type") + } else if !strings.HasPrefix(rule.PathRewritePrefix, "/") { + rule.PathRewritePrefix = "/" + rule.PathRewritePrefix + } + } matchList = append(matchList, rule) } } @@ -96,17 +136,17 @@ func matchDomainAndPath(domain, path string, rule MatchRule) bool { // IsMatch checks if the request matches any rule in the rule list // Returns true if no rules are specified -func IsMatch(rules []MatchRule, host, path string) bool { +func IsMatch(rules []MatchRule, host, path string) (bool, MatchRule) { if len(rules) == 0 { - return true + return true, MatchRule{} } for _, rule := range rules { if matchDomainAndPath(host, path, rule) { - return true + return true, rule } } - return false + return false, MatchRule{} } // MatchDomainList checks if the domain matches any of the domains in the list diff --git a/plugins/golang-filter/mcp-session/filter.go b/plugins/golang-filter/mcp-session/filter.go index acc812539..e0f15d651 100644 --- a/plugins/golang-filter/mcp-session/filter.go +++ b/plugins/golang-filter/mcp-session/filter.go @@ -2,6 +2,7 @@ package mcp_session import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -28,10 +29,14 @@ type filter struct { config *config stopChan chan struct{} - req *http.Request - serverName string - proxyURL *url.URL - neepProcess bool + req *http.Request + serverName string + proxyURL *url.URL + matchedRule common.MatchRule + needProcess bool + skipRequestBody bool + skipResponseBody bool + cachedResponseBody []byte userLevelConfig bool mcpConfigHandler *handler.MCPConfigHandler @@ -42,31 +47,33 @@ type filter struct { // Callbacks which are called in request path // The endStream is true if the request doesn't have body func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType { - url := common.NewRequestURL(header) - if url == nil { + requestUrl := common.NewRequestURL(header) + if requestUrl == nil { return api.Continue } - f.path = url.ParsedURL.Path + f.path = requestUrl.ParsedURL.Path // Check if request matches any rule in match_list - if !common.IsMatch(f.config.matchList, url.Host, f.path) { - api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String()) + matched, matchedRule := common.IsMatch(f.config.matchList, requestUrl.Host, f.path) + if !matched { + api.LogDebugf("Request does not match any rule in match_list: %s", requestUrl.ParsedURL.String()) return api.Continue } - f.neepProcess = true + f.needProcess = true + f.matchedRule = matchedRule f.req = &http.Request{ - Method: url.Method, - URL: url.ParsedURL, + Method: requestUrl.Method, + URL: requestUrl.ParsedURL, } if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer { - if !url.InternalIP { - api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String()) + if !requestUrl.InternalIP { + api.LogWarnf("Access denied: non-Internal IP address %s", requestUrl.ParsedURL.String()) f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "") return api.LocalReply } - if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet { + if strings.HasSuffix(f.path, ConfigPathSuffix) && requestUrl.Method == http.MethodGet { api.LogDebugf("Handling config request: %s", f.path) f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{}) return api.LocalReply @@ -79,10 +86,27 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. } } - if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) { - f.proxyURL = url.ParsedURL + return f.processMcpRequestHeaders(header, endStream) +} + +func (f *filter) processMcpRequestHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType { + switch f.matchedRule.UpstreamType { + case common.RestUpstream, common.StreamableUpstream: + return f.processMcpRequestHeadersForRestUpstream(header, endStream) + case common.SSEUpstream: + return f.processMcpRequestHeadersForSSEUpstream(header, endStream) + } + f.needProcess = false + return api.Continue +} + +func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType { + method := f.req.Method + requestUrl := f.req.URL + if !strings.HasSuffix(requestUrl.Path, GlobalSSEPathSuffix) { + f.proxyURL = requestUrl if f.config.enableUserLevelServer { - parts := strings.Split(url.ParsedURL.Path, "/") + parts := strings.Split(requestUrl.Path, "/") if len(parts) >= 3 { serverName := parts[1] uid := parts[2] @@ -102,12 +126,12 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. } } - if url.Method != http.MethodGet { + if method != http.MethodGet { f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") } else { f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version), common.WithSSEEndpoint(GlobalSSEPathSuffix), - common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)), + common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)), common.WithRedisClient(f.config.redisClient)) f.serverName = f.config.defaultServer.GetServerName() body := "SSE connection create" @@ -116,10 +140,60 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. return api.LocalReply } +func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType { + // We don't need to process the request body for SSE upstream. + f.skipRequestBody = true + f.rewritePathForSSEUpstream(header) + return api.Continue +} + +func (f *filter) rewritePathForSSEUpstream(header api.RequestHeaderMap) { + matchedRule := f.matchedRule + if !matchedRule.EnablePathRewrite || matchedRule.MatchRuleType != common.PrefixMatch { + // No rewrite required, so we don't need to process the response body, either. + f.skipResponseBody = true + return + } + + path := f.req.URL.Path + if !strings.HasPrefix(path, matchedRule.MatchRulePath) { + api.LogWarnf("Unexpected: Path %s does not match the configured prefix %s", path, matchedRule.MatchRulePath) + return + } + + rewrittenPath := path[len(matchedRule.MatchRulePath):] + + if rewrittenPath == "" { + rewrittenPath = matchedRule.PathRewritePrefix + } else { + rewritePrefixHasTrailingSlash := strings.HasSuffix(matchedRule.PathRewritePrefix, "/") + pathSuffixHasLeadingSlash := strings.HasPrefix(rewrittenPath, "/") + if rewritePrefixHasTrailingSlash != pathSuffixHasLeadingSlash { + // One has, the other doesn't have. + rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath + } else if pathSuffixHasLeadingSlash { + // Both have. + rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath[1:] + } else { + // Neither have. + rewrittenPath = matchedRule.PathRewritePrefix + "/" + rewrittenPath + } + } + + if f.req.URL.RawQuery != "" { + rewrittenPath = rewrittenPath + "?" + f.req.URL.RawQuery + } + + header.SetPath(rewrittenPath) +} + // DecodeData might be called multiple times during handling the request body. // The endStream is true when handling the last piece of the body. func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType { - if !f.neepProcess { + if !f.needProcess || f.skipRequestBody { + return api.Continue + } + if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream { return api.Continue } if !endStream { @@ -158,10 +232,17 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu return api.Continue } -// Callbacks which are called in response path -// The endStream is true if the response doesn't have body +// EncodeHeaders Callbacks which are called in response path. +// The endStream is true if the response doesn't have body. func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType { - if !f.neepProcess { + if !f.needProcess { + return api.Continue + } + if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream { + if contentType, ok := header.Get("content-type"); !ok || !strings.HasPrefix(contentType, "text/event-stream") { + api.LogDebugf("Skip response body for non-SSE upstream. Content-Type: %s", contentType) + f.skipResponseBody = true + } return api.Continue } if f.serverName != "" { @@ -182,7 +263,30 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api // EncodeData might be called multiple times during handling the response body. // The endStream is true when handling the last piece of the body. func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType { - if !f.neepProcess { + if !f.needProcess || f.skipResponseBody { + return api.Continue + } + + ret := api.Continue + api.LogDebugf("Upstream Type: %s", f.matchedRule.UpstreamType) + switch f.matchedRule.UpstreamType { + case common.RestUpstream, common.StreamableUpstream: + api.LogDebugf("Encoding data from Rest upstream") + ret = f.encodeDataFromRestUpstream(buffer, endStream) + break + case common.SSEUpstream: + api.LogDebugf("Encoding data from SSE upstream") + ret = f.encodeDataFromSSEUpstream(buffer, endStream) + if endStream { + // Always continue as long as the stream has ended. + ret = api.Continue + } + } + return ret +} + +func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream bool) api.StatusType { + if !f.needProcess { return api.Continue } if !endStream { @@ -207,13 +311,157 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan) return api.Running } else { - buffer.SetString(RedisNotEnabledResponseBody) + _ = buffer.SetString(RedisNotEnabledResponseBody) return api.Continue } } return api.Continue } +func (f *filter) encodeDataFromSSEUpstream(buffer api.BufferInstance, endStream bool) api.StatusType { + bufferBytes := buffer.Bytes() + bufferData := string(bufferBytes) + + err, lineBreak := f.findSSELineBreak(bufferData) + if err != nil { + api.LogWarnf("Failed to find line break in SSE data: %v", err) + f.needProcess = false + return api.Continue + } + if lineBreak == "" { + // Have not found any line break. Need to buffer and check again. + return api.StopAndBuffer + } + + api.LogDebugf("Line break sequence: %v", []byte(lineBreak)) + + err, endpointUrl := f.findEndpointUrl(bufferData, lineBreak) + if err != nil { + api.LogWarnf("Failed to find endpoint URL in SSE data: %v", err) + f.needProcess = false + return api.Continue + } + if endpointUrl == "" { + // No endpoint URL found. Need to buffer and check again. + return api.StopAndBuffer + } + + // Remove query string since we don't need to change it. + queryStringIndex := strings.IndexAny(endpointUrl, "?") + if queryStringIndex != -1 { + endpointUrl = endpointUrl[:queryStringIndex] + } + + if changed, newEndpointUrl := f.rewriteEndpointUrl(endpointUrl); changed { + api.LogDebugf("The endpoint URL is changed.\n Old: %s\n New: %s", endpointUrl, newEndpointUrl) + + endpointUrlIndex := strings.Index(bufferData, endpointUrl) + if endpointUrlIndex == -1 { + api.LogWarnf("Something wrong, the previously found endpoint URL %s not found in the SSE data now", endpointUrl) + } else { + bufferData = bufferData[:endpointUrlIndex] + newEndpointUrl + bufferData[endpointUrlIndex+len(endpointUrl):] + _ = buffer.SetString(bufferData) + } + } else { + api.LogDebugf("The endpoint URL %s is not changed", endpointUrl) + } + + f.needProcess = false + return api.Continue +} + +func (f *filter) rewriteEndpointUrl(endpointUrl string) (bool, string) { + if !f.matchedRule.EnablePathRewrite { + return false, "" + } + + if schemeIndex := strings.Index(endpointUrl, "://"); schemeIndex != -1 { + endpointUrl = endpointUrl[schemeIndex+3:] + if slashIndex := strings.Index(endpointUrl, "/"); slashIndex != -1 { + endpointUrl = endpointUrl[slashIndex:] + } else { + endpointUrl = "/" + } + } + + if !strings.HasPrefix(endpointUrl, f.matchedRule.PathRewritePrefix) { + // The endpoint URL does not match the path rewrite prefix. We are unable to rewrite it back. + api.LogWarnf("The endpoint URL %s does not match the path rewrite prefix %s", endpointUrl, f.matchedRule.PathRewritePrefix) + return false, "" + } + + suffix := endpointUrl[len(f.matchedRule.PathRewritePrefix):] + + if len(suffix) == 0 { + endpointUrl = f.matchedRule.MatchRulePath + } else { + matchPathHasTrailingSlash := strings.HasSuffix(f.matchedRule.MatchRulePath, "/") + suffixHasLeadingSlash := strings.HasPrefix(suffix, "/") + if matchPathHasTrailingSlash != suffixHasLeadingSlash { + // One has, the other doesn't have. + endpointUrl = f.matchedRule.MatchRulePath + suffix + } else if matchPathHasTrailingSlash { + // Both have. + endpointUrl = f.matchedRule.MatchRulePath + suffix[1:] + } else { + // Neither have. + endpointUrl = f.matchedRule.MatchRulePath + "/" + suffix + } + } + + return true, endpointUrl +} + +func (f *filter) findSSELineBreak(bufferData string) (error, string) { + // See https://html.spec.whatwg.org/multipage/server-sent-events.html + crIndex := strings.IndexAny(bufferData, "\r") + lfIndex := strings.IndexAny(bufferData, "\n") + if crIndex == -1 && lfIndex == -1 { + // No line break found. + return nil, "" + } + lineBreak := "" + if crIndex != -1 && lfIndex != -1 { + if crIndex+1 != lfIndex { + // Found both line breaks, but they are not adjacent. Skip body processing. + return errors.New("found non-adjacent CR and LF"), "" + } + lineBreak = "\r\n" + } else if crIndex != -1 { + lineBreak = "\r" + } else { + lineBreak = "\n" + } + return nil, lineBreak +} + +func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) { + eventIndex := strings.Index(bufferData, "event:") + if eventIndex == -1 { + return nil, "" + } + bufferData = bufferData[eventIndex:] + eventEndIndex := strings.Index(bufferData, lineBreak) + if eventEndIndex == -1 { + return nil, "" + } + eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex]) + if eventName != "endpoint" { + return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), "" + } + bufferData = bufferData[eventEndIndex+len(lineBreak):] + dataEndIndex := strings.Index(bufferData, lineBreak) + if dataEndIndex == -1 { + // Data received not enough. + return nil, "" + } + eventData := bufferData[:dataEndIndex] + if !strings.HasPrefix(eventData, "data:") { + return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), "" + } + return nil, strings.TrimSpace(eventData[len("data:"):]) +} + // OnDestroy stops the goroutine func (f *filter) OnDestroy(reason api.DestroyReason) { api.LogDebugf("OnDestroy: reason=%v", reason)