From 1834d4acefeb7f4114bb4aa944b9c60a29707907 Mon Sep 17 00:00:00 2001 From: Jingze <52855280+Jing-ze@users.noreply.github.com> Date: Fri, 18 Apr 2025 11:19:56 +0800 Subject: [PATCH] fix: support mcp server database reconnect and fix tool/list method denied (#2074) --- plugins/golang-filter/mcp-server/config.go | 6 + plugins/golang-filter/mcp-server/filter.go | 110 +++++++----- .../mcp-server/handler/config_handler.go | 23 ++- .../mcp-server/handler/rate_limit_handler.go | 58 ++++++- .../mcp-server/internal/server.go | 10 ++ .../golang-filter/mcp-server/internal/sse.go | 19 ++- .../mcp-server/servers/gorm/db.go | 161 +++++++++++++++--- .../mcp-server/servers/gorm/server.go | 17 +- 8 files changed, 322 insertions(+), 82 deletions(-) diff --git a/plugins/golang-filter/mcp-server/config.go b/plugins/golang-filter/mcp-server/config.go index 1c846adac..cc5ad40ba 100644 --- a/plugins/golang-filter/mcp-server/config.go +++ b/plugins/golang-filter/mcp-server/config.go @@ -49,6 +49,9 @@ func (c *config) Destroy() { api.LogDebug("Closing Redis client") c.redisClient.Close() } + for _, server := range c.servers { + server.Close() + } } type parser struct { @@ -127,6 +130,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int } } } + if errorText, ok := rateLimit["error_text"].(string); ok { + rateLimitConfig.ErrorText = errorText + } conf.rateLimitConfig = rateLimitConfig } diff --git a/plugins/golang-filter/mcp-server/filter.go b/plugins/golang-filter/mcp-server/filter.go index abe54d730..5fb5ea010 100644 --- a/plugins/golang-filter/mcp-server/filter.go +++ b/plugins/golang-filter/mcp-server/filter.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -11,6 +12,7 @@ import ( "github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/mark3labs/mcp-go/mcp" ) const ( @@ -35,6 +37,7 @@ type filter struct { userLevelConfig bool mcpConfigHandler *handler.MCPConfigHandler + ratelimit bool mcpRatelimitHandler *handler.MCPRatelimitHandler } @@ -110,6 +113,11 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. } } + f.req = &http.Request{ + Method: url.method, + URL: url.parsedURL, + } + if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer { if !url.internalIP { api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String()) @@ -118,13 +126,9 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. } if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet { api.LogDebugf("Handling config request: %s", f.path) - f.mcpConfigHandler.HandleConfigRequest(f.path, url.method, []byte{}) + f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{}) return api.LocalReply } - f.req = &http.Request{ - Method: url.method, - URL: url.parsedURL, - } f.userLevelConfig = true if endStream { return api.Continue @@ -137,30 +141,23 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. f.proxyURL = url.parsedURL if f.config.enableUserLevelServer { parts := strings.Split(url.parsedURL.Path, "/") - if len(parts) < 3 { - api.LogDebugf("Access denied: missing uid in path %s", url.parsedURL.Path) - f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "Access denied: missing uid", nil, 0, "") - return api.LocalReply - } - serverName := parts[1] - uid := parts[2] - // Get encoded config - encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid) - if err != nil { - api.LogWarnf("Access denied: no valid config found for uid %s", uid) - f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "") - return api.LocalReply - } else if encodedConfig != "" { - header.Set("x-higress-mcpserver-config", encodedConfig) - api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid) - } else { - api.LogDebugf("Empty config found for %s:%s", serverName, uid) - if !f.mcpRatelimitHandler.HandleRatelimit(url.parsedURL.Path, url.method, []byte{}) { - return api.LocalReply + if len(parts) >= 3 { + serverName := parts[1] + uid := parts[2] + // Get encoded config + encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid) + if encodedConfig != "" { + header.Set("x-higress-mcpserver-config", encodedConfig) + api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid) } } + f.ratelimit = true + } + if endStream { + return api.Continue + } else { + return api.StopAndBuffer } - return api.Continue } if url.method != http.MethodGet { @@ -183,26 +180,50 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu if f.skip { return api.Continue } + if !endStream { + return api.StopAndBuffer + } if f.message { - if endStream { - for _, server := range f.config.servers { - if f.path == server.GetMessageEndpoint() { - // Create a response recorder to capture the response - recorder := httptest.NewRecorder() - // Call the handleMessage method of SSEServer with complete body - server.HandleMessage(recorder, f.req, buffer.Bytes()) - f.message = false - f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "") - return api.LocalReply - } + for _, server := range f.config.servers { + if f.path == server.GetMessageEndpoint() { + // Create a response recorder to capture the response + recorder := httptest.NewRecorder() + // Call the handleMessage method of SSEServer with complete body + httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes()) + f.message = false + f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "") + return api.LocalReply } } - return api.StopAndBuffer } else if f.userLevelConfig { // Handle config POST request api.LogDebugf("Handling config request: %s", f.path) - f.mcpConfigHandler.HandleConfigRequest(f.path, f.req.Method, buffer.Bytes()) + f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes()) return api.LocalReply + } else if f.ratelimit { + if checkJSONRPCMethod(buffer.Bytes(), "tools/list") { + api.LogDebugf("Not a tools call request, skipping ratelimit") + return api.Continue + } + parts := strings.Split(f.req.URL.Path, "/") + if len(parts) < 3 { + api.LogWarnf("Access denied: no valid uid found") + f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "") + return api.LocalReply + } + serverName := parts[1] + uid := parts[2] + encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid) + if err != nil { + api.LogWarnf("Access denied: no valid config found for uid %s", uid) + f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "") + return api.LocalReply + } else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") { + api.LogDebugf("Empty config found for %s:%s", serverName, uid) + if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) { + return api.LocalReply + } + } } return api.Continue } @@ -287,3 +308,14 @@ func (f *filter) OnDestroy(reason api.DestroyReason) { } } } + +// check if the request is a tools/call request +func checkJSONRPCMethod(body []byte, method string) bool { + var request mcp.CallToolRequest + if err := json.Unmarshal(body, &request); err != nil { + api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err) + return true + } + + return request.Method == method +} diff --git a/plugins/golang-filter/mcp-server/handler/config_handler.go b/plugins/golang-filter/mcp-server/handler/config_handler.go index ceff4fc70..cf88260d7 100644 --- a/plugins/golang-filter/mcp-server/handler/config_handler.go +++ b/plugins/golang-filter/mcp-server/handler/config_handler.go @@ -26,14 +26,14 @@ func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.Filter } // HandleConfigRequest processes configuration requests -func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body []byte) bool { +func (h *MCPConfigHandler) HandleConfigRequest(req *http.Request, body []byte) bool { // Check if it's a configuration request - if !strings.HasSuffix(path, "/config") { + if !strings.HasSuffix(req.URL.Path, "/config") { return false } // Extract serverName and uid from path - pathParts := strings.Split(strings.TrimSuffix(path, "/config"), "/") + pathParts := strings.Split(strings.TrimSuffix(req.URL.Path, "/config"), "/") if len(pathParts) < 2 { h.sendErrorResponse(http.StatusBadRequest, "INVALID_PATH", "Invalid path format") return true @@ -41,7 +41,7 @@ func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body uid := pathParts[len(pathParts)-1] serverName := pathParts[len(pathParts)-2] - switch method { + switch req.Method { case http.MethodGet: return h.handleGetConfig(serverName, uid) case http.MethodPost: @@ -70,10 +70,13 @@ func (h *MCPConfigHandler) handleGetConfig(serverName string, uid string) bool { } responseBytes, _ := json.Marshal(response) + headers := map[string][]string{ + "Content-Type": {"application/json"}, + } h.callbacks.DecoderFilterCallbacks().SendLocalReply( http.StatusOK, string(responseBytes), - nil, 0, "", + headers, 0, "", ) return true } @@ -103,10 +106,13 @@ func (h *MCPConfigHandler) handleStoreConfig(serverName string, uid string, body } responseBytes, _ := json.Marshal(response) + headers := map[string][]string{ + "Content-Type": {"application/json"}, + } h.callbacks.DecoderFilterCallbacks().SendLocalReply( http.StatusOK, string(responseBytes), - nil, 0, "", + headers, 0, "", ) return true } @@ -124,10 +130,13 @@ func (h *MCPConfigHandler) sendErrorResponse(status int, code string, message st }, } responseBytes, _ := json.Marshal(response) + headers := map[string][]string{ + "Content-Type": {"application/json"}, + } h.callbacks.DecoderFilterCallbacks().SendLocalReply( status, string(responseBytes), - nil, 0, "", + headers, 0, "", ) } diff --git a/plugins/golang-filter/mcp-server/handler/rate_limit_handler.go b/plugins/golang-filter/mcp-server/handler/rate_limit_handler.go index e08dbc3d6..0c583aea9 100644 --- a/plugins/golang-filter/mcp-server/handler/rate_limit_handler.go +++ b/plugins/golang-filter/mcp-server/handler/rate_limit_handler.go @@ -1,6 +1,7 @@ package handler import ( + "encoding/json" "fmt" "net/http" "strconv" @@ -9,6 +10,7 @@ import ( "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/mark3labs/mcp-go/mcp" ) type MCPRatelimitHandler struct { @@ -17,6 +19,7 @@ type MCPRatelimitHandler struct { limit int // Maximum requests allowed per window window int // Time window in seconds whitelist []string // Whitelist of UIDs that bypass rate limiting + errorText string // Error text to be displayed } // MCPRatelimitConfig is the configuration for the rate limit handler @@ -24,6 +27,7 @@ type MCPRatelimitConfig struct { Limit int `json:"limit"` Window int `json:"window"` Whitelist []string `json:"white_list"` // List of UIDs that bypass rate limiting + ErrorText string `json:"error_text"` // Error text to be displayed } // NewMCPRatelimitHandler creates a new rate limit handler @@ -33,6 +37,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil Limit: 100, Window: int(24 * time.Hour / time.Second), // 24 hours in seconds Whitelist: []string{}, + ErrorText: "API rate limit exceeded", } } return &MCPRatelimitHandler{ @@ -41,6 +46,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil limit: conf.Limit, window: conf.Window, whitelist: conf.Whitelist, + errorText: conf.ErrorText, } } @@ -62,8 +68,9 @@ type LimitContext struct { Reset int // Time until reset in seconds } -func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body []byte) bool { - parts := strings.Split(path, "/") +// TODO: needs to be refactored, rate limit should be registered as a request hook in MCP server +func (h *MCPRatelimitHandler) HandleRatelimit(req *http.Request, body []byte) bool { + parts := strings.Split(req.URL.Path, "/") if len(parts) < 3 { h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "") return false @@ -106,13 +113,58 @@ func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body [ } if context.Remaining < 0 { - h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusTooManyRequests, "", nil, 0, "") + // Create error response content + errorContent := []mcp.TextContent{ + { + Type: "text", + Text: h.errorText, + }, + } + // Create response result + result := map[string]interface{}{ + "content": errorContent, + "isError": true, + } + // Create JSON-RPC response + id := getJSONPRCID(body) + response := mcp.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Result: result, + } + // Convert response to JSON + jsonResponse, err := json.Marshal(response) + if err != nil { + api.LogErrorf("Failed to marshal JSON response: %v", err) + h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "") + return false + } + // Send JSON-RPC response + sessionID := req.URL.Query().Get("sessionId") + if sessionID != "" { + h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusAccepted, string(jsonResponse), nil, 0, "") + } else { + h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, string(jsonResponse), nil, 0, "") + } return false } return true } +func getJSONPRCID(body []byte) mcp.RequestId { + baseMessage := struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID interface{} `json:"id,omitempty"` + }{} + if err := json.Unmarshal(body, &baseMessage); err != nil { + api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err) + return "" + } + return baseMessage.ID +} + // parseRedisValue converts the value from Redis to an int func parseRedisValue(value interface{}) int { switch v := value.(type) { diff --git a/plugins/golang-filter/mcp-server/internal/server.go b/plugins/golang-filter/mcp-server/internal/server.go index 70940c663..6745d1b2a 100644 --- a/plugins/golang-filter/mcp-server/internal/server.go +++ b/plugins/golang-filter/mcp-server/internal/server.go @@ -78,6 +78,7 @@ type MCPServer struct { clientMu sync.Mutex // Separate mutex for client context currentClient NotificationContext initialized atomic.Bool // Use atomic for the initialized flag + destory chan struct{} } // serverKey is the context key for storing the server instance @@ -226,6 +227,7 @@ func NewMCPServer( prompts: nil, logging: false, }, + destory: make(chan struct{}), } for _, opt := range opts { @@ -826,6 +828,14 @@ func (s *MCPServer) handleNotification( return nil } +func (s *MCPServer) Close() { + close(s.destory) +} + +func (s *MCPServer) GetDestoryChannel() chan struct{} { + return s.destory +} + func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage { return mcp.JSONRPCResponse{ JSONRPC: mcp.JSONRPC_VERSION, diff --git a/plugins/golang-filter/mcp-server/internal/sse.go b/plugins/golang-filter/mcp-server/internal/sse.go index 770a70a03..56ede23b0 100644 --- a/plugins/golang-filter/mcp-server/internal/sse.go +++ b/plugins/golang-filter/mcp-server/internal/sse.go @@ -179,10 +179,10 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct // handleMessage processes incoming JSON-RPC messages from clients and sends responses // back through both the SSE connection and HTTP response. -func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) { +func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) int { if r.Method != http.MethodPost { s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, fmt.Sprintf("Method %s not allowed", r.Method)) - return + return http.StatusBadRequest } sessionID := r.URL.Query().Get("sessionId") @@ -207,7 +207,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j // Process message through MCPServer response := s.server.HandleMessage(ctx, body) - + var status int // Only send response if there is one (not for notifications) if response != nil { eventData, _ := json.Marshal(response) @@ -219,15 +219,22 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j if publishErr != nil { api.LogErrorf("Failed to publish message to Redis: %v", publishErr) } + w.WriteHeader(http.StatusAccepted) + status = http.StatusAccepted + } else { + // support streamable http + w.WriteHeader(http.StatusOK) + status = http.StatusOK } // Send HTTP response w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusAccepted) json.NewEncoder(w).Encode(response) } else { // For notifications, just send 202 Accepted with no body w.WriteHeader(http.StatusAccepted) + status = http.StatusAccepted } + return status } // writeJSONRPCError writes a JSON-RPC error response with the given error details. @@ -242,3 +249,7 @@ func (s *SSEServer) writeJSONRPCError( w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(response) } + +func (s *SSEServer) Close() { + s.server.Close() +} diff --git a/plugins/golang-filter/mcp-server/servers/gorm/db.go b/plugins/golang-filter/mcp-server/servers/gorm/db.go index a5795c869..22547e789 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/db.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/db.go @@ -1,47 +1,148 @@ package gorm import ( + "context" "fmt" + "sync/atomic" + "time" + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" "gorm.io/driver/clickhouse" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" + "gorm.io/gorm/logger" ) -// DBClient is a struct to handle PostgreSQL connections and operations +// DBClient is a struct to handle database connections and operations type DBClient struct { - db *gorm.DB + db *gorm.DB + dsn string + dbType string + reconnect chan struct{} + stop chan struct{} + panicCount int32 // Add panic counter } -// NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database -func NewDBClient(dsn string, dbType string) (*DBClient, error) { - var db *gorm.DB - var err error - if dbType == "postgres" { - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) - } else if dbType == "clickhouse" { - db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{}) - } else if dbType == "mysql" { - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) - } else if dbType == "sqlite" { - db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) - } else { - return nil, fmt.Errorf("unsupported database type %s", dbType) - } - // Connect to the database - if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) +// NewDBClient creates a new DBClient instance and establishes a connection to the database +func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient { + client := &DBClient{ + dsn: dsn, + dbType: dbType, + reconnect: make(chan struct{}, 1), + stop: stop, } - return &DBClient{db: db}, nil + // Start reconnection goroutine + go client.reconnectLoop() + + // Try initial connection + if err := client.connect(); err != nil { + api.LogErrorf("Initial database connection failed: %v", err) + } + + return client +} + +func (c *DBClient) connect() error { + var db *gorm.DB + var err error + gormConfig := gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + + switch c.dbType { + case "postgres": + db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig) + case "clickhouse": + db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig) + case "mysql": + db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig) + case "sqlite": + db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig) + default: + return fmt.Errorf("unsupported database type %s", c.dbType) + } + + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + c.db = db + return nil +} + +func (c *DBClient) reconnectLoop() { + defer func() { + if r := recover(); r != nil { + api.LogErrorf("Recovered from panic in reconnectLoop: %v", r) + + // Increment panic counter + atomic.AddInt32(&c.panicCount, 1) + + // If panic count exceeds threshold, stop trying to reconnect + if atomic.LoadInt32(&c.panicCount) > 3 { + api.LogErrorf("Too many panics in reconnectLoop, stopping reconnection attempts") + return + } + + // Wait for a while before restarting + time.Sleep(5 * time.Second) + + // Restart the reconnect loop + go c.reconnectLoop() + } + }() + + ticker := time.NewTicker(30 * time.Second) // Try to reconnect every 30 seconds + defer ticker.Stop() + + for { + select { + case <-c.stop: + api.LogInfof("Database %s connection closed", c.dbType) + return + case <-ticker.C: + if c.db == nil || c.Ping() != nil { + if err := c.connect(); err != nil { + api.LogErrorf("Database reconnection failed: %v", err) + } else { + api.LogInfof("Database reconnected successfully") + // Reset panic count on successful connection + atomic.StoreInt32(&c.panicCount, 0) + } + } + case <-c.reconnect: + if err := c.connect(); err != nil { + api.LogErrorf("Database reconnection failed: %v", err) + } else { + api.LogInfof("Database reconnected successfully") + // Reset panic count on successful connection + atomic.StoreInt32(&c.panicCount, 0) + } + } + } } // ExecuteSQL executes a raw SQL query and returns the result as a slice of maps func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) { + if c.db == nil { + // Trigger reconnection + select { + case c.reconnect <- struct{}{}: + default: + } + return nil, fmt.Errorf("database is not connected, attempting to reconnect") + } + rows, err := c.db.Raw(query, args...).Rows() if err != nil { + // If execution fails, connection might be lost, trigger reconnection + select { + case c.reconnect <- struct{}{}: + default: + } return nil, fmt.Errorf("failed to execute SQL query: %w", err) } defer rows.Close() @@ -88,3 +189,21 @@ func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]i return results, nil } + +func (c *DBClient) Ping() error { + if c.db == nil { + return fmt.Errorf("database connection is nil") + } + + // Use context to set timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // Try to ping the database + sqlDB, err := c.db.DB() + if err != nil { + return fmt.Errorf("failed to get underlying *sql.DB: %v", err) + } + + return sqlDB.PingContext(ctx) +} diff --git a/plugins/golang-filter/mcp-server/servers/gorm/server.go b/plugins/golang-filter/mcp-server/servers/gorm/server.go index 612bfa389..3ce9cfb08 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/server.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/server.go @@ -16,8 +16,9 @@ func init() { } type DBConfig struct { - dbType string - dsn string + dbType string + dsn string + description string } func (c *DBConfig) ParseConfig(config map[string]any) error { @@ -33,6 +34,10 @@ func (c *DBConfig) ParseConfig(config map[string]any) error { } c.dbType = dbType api.LogDebugf("DBConfig ParseConfig: %+v", config) + c.description, ok = config["description"].(string) + if !ok { + c.description = "" + } return nil } @@ -43,14 +48,10 @@ func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) { internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)), ) - dbClient, err := NewDBClient(c.dsn, c.dbType) - if err != nil { - return nil, fmt.Errorf("failed to initialize DBClient: %w", err) - } - + dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel()) // Add query tool mcpServer.AddTool( - mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s", c.dbType), GetQueryToolSchema()), + mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()), HandleQueryTool(dbClient), )