feat: Add SSE direct proxy support to mcp-session filter (#2157)

This commit is contained in:
Kent Dong
2025-05-09 14:28:42 +08:00
committed by GitHub
parent 3f67b05fab
commit ab014cf912
4 changed files with 450 additions and 53 deletions

View File

@@ -3,24 +3,36 @@ package common
import (
"regexp"
"strings"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
// RuleType defines the type of matching rule
type RuleType string
// UpstreamType defines the type of matching rule
type UpstreamType string
const (
ExactMatch RuleType = "exact"
PrefixMatch RuleType = "prefix"
SuffixMatch RuleType = "suffix"
ContainsMatch RuleType = "contains"
RegexMatch RuleType = "regex"
RestUpstream UpstreamType = "rest"
SSEUpstream UpstreamType = "sse"
StreamableUpstream UpstreamType = "streamable"
)
// MatchRule defines the structure for a matching rule
type MatchRule struct {
MatchRuleDomain string `json:"match_rule_domain"` // Domain pattern, supports wildcards
MatchRulePath string `json:"match_rule_path"` // Path pattern to match
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
MatchRuleDomain string `json:"match_rule_domain"` // Domain pattern, supports wildcards
MatchRulePath string `json:"match_rule_path"` // Path pattern to match
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
UpstreamType UpstreamType `json:"upstream_type"` // Type of upstream(s) matched by the rule
EnablePathRewrite bool `json:"enable_path_rewrite"` // Enable request path rewrite for matched routes
PathRewritePrefix string `json:"path_rewrite_prefix"` // Prefix the request path would be rewritten to.
}
// ParseMatchList parses the match list from the config
@@ -38,6 +50,34 @@ func ParseMatchList(matchListConfig []interface{}) []MatchRule {
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
rule.MatchRuleType = RuleType(ruleType)
}
if upstreamType, ok := ruleMap["upstream_type"].(string); ok {
rule.UpstreamType = UpstreamType(upstreamType)
}
if len(rule.UpstreamType) == 0 {
rule.UpstreamType = RestUpstream
} else {
switch rule.UpstreamType {
case RestUpstream, SSEUpstream, StreamableUpstream:
break
default:
api.LogWarnf("Unknown upstream type: %s", rule.UpstreamType)
}
}
if enablePathRewrite, ok := ruleMap["enable_path_rewrite"].(bool); ok {
rule.EnablePathRewrite = enablePathRewrite
}
if pathRewritePrefix, ok := ruleMap["path_rewrite_prefix"].(string); ok {
rule.PathRewritePrefix = pathRewritePrefix
}
if rule.EnablePathRewrite {
if rule.UpstreamType != SSEUpstream {
api.LogWarnf("Path rewrite is only supported for SSE upstream type")
} else if rule.MatchRuleType != PrefixMatch {
api.LogWarnf("Path rewrite is only supported for prefix match type")
} else if !strings.HasPrefix(rule.PathRewritePrefix, "/") {
rule.PathRewritePrefix = "/" + rule.PathRewritePrefix
}
}
matchList = append(matchList, rule)
}
}
@@ -96,17 +136,17 @@ func matchDomainAndPath(domain, path string, rule MatchRule) bool {
// IsMatch checks if the request matches any rule in the rule list
// Returns true if no rules are specified
func IsMatch(rules []MatchRule, host, path string) bool {
func IsMatch(rules []MatchRule, host, path string) (bool, MatchRule) {
if len(rules) == 0 {
return true
return true, MatchRule{}
}
for _, rule := range rules {
if matchDomainAndPath(host, path, rule) {
return true
return true, rule
}
}
return false
return false, MatchRule{}
}
// MatchDomainList checks if the domain matches any of the domains in the list

View File

@@ -2,6 +2,7 @@ package mcp_session
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -28,10 +29,14 @@ type filter struct {
config *config
stopChan chan struct{}
req *http.Request
serverName string
proxyURL *url.URL
neepProcess bool
req *http.Request
serverName string
proxyURL *url.URL
matchedRule common.MatchRule
needProcess bool
skipRequestBody bool
skipResponseBody bool
cachedResponseBody []byte
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
@@ -42,31 +47,33 @@ type filter struct {
// Callbacks which are called in request path
// The endStream is true if the request doesn't have body
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
url := common.NewRequestURL(header)
if url == nil {
requestUrl := common.NewRequestURL(header)
if requestUrl == nil {
return api.Continue
}
f.path = url.ParsedURL.Path
f.path = requestUrl.ParsedURL.Path
// Check if request matches any rule in match_list
if !common.IsMatch(f.config.matchList, url.Host, f.path) {
api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String())
matched, matchedRule := common.IsMatch(f.config.matchList, requestUrl.Host, f.path)
if !matched {
api.LogDebugf("Request does not match any rule in match_list: %s", requestUrl.ParsedURL.String())
return api.Continue
}
f.neepProcess = true
f.needProcess = true
f.matchedRule = matchedRule
f.req = &http.Request{
Method: url.Method,
URL: url.ParsedURL,
Method: requestUrl.Method,
URL: requestUrl.ParsedURL,
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !url.InternalIP {
api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String())
if !requestUrl.InternalIP {
api.LogWarnf("Access denied: non-Internal IP address %s", requestUrl.ParsedURL.String())
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet {
if strings.HasSuffix(f.path, ConfigPathSuffix) && requestUrl.Method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
return api.LocalReply
@@ -79,10 +86,27 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
}
if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) {
f.proxyURL = url.ParsedURL
return f.processMcpRequestHeaders(header, endStream)
}
func (f *filter) processMcpRequestHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
switch f.matchedRule.UpstreamType {
case common.RestUpstream, common.StreamableUpstream:
return f.processMcpRequestHeadersForRestUpstream(header, endStream)
case common.SSEUpstream:
return f.processMcpRequestHeadersForSSEUpstream(header, endStream)
}
f.needProcess = false
return api.Continue
}
func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
method := f.req.Method
requestUrl := f.req.URL
if !strings.HasSuffix(requestUrl.Path, GlobalSSEPathSuffix) {
f.proxyURL = requestUrl
if f.config.enableUserLevelServer {
parts := strings.Split(url.ParsedURL.Path, "/")
parts := strings.Split(requestUrl.Path, "/")
if len(parts) >= 3 {
serverName := parts[1]
uid := parts[2]
@@ -102,12 +126,12 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
}
if url.Method != http.MethodGet {
if method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)),
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"
@@ -116,10 +140,60 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
return api.LocalReply
}
func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
// We don't need to process the request body for SSE upstream.
f.skipRequestBody = true
f.rewritePathForSSEUpstream(header)
return api.Continue
}
func (f *filter) rewritePathForSSEUpstream(header api.RequestHeaderMap) {
matchedRule := f.matchedRule
if !matchedRule.EnablePathRewrite || matchedRule.MatchRuleType != common.PrefixMatch {
// No rewrite required, so we don't need to process the response body, either.
f.skipResponseBody = true
return
}
path := f.req.URL.Path
if !strings.HasPrefix(path, matchedRule.MatchRulePath) {
api.LogWarnf("Unexpected: Path %s does not match the configured prefix %s", path, matchedRule.MatchRulePath)
return
}
rewrittenPath := path[len(matchedRule.MatchRulePath):]
if rewrittenPath == "" {
rewrittenPath = matchedRule.PathRewritePrefix
} else {
rewritePrefixHasTrailingSlash := strings.HasSuffix(matchedRule.PathRewritePrefix, "/")
pathSuffixHasLeadingSlash := strings.HasPrefix(rewrittenPath, "/")
if rewritePrefixHasTrailingSlash != pathSuffixHasLeadingSlash {
// One has, the other doesn't have.
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath
} else if pathSuffixHasLeadingSlash {
// Both have.
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath[1:]
} else {
// Neither have.
rewrittenPath = matchedRule.PathRewritePrefix + "/" + rewrittenPath
}
}
if f.req.URL.RawQuery != "" {
rewrittenPath = rewrittenPath + "?" + f.req.URL.RawQuery
}
header.SetPath(rewrittenPath)
}
// DecodeData might be called multiple times during handling the request body.
// The endStream is true when handling the last piece of the body.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.neepProcess {
if !f.needProcess || f.skipRequestBody {
return api.Continue
}
if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream {
return api.Continue
}
if !endStream {
@@ -158,10 +232,17 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu
return api.Continue
}
// Callbacks which are called in response path
// The endStream is true if the response doesn't have body
// EncodeHeaders Callbacks which are called in response path.
// The endStream is true if the response doesn't have body.
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
if !f.neepProcess {
if !f.needProcess {
return api.Continue
}
if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream {
if contentType, ok := header.Get("content-type"); !ok || !strings.HasPrefix(contentType, "text/event-stream") {
api.LogDebugf("Skip response body for non-SSE upstream. Content-Type: %s", contentType)
f.skipResponseBody = true
}
return api.Continue
}
if f.serverName != "" {
@@ -182,7 +263,30 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
// EncodeData might be called multiple times during handling the response body.
// The endStream is true when handling the last piece of the body.
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.neepProcess {
if !f.needProcess || f.skipResponseBody {
return api.Continue
}
ret := api.Continue
api.LogDebugf("Upstream Type: %s", f.matchedRule.UpstreamType)
switch f.matchedRule.UpstreamType {
case common.RestUpstream, common.StreamableUpstream:
api.LogDebugf("Encoding data from Rest upstream")
ret = f.encodeDataFromRestUpstream(buffer, endStream)
break
case common.SSEUpstream:
api.LogDebugf("Encoding data from SSE upstream")
ret = f.encodeDataFromSSEUpstream(buffer, endStream)
if endStream {
// Always continue as long as the stream has ended.
ret = api.Continue
}
}
return ret
}
func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.needProcess {
return api.Continue
}
if !endStream {
@@ -207,13 +311,157 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
} else {
buffer.SetString(RedisNotEnabledResponseBody)
_ = buffer.SetString(RedisNotEnabledResponseBody)
return api.Continue
}
}
return api.Continue
}
func (f *filter) encodeDataFromSSEUpstream(buffer api.BufferInstance, endStream bool) api.StatusType {
bufferBytes := buffer.Bytes()
bufferData := string(bufferBytes)
err, lineBreak := f.findSSELineBreak(bufferData)
if err != nil {
api.LogWarnf("Failed to find line break in SSE data: %v", err)
f.needProcess = false
return api.Continue
}
if lineBreak == "" {
// Have not found any line break. Need to buffer and check again.
return api.StopAndBuffer
}
api.LogDebugf("Line break sequence: %v", []byte(lineBreak))
err, endpointUrl := f.findEndpointUrl(bufferData, lineBreak)
if err != nil {
api.LogWarnf("Failed to find endpoint URL in SSE data: %v", err)
f.needProcess = false
return api.Continue
}
if endpointUrl == "" {
// No endpoint URL found. Need to buffer and check again.
return api.StopAndBuffer
}
// Remove query string since we don't need to change it.
queryStringIndex := strings.IndexAny(endpointUrl, "?")
if queryStringIndex != -1 {
endpointUrl = endpointUrl[:queryStringIndex]
}
if changed, newEndpointUrl := f.rewriteEndpointUrl(endpointUrl); changed {
api.LogDebugf("The endpoint URL is changed.\n Old: %s\n New: %s", endpointUrl, newEndpointUrl)
endpointUrlIndex := strings.Index(bufferData, endpointUrl)
if endpointUrlIndex == -1 {
api.LogWarnf("Something wrong, the previously found endpoint URL %s not found in the SSE data now", endpointUrl)
} else {
bufferData = bufferData[:endpointUrlIndex] + newEndpointUrl + bufferData[endpointUrlIndex+len(endpointUrl):]
_ = buffer.SetString(bufferData)
}
} else {
api.LogDebugf("The endpoint URL %s is not changed", endpointUrl)
}
f.needProcess = false
return api.Continue
}
func (f *filter) rewriteEndpointUrl(endpointUrl string) (bool, string) {
if !f.matchedRule.EnablePathRewrite {
return false, ""
}
if schemeIndex := strings.Index(endpointUrl, "://"); schemeIndex != -1 {
endpointUrl = endpointUrl[schemeIndex+3:]
if slashIndex := strings.Index(endpointUrl, "/"); slashIndex != -1 {
endpointUrl = endpointUrl[slashIndex:]
} else {
endpointUrl = "/"
}
}
if !strings.HasPrefix(endpointUrl, f.matchedRule.PathRewritePrefix) {
// The endpoint URL does not match the path rewrite prefix. We are unable to rewrite it back.
api.LogWarnf("The endpoint URL %s does not match the path rewrite prefix %s", endpointUrl, f.matchedRule.PathRewritePrefix)
return false, ""
}
suffix := endpointUrl[len(f.matchedRule.PathRewritePrefix):]
if len(suffix) == 0 {
endpointUrl = f.matchedRule.MatchRulePath
} else {
matchPathHasTrailingSlash := strings.HasSuffix(f.matchedRule.MatchRulePath, "/")
suffixHasLeadingSlash := strings.HasPrefix(suffix, "/")
if matchPathHasTrailingSlash != suffixHasLeadingSlash {
// One has, the other doesn't have.
endpointUrl = f.matchedRule.MatchRulePath + suffix
} else if matchPathHasTrailingSlash {
// Both have.
endpointUrl = f.matchedRule.MatchRulePath + suffix[1:]
} else {
// Neither have.
endpointUrl = f.matchedRule.MatchRulePath + "/" + suffix
}
}
return true, endpointUrl
}
func (f *filter) findSSELineBreak(bufferData string) (error, string) {
// See https://html.spec.whatwg.org/multipage/server-sent-events.html
crIndex := strings.IndexAny(bufferData, "\r")
lfIndex := strings.IndexAny(bufferData, "\n")
if crIndex == -1 && lfIndex == -1 {
// No line break found.
return nil, ""
}
lineBreak := ""
if crIndex != -1 && lfIndex != -1 {
if crIndex+1 != lfIndex {
// Found both line breaks, but they are not adjacent. Skip body processing.
return errors.New("found non-adjacent CR and LF"), ""
}
lineBreak = "\r\n"
} else if crIndex != -1 {
lineBreak = "\r"
} else {
lineBreak = "\n"
}
return nil, lineBreak
}
func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) {
eventIndex := strings.Index(bufferData, "event:")
if eventIndex == -1 {
return nil, ""
}
bufferData = bufferData[eventIndex:]
eventEndIndex := strings.Index(bufferData, lineBreak)
if eventEndIndex == -1 {
return nil, ""
}
eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex])
if eventName != "endpoint" {
return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), ""
}
bufferData = bufferData[eventEndIndex+len(lineBreak):]
dataEndIndex := strings.Index(bufferData, lineBreak)
if dataEndIndex == -1 {
// Data received not enough.
return nil, ""
}
eventData := bufferData[:dataEndIndex]
if !strings.HasPrefix(eventData, "data:") {
return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), ""
}
return nil, strings.TrimSpace(eventData[len("data:"):])
}
// OnDestroy stops the goroutine
func (f *filter) OnDestroy(reason api.DestroyReason) {
api.LogDebugf("OnDestroy: reason=%v", reason)