fix concurrent SSE connections returning wrong endpoint (#3341)

This commit is contained in:
TianHao Zhang
2026-01-19 10:22:50 +08:00
committed by jingze
parent a38be77b9e
commit 24c69fb0b7
2 changed files with 11 additions and 8 deletions

View File

@@ -26,8 +26,8 @@ type config struct {
matchList []common.MatchRule matchList []common.MatchRule
enableUserLevelServer bool enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer
redisClient *common.RedisClient redisClient *common.RedisClient
sharedMCPServer *common.MCPServer // Created once, thread-safe with sync.RWMutex
} }
func (c *config) Destroy() { func (c *config) Destroy() {
@@ -110,6 +110,9 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
} }
GlobalSSEPathSuffix = ssePathSuffix GlobalSSEPathSuffix = ssePathSuffix
// Create shared MCPServer once during config parsing (thread-safe with sync.RWMutex)
conf.sharedMCPServer = common.NewMCPServer(DefaultServerName, Version)
return conf, nil return conf, nil
} }
@@ -125,9 +128,6 @@ func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
if childConfig.rateLimitConfig != nil { if childConfig.rateLimitConfig != nil {
newConfig.rateLimitConfig = childConfig.rateLimitConfig newConfig.rateLimitConfig = childConfig.rateLimitConfig
} }
if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer
}
return &newConfig return &newConfig
} }

View File

@@ -37,6 +37,7 @@ type filter struct {
skipRequestBody bool skipRequestBody bool
skipResponseBody bool skipResponseBody bool
cachedResponseBody []byte cachedResponseBody []byte
sseServer *common.SSEServer // SSE server instance for this filter (per-request, not shared)
userLevelConfig bool userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler mcpConfigHandler *handler.MCPConfigHandler
@@ -135,11 +136,13 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
trimmed += "?" + rq trimmed += "?" + rq
} }
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version), // Create SSE server instance for this filter (per-request, not shared)
// MCPServer is shared (thread-safe), but SSEServer must be per-request (contains request-specific messageEndpoint)
f.sseServer = common.NewSSEServer(f.config.sharedMCPServer,
common.WithSSEEndpoint(GlobalSSEPathSuffix), common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(trimmed), common.WithMessageEndpoint(trimmed),
common.WithRedisClient(f.config.redisClient)) common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName() f.serverName = f.sseServer.GetServerName()
body := "SSE connection create" body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
} }
@@ -275,9 +278,9 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
if f.serverName != "" { if f.serverName != "" {
if f.config.redisClient != nil { if f.config.redisClient != nil {
// handle default server // handle SSE server for this filter instance
buffer.Reset() buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan) f.sseServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running return api.Running
} else { } else {
_ = buffer.SetString(RedisNotEnabledResponseBody) _ = buffer.SetString(RedisNotEnabledResponseBody)