mirror of
https://github.com/alibaba/higress.git
synced 2026-03-20 02:07:27 +08:00
fix: make mcp server redis client config based (#2145)
Co-authored-by: daijingze_mac <18373118@buaa.edu.cn>
This commit is contained in:
@@ -104,7 +104,6 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
|
||||
conf.servers = append(conf.servers, &SSEServerWrapper{
|
||||
BaseServer: common.NewSSEServer(serverInstance,
|
||||
common.WithRedisClient(common.GlobalRedisClient),
|
||||
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
||||
common.WithMessageEndpoint(serverPath)),
|
||||
DomainList: serverDomainList,
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var GlobalRedisClient *RedisClient
|
||||
|
||||
type RedisConfig struct {
|
||||
address string
|
||||
username string
|
||||
@@ -74,9 +72,10 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
// Ping the Redis server to check the connection
|
||||
pong, err := client.Ping(context.Background()).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
api.LogErrorf("Failed to connect to Redis: %v", err)
|
||||
} else {
|
||||
api.LogDebugf("Connected to Redis: %s", pong)
|
||||
}
|
||||
api.LogDebugf("Connected to Redis: %s", pong)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
@@ -85,7 +84,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
crypto, err = NewCrypto(config.secret)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
api.LogWarnf("Failed to initialize redis crypto: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,7 +104,7 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
|
||||
// keepAlive periodically checks Redis connection and attempts to reconnect if needed
|
||||
func (r *RedisClient) keepAlive() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@@ -210,7 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
var status int
|
||||
// Only send response if there is one (not for notifications)
|
||||
if response != nil {
|
||||
if sessionID != "" && s.redisClient != nil {
|
||||
if sessionID != ""{
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
} else {
|
||||
|
||||
@@ -25,12 +25,13 @@ type config struct {
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
defaultServer *common.SSEServer
|
||||
redisClient *common.RedisClient
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
if common.GlobalRedisClient != nil {
|
||||
if c.redisClient != nil {
|
||||
api.LogDebug("Closing Redis client")
|
||||
common.GlobalRedisClient.Close()
|
||||
c.redisClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,10 +64,11 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
|
||||
redisClient, err := common.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
api.LogErrorf("Failed to initialize Redis client: %w", err)
|
||||
} else {
|
||||
api.LogDebug("Redis client initialized")
|
||||
}
|
||||
common.GlobalRedisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
conf.redisClient = redisClient
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
@@ -74,7 +76,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
enableUserLevelServer = false
|
||||
if common.GlobalRedisClient == nil {
|
||||
if conf.redisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
@@ -137,7 +139,7 @@ func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
|
||||
common.WithRedisClient(common.GlobalRedisClient))
|
||||
common.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
@@ -165,7 +165,7 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
if f.config.redisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
@@ -188,12 +188,12 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil && common.GlobalRedisClient != nil {
|
||||
if f.proxyURL != nil && f.config.redisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := common.GetSSEChannelName(sessionID)
|
||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
|
||||
publishErr := f.config.redisClient.Publish(channel, eventData)
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
if f.config.redisClient != nil {
|
||||
// handle default server
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
|
||||
Reference in New Issue
Block a user