diff --git a/plugins/golang-filter/mcp-server/servers/gorm/db.go b/plugins/golang-filter/mcp-server/servers/gorm/db.go index 22547e789..4ba31614a 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/db.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/db.go @@ -25,6 +25,14 @@ type DBClient struct { panicCount int32 // Add panic counter } +// supports database types +const ( + MYSQL = "mysql" + POSTGRES = "postgres" + CLICKHOUSE = "clickhouse" + SQLITE = "sqlite" +) + // NewDBClient creates a new DBClient instance and establishes a connection to the database func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient { client := &DBClient{ @@ -53,13 +61,13 @@ func (c *DBClient) connect() error { } switch c.dbType { - case "postgres": + case POSTGRES: db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig) - case "clickhouse": + case CLICKHOUSE: db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig) - case "mysql": + case MYSQL: db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig) - case "sqlite": + case SQLITE: db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig) default: return fmt.Errorf("unsupported database type %s", c.dbType) @@ -125,25 +133,166 @@ func (c *DBClient) reconnectLoop() { } } -// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps -func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) { +func (c *DBClient) reconnectIfDbEmpty() error { if c.db == nil { // Trigger reconnection select { case c.reconnect <- struct{}{}: default: } - return nil, fmt.Errorf("database is not connected, attempting to reconnect") + return fmt.Errorf("database is not connected, attempting to reconnect") } + return nil +} - rows, err := c.db.Raw(query, args...).Rows() +func (c *DBClient) handleSQLError(err error) error { if err != nil { // If execution fails, connection might be lost, trigger reconnection select { case c.reconnect <- struct{}{}: default: } - return nil, fmt.Errorf("failed to execute SQL query: %w", err) + return fmt.Errorf("failed to execute SQL: %w", err) + } + return nil +} + +// DescribeTable Get the structure of a specific table. +func (c *DBClient) DescribeTable(table string) ([]map[string]interface{}, error) { + var sql string + var args []string + switch c.dbType { + case MYSQL: + sql = ` + select + column_name, + column_type, + is_nullable, + column_key, + column_default, + extra, + column_comment + from information_schema.columns + where table_schema = database() and table_name = ? + ` + args = []string{table} + + case POSTGRES: + sql = ` + select + column_name, + data_type as column_type, + is_nullable, + case + when column_default like 'nextval%%' then 'auto_increment' + when column_default is not null then 'default' + else '' + end as column_key, + column_default, + case + when column_default like 'nextval%%' then 'auto_increment' + else '' + end as extra, + col_description((select oid from pg_class where relname = ?), ordinal_position) as column_comment + from information_schema.columns + where table_name = ? + ` + args = []string{table, table} + + case CLICKHOUSE: + sql = ` + select + name as column_name, + type as column_type, + if(is_nullable, 'YES', 'NO') as is_nullable, + default_kind as column_key, + default_expression as column_default, + default_kind as extra, + comment as column_comment + from system.columns + where database = currentDatabase() and table = ? + ` + args = []string{table} + + case SQLITE: + sql = ` + select + name as column_name, + type as column_type, + not (notnull = 1) as is_nullable, + pk as column_key, + dflt_value as column_default, + '' as extra, + '' as column_comment + from pragma_table_info(?) + ` + args = []string{table} + + default: + return nil, fmt.Errorf("unsupported database type: %s", c.dbType) + } + + return c.Query(sql, args) +} + +// ListTables List all tables in the connected database. +func (c *DBClient) ListTables() ([]string, error) { + var sql string + switch c.dbType { + case MYSQL: + sql = "show tables" + case POSTGRES: + sql = "select tablename from pg_tables where schemaname = 'public'" + case CLICKHOUSE: + sql = "select name from system.tables where database = currentDatabase()" + case SQLITE: + sql = "select name from sqlite_master where type='table'" + default: + return nil, fmt.Errorf("unsupported database type: %s", c.dbType) + } + + rows, err := c.db.Raw(sql).Rows() + if err := c.handleSQLError(err); err != nil { + return nil, err + } + defer rows.Close() + + var tables []string + for rows.Next() { + var table string + if err := rows.Scan(&table); err != nil { + return nil, fmt.Errorf("failed to scan table name: %w", err) + } + tables = append(tables, table) + } + + return tables, nil +} + +// Execute executes an INSERT, UPDATE, or DELETE raw SQL and returns the rows affected +func (c *DBClient) Execute(sql string, args ...interface{}) (int64, error) { + if err := c.reconnectIfDbEmpty(); err != nil { + return 0, err + } + + tx := c.db.Exec(sql, args...) + if err := c.handleSQLError(tx.Error); err != nil { + return 0, err + } + defer tx.Commit() + + return tx.RowsAffected, nil +} + +// Query executes a raw SQL query and returns the result as a slice of maps +func (c *DBClient) Query(sql string, args ...interface{}) ([]map[string]interface{}, error) { + if err := c.reconnectIfDbEmpty(); err != nil { + return nil, err + } + + rows, err := c.db.Raw(sql, args...).Rows() + if err := c.handleSQLError(err); err != nil { + return nil, err } defer rows.Close() diff --git a/plugins/golang-filter/mcp-server/servers/gorm/server.go b/plugins/golang-filter/mcp-server/servers/gorm/server.go index 61c94e5aa..cb21122c7 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/server.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/server.go @@ -49,11 +49,24 @@ func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) { ) dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel()) + descriptionSuffix := fmt.Sprintf("in database %s. Database description: %s", c.dbType, c.description) // Add query tool mcpServer.AddTool( - mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()), + mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query %s", descriptionSuffix), GetQueryToolSchema()), HandleQueryTool(dbClient), ) + mcpServer.AddTool( + mcp.NewToolWithRawSchema("execute", fmt.Sprintf("Execute an insert, update, or delete SQL %s", descriptionSuffix), GetExecuteToolSchema()), + HandleExecuteTool(dbClient), + ) + mcpServer.AddTool( + mcp.NewToolWithRawSchema("list tables", fmt.Sprintf("List all tables %s", descriptionSuffix), GetListTablesToolSchema()), + HandleListTablesTool(dbClient), + ) + mcpServer.AddTool( + mcp.NewToolWithRawSchema("describe table", fmt.Sprintf("Get the structure of a specific table %s", descriptionSuffix), GetDescribeTableToolSchema()), + HandleDescribeTableTool(dbClient), + ) return mcpServer, nil } diff --git a/plugins/golang-filter/mcp-server/servers/gorm/tools.go b/plugins/golang-filter/mcp-server/servers/gorm/tools.go index 4c4ceda12..16b825f2c 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/tools.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/tools.go @@ -18,27 +18,80 @@ func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc { return nil, fmt.Errorf("invalid message argument") } - results, err := dbClient.ExecuteSQL(message) + results, err := dbClient.Query(message) if err != nil { return nil, fmt.Errorf("failed to execute SQL query: %w", err) } - jsonData, err := json.Marshal(results) - if err != nil { - return nil, fmt.Errorf("failed to marshal SQL results: %w", err) + return buildCallToolResult(results) + } +} + +// HandleExecuteTool handles SQL INSERT, UPDATE, or DELETE execution +func HandleExecuteTool(dbClient *DBClient) common.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.Params.Arguments + message, ok := arguments["sql"].(string) + if !ok { + return nil, fmt.Errorf("invalid message argument") } - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: string(jsonData), - }, - }, - }, nil + results, err := dbClient.Execute(message) + if err != nil { + return nil, fmt.Errorf("failed to execute SQL query: %w", err) + } + + return buildCallToolResult(results) } } +// HandleListTablesTool handles list all tables +func HandleListTablesTool(dbClient *DBClient) common.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + results, err := dbClient.ListTables() + if err != nil { + return nil, fmt.Errorf("failed to execute SQL query: %w", err) + } + + return buildCallToolResult(results) + } +} + +// HandleDescribeTableTool handles describe table +func HandleDescribeTableTool(dbClient *DBClient) common.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.Params.Arguments + message, ok := arguments["table"].(string) + if !ok { + return nil, fmt.Errorf("invalid message argument") + } + + results, err := dbClient.DescribeTable(message) + if err != nil { + return nil, fmt.Errorf("failed to execute SQL query: %w", err) + } + + return buildCallToolResult(results) + } +} + +// buildCallToolResult builds the call tool result +func buildCallToolResult(results any) (*mcp.CallToolResult, error) { + jsonData, err := json.Marshal(results) + if err != nil { + return nil, fmt.Errorf("failed to marshal SQL results: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: string(jsonData), + }, + }, + }, nil +} + // GetQueryToolSchema returns the schema for query tool func GetQueryToolSchema() json.RawMessage { return json.RawMessage(` @@ -53,3 +106,44 @@ func GetQueryToolSchema() json.RawMessage { } `) } + +// GetExecuteToolSchema returns the schema for execute tool +func GetExecuteToolSchema() json.RawMessage { + return json.RawMessage(` + { + "type": "object", + "properties": { + "sql": { + "type": "string", + "description": "The sql to execute" + } + } + } + `) +} + +// GetDescribeTableToolSchema returns the schema for DescribeTable tool +func GetDescribeTableToolSchema() json.RawMessage { + return json.RawMessage(` + { + "type": "object", + "properties": { + "table": { + "type": "string", + "description": "table name" + } + } + } + `) +} + +// GetListTablesToolSchema returns the schema for ListTables tool +func GetListTablesToolSchema() json.RawMessage { + return json.RawMessage(` + { + "type": "object", + "properties": { + } + } + `) +}