From ab73f21017294c9af4a20d83a7006eb67671b189 Mon Sep 17 00:00:00 2001 From: Jingze <52855280+Jing-ze@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:27:48 +0800 Subject: [PATCH] fix: make mcp server redis client config based (#2145) Co-authored-by: daijingze_mac <18373118@buaa.edu.cn> --- plugins/golang-filter/mcp-server/config.go | 1 - .../golang-filter/mcp-session/common/redis.go | 11 +++++------ .../golang-filter/mcp-session/common/sse.go | 2 +- plugins/golang-filter/mcp-session/config.go | 18 ++++++++++-------- plugins/golang-filter/mcp-session/filter.go | 10 +++++----- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/plugins/golang-filter/mcp-server/config.go b/plugins/golang-filter/mcp-server/config.go index 7f0400c90..5be134191 100644 --- a/plugins/golang-filter/mcp-server/config.go +++ b/plugins/golang-filter/mcp-server/config.go @@ -104,7 +104,6 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int conf.servers = append(conf.servers, &SSEServerWrapper{ BaseServer: common.NewSSEServer(serverInstance, - common.WithRedisClient(common.GlobalRedisClient), common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)), common.WithMessageEndpoint(serverPath)), DomainList: serverDomainList, diff --git a/plugins/golang-filter/mcp-session/common/redis.go b/plugins/golang-filter/mcp-session/common/redis.go index 4bb743b08..777efe6da 100644 --- a/plugins/golang-filter/mcp-session/common/redis.go +++ b/plugins/golang-filter/mcp-session/common/redis.go @@ -9,8 +9,6 @@ import ( "github.com/go-redis/redis/v8" ) -var GlobalRedisClient *RedisClient - type RedisConfig struct { address string username string @@ -74,9 +72,10 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) { // Ping the Redis server to check the connection pong, err := client.Ping(context.Background()).Result() if err != nil { - return nil, fmt.Errorf("failed to connect to Redis: %w", err) + api.LogErrorf("Failed to connect to Redis: %v", err) + } else { + api.LogDebugf("Connected to Redis: %s", pong) } - api.LogDebugf("Connected to Redis: %s", pong) ctx, cancel := context.WithCancel(context.Background()) @@ -85,7 +84,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) { crypto, err = NewCrypto(config.secret) if err != nil { cancel() - return nil, err + api.LogWarnf("Failed to initialize redis crypto: %v", err) } } @@ -105,7 +104,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) { // keepAlive periodically checks Redis connection and attempts to reconnect if needed func (r *RedisClient) keepAlive() { - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { diff --git a/plugins/golang-filter/mcp-session/common/sse.go b/plugins/golang-filter/mcp-session/common/sse.go index f1f2b7c06..11fee2458 100644 --- a/plugins/golang-filter/mcp-session/common/sse.go +++ b/plugins/golang-filter/mcp-session/common/sse.go @@ -210,7 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j var status int // Only send response if there is one (not for notifications) if response != nil { - if sessionID != "" && s.redisClient != nil { + if sessionID != ""{ w.WriteHeader(http.StatusAccepted) status = http.StatusAccepted } else { diff --git a/plugins/golang-filter/mcp-session/config.go b/plugins/golang-filter/mcp-session/config.go index b82798038..6bb9fcf30 100644 --- a/plugins/golang-filter/mcp-session/config.go +++ b/plugins/golang-filter/mcp-session/config.go @@ -25,12 +25,13 @@ type config struct { enableUserLevelServer bool rateLimitConfig *handler.MCPRatelimitConfig defaultServer *common.SSEServer + redisClient *common.RedisClient } func (c *config) Destroy() { - if common.GlobalRedisClient != nil { + if c.redisClient != nil { api.LogDebug("Closing Redis client") - common.GlobalRedisClient.Close() + c.redisClient.Close() } } @@ -63,10 +64,11 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int redisClient, err := common.NewRedisClient(redisConfig) if err != nil { - return nil, fmt.Errorf("failed to initialize RedisClient: %w", err) + api.LogErrorf("Failed to initialize Redis client: %w", err) + } else { + api.LogDebug("Redis client initialized") } - common.GlobalRedisClient = redisClient - api.LogDebug("Redis client initialized") + conf.redisClient = redisClient } else { api.LogDebug("Redis configuration not provided, running without Redis") } @@ -74,7 +76,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool) if !ok { enableUserLevelServer = false - if common.GlobalRedisClient == nil { + if conf.redisClient == nil { return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true") } } @@ -137,7 +139,7 @@ func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea callbacks: callbacks, config: conf, stopChan: make(chan struct{}), - mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks), - mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig), + mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks), + mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig), } } diff --git a/plugins/golang-filter/mcp-session/filter.go b/plugins/golang-filter/mcp-session/filter.go index d22498ab6..acc812539 100644 --- a/plugins/golang-filter/mcp-session/filter.go +++ b/plugins/golang-filter/mcp-session/filter.go @@ -108,7 +108,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version), common.WithSSEEndpoint(GlobalSSEPathSuffix), common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)), - common.WithRedisClient(common.GlobalRedisClient)) + common.WithRedisClient(f.config.redisClient)) f.serverName = f.config.defaultServer.GetServerName() body := "SSE connection create" f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") @@ -165,7 +165,7 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api return api.Continue } if f.serverName != "" { - if common.GlobalRedisClient != nil { + if f.config.redisClient != nil { header.Set("Content-Type", "text/event-stream") header.Set("Cache-Control", "no-cache") header.Set("Connection", "keep-alive") @@ -188,12 +188,12 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu if !endStream { return api.StopAndBuffer } - if f.proxyURL != nil && common.GlobalRedisClient != nil { + if f.proxyURL != nil && f.config.redisClient != nil { sessionID := f.proxyURL.Query().Get("sessionId") if sessionID != "" { channel := common.GetSSEChannelName(sessionID) eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String()) - publishErr := common.GlobalRedisClient.Publish(channel, eventData) + publishErr := f.config.redisClient.Publish(channel, eventData) if publishErr != nil { api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr) } @@ -201,7 +201,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu } if f.serverName != "" { - if common.GlobalRedisClient != nil { + if f.config.redisClient != nil { // handle default server buffer.Reset() f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)