fix: make mcp server redis client config based (#2145)

Co-authored-by: daijingze_mac <18373118@buaa.edu.cn>
This commit is contained in:
Jingze
2025-04-29 14:27:48 +08:00
committed by GitHub
parent 806563298b
commit ab73f21017
5 changed files with 21 additions and 21 deletions

View File

@@ -104,7 +104,6 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
conf.servers = append(conf.servers, &SSEServerWrapper{
BaseServer: common.NewSSEServer(serverInstance,
common.WithRedisClient(common.GlobalRedisClient),
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
common.WithMessageEndpoint(serverPath)),
DomainList: serverDomainList,

View File

@@ -9,8 +9,6 @@ import (
"github.com/go-redis/redis/v8"
)
var GlobalRedisClient *RedisClient
type RedisConfig struct {
address string
username string
@@ -74,9 +72,10 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
// Ping the Redis server to check the connection
pong, err := client.Ping(context.Background()).Result()
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
api.LogErrorf("Failed to connect to Redis: %v", err)
} else {
api.LogDebugf("Connected to Redis: %s", pong)
}
api.LogDebugf("Connected to Redis: %s", pong)
ctx, cancel := context.WithCancel(context.Background())
@@ -85,7 +84,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
crypto, err = NewCrypto(config.secret)
if err != nil {
cancel()
return nil, err
api.LogWarnf("Failed to initialize redis crypto: %v", err)
}
}
@@ -105,7 +104,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
// keepAlive periodically checks Redis connection and attempts to reconnect if needed
func (r *RedisClient) keepAlive() {
ticker := time.NewTicker(30 * time.Second)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {

View File

@@ -210,7 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
var status int
// Only send response if there is one (not for notifications)
if response != nil {
if sessionID != "" && s.redisClient != nil {
if sessionID != ""{
w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted
} else {

View File

@@ -25,12 +25,13 @@ type config struct {
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer
redisClient *common.RedisClient
}
func (c *config) Destroy() {
if common.GlobalRedisClient != nil {
if c.redisClient != nil {
api.LogDebug("Closing Redis client")
common.GlobalRedisClient.Close()
c.redisClient.Close()
}
}
@@ -63,10 +64,11 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
redisClient, err := common.NewRedisClient(redisConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
api.LogErrorf("Failed to initialize Redis client: %w", err)
} else {
api.LogDebug("Redis client initialized")
}
common.GlobalRedisClient = redisClient
api.LogDebug("Redis client initialized")
conf.redisClient = redisClient
} else {
api.LogDebug("Redis configuration not provided, running without Redis")
}
@@ -74,7 +76,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
if !ok {
enableUserLevelServer = false
if common.GlobalRedisClient == nil {
if conf.redisClient == nil {
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
}
}
@@ -137,7 +139,7 @@ func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
callbacks: callbacks,
config: conf,
stopChan: make(chan struct{}),
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
}
}

View File

@@ -108,7 +108,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
common.WithRedisClient(common.GlobalRedisClient))
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
@@ -165,7 +165,7 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
return api.Continue
}
if f.serverName != "" {
if common.GlobalRedisClient != nil {
if f.config.redisClient != nil {
header.Set("Content-Type", "text/event-stream")
header.Set("Cache-Control", "no-cache")
header.Set("Connection", "keep-alive")
@@ -188,12 +188,12 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
if !endStream {
return api.StopAndBuffer
}
if f.proxyURL != nil && common.GlobalRedisClient != nil {
if f.proxyURL != nil && f.config.redisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" {
channel := common.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
publishErr := f.config.redisClient.Publish(channel, eventData)
if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
}
@@ -201,7 +201,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
}
if f.serverName != "" {
if common.GlobalRedisClient != nil {
if f.config.redisClient != nil {
// handle default server
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)