fix: make mcp server redis client config based (#2145)

Co-authored-by: daijingze_mac <18373118@buaa.edu.cn>
This commit is contained in:
Jingze
2025-04-29 14:27:48 +08:00
committed by GitHub
parent 806563298b
commit ab73f21017
5 changed files with 21 additions and 21 deletions

View File

@@ -104,7 +104,6 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
conf.servers = append(conf.servers, &SSEServerWrapper{ conf.servers = append(conf.servers, &SSEServerWrapper{
BaseServer: common.NewSSEServer(serverInstance, BaseServer: common.NewSSEServer(serverInstance,
common.WithRedisClient(common.GlobalRedisClient),
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)), common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
common.WithMessageEndpoint(serverPath)), common.WithMessageEndpoint(serverPath)),
DomainList: serverDomainList, DomainList: serverDomainList,

View File

@@ -9,8 +9,6 @@ import (
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
) )
var GlobalRedisClient *RedisClient
type RedisConfig struct { type RedisConfig struct {
address string address string
username string username string
@@ -74,9 +72,10 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
// Ping the Redis server to check the connection // Ping the Redis server to check the connection
pong, err := client.Ping(context.Background()).Result() pong, err := client.Ping(context.Background()).Result()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err) api.LogErrorf("Failed to connect to Redis: %v", err)
} else {
api.LogDebugf("Connected to Redis: %s", pong)
} }
api.LogDebugf("Connected to Redis: %s", pong)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -85,7 +84,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
crypto, err = NewCrypto(config.secret) crypto, err = NewCrypto(config.secret)
if err != nil { if err != nil {
cancel() cancel()
return nil, err api.LogWarnf("Failed to initialize redis crypto: %v", err)
} }
} }
@@ -105,7 +104,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
// keepAlive periodically checks Redis connection and attempts to reconnect if needed // keepAlive periodically checks Redis connection and attempts to reconnect if needed
func (r *RedisClient) keepAlive() { func (r *RedisClient) keepAlive() {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {

View File

@@ -210,7 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
var status int var status int
// Only send response if there is one (not for notifications) // Only send response if there is one (not for notifications)
if response != nil { if response != nil {
if sessionID != "" && s.redisClient != nil { if sessionID != ""{
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted status = http.StatusAccepted
} else { } else {

View File

@@ -25,12 +25,13 @@ type config struct {
enableUserLevelServer bool enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer defaultServer *common.SSEServer
redisClient *common.RedisClient
} }
func (c *config) Destroy() { func (c *config) Destroy() {
if common.GlobalRedisClient != nil { if c.redisClient != nil {
api.LogDebug("Closing Redis client") api.LogDebug("Closing Redis client")
common.GlobalRedisClient.Close() c.redisClient.Close()
} }
} }
@@ -63,10 +64,11 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
redisClient, err := common.NewRedisClient(redisConfig) redisClient, err := common.NewRedisClient(redisConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err) api.LogErrorf("Failed to initialize Redis client: %w", err)
} else {
api.LogDebug("Redis client initialized")
} }
common.GlobalRedisClient = redisClient conf.redisClient = redisClient
api.LogDebug("Redis client initialized")
} else { } else {
api.LogDebug("Redis configuration not provided, running without Redis") api.LogDebug("Redis configuration not provided, running without Redis")
} }
@@ -74,7 +76,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool) enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
if !ok { if !ok {
enableUserLevelServer = false enableUserLevelServer = false
if common.GlobalRedisClient == nil { if conf.redisClient == nil {
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true") return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
} }
} }
@@ -137,7 +139,7 @@ func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
callbacks: callbacks, callbacks: callbacks,
config: conf, config: conf,
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks), mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig), mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
} }
} }

View File

@@ -108,7 +108,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version), f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix), common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)), common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
common.WithRedisClient(common.GlobalRedisClient)) common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName() f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create" body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
@@ -165,7 +165,7 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
return api.Continue return api.Continue
} }
if f.serverName != "" { if f.serverName != "" {
if common.GlobalRedisClient != nil { if f.config.redisClient != nil {
header.Set("Content-Type", "text/event-stream") header.Set("Content-Type", "text/event-stream")
header.Set("Cache-Control", "no-cache") header.Set("Cache-Control", "no-cache")
header.Set("Connection", "keep-alive") header.Set("Connection", "keep-alive")
@@ -188,12 +188,12 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
if !endStream { if !endStream {
return api.StopAndBuffer return api.StopAndBuffer
} }
if f.proxyURL != nil && common.GlobalRedisClient != nil { if f.proxyURL != nil && f.config.redisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId") sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" { if sessionID != "" {
channel := common.GetSSEChannelName(sessionID) channel := common.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String()) eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := common.GlobalRedisClient.Publish(channel, eventData) publishErr := f.config.redisClient.Publish(channel, eventData)
if publishErr != nil { if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr) api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
} }
@@ -201,7 +201,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
} }
if f.serverName != "" { if f.serverName != "" {
if common.GlobalRedisClient != nil { if f.config.redisClient != nil {
// handle default server // handle default server
buffer.Reset() buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan) f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)