feat: add DB MCP Server execute, list tables, describe table tools (#2506)

Signed-off-by: hongzhouzi <weihongzhou.whz@alibaba-inc.com>
This commit is contained in:
hongzhouzi
2025-07-02 14:47:49 +08:00
committed by GitHub
parent 44566f5259
commit 783a8db512
3 changed files with 278 additions and 22 deletions

View File

@@ -25,6 +25,14 @@ type DBClient struct {
panicCount int32 // Add panic counter
}
// supports database types
const (
MYSQL = "mysql"
POSTGRES = "postgres"
CLICKHOUSE = "clickhouse"
SQLITE = "sqlite"
)
// NewDBClient creates a new DBClient instance and establishes a connection to the database
func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient {
client := &DBClient{
@@ -53,13 +61,13 @@ func (c *DBClient) connect() error {
}
switch c.dbType {
case "postgres":
case POSTGRES:
db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig)
case "clickhouse":
case CLICKHOUSE:
db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig)
case "mysql":
case MYSQL:
db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig)
case "sqlite":
case SQLITE:
db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig)
default:
return fmt.Errorf("unsupported database type %s", c.dbType)
@@ -125,25 +133,166 @@ func (c *DBClient) reconnectLoop() {
}
}
// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps
func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) {
func (c *DBClient) reconnectIfDbEmpty() error {
if c.db == nil {
// Trigger reconnection
select {
case c.reconnect <- struct{}{}:
default:
}
return nil, fmt.Errorf("database is not connected, attempting to reconnect")
return fmt.Errorf("database is not connected, attempting to reconnect")
}
return nil
}
rows, err := c.db.Raw(query, args...).Rows()
func (c *DBClient) handleSQLError(err error) error {
if err != nil {
// If execution fails, connection might be lost, trigger reconnection
select {
case c.reconnect <- struct{}{}:
default:
}
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
return fmt.Errorf("failed to execute SQL: %w", err)
}
return nil
}
// DescribeTable Get the structure of a specific table.
func (c *DBClient) DescribeTable(table string) ([]map[string]interface{}, error) {
var sql string
var args []string
switch c.dbType {
case MYSQL:
sql = `
select
column_name,
column_type,
is_nullable,
column_key,
column_default,
extra,
column_comment
from information_schema.columns
where table_schema = database() and table_name = ?
`
args = []string{table}
case POSTGRES:
sql = `
select
column_name,
data_type as column_type,
is_nullable,
case
when column_default like 'nextval%%' then 'auto_increment'
when column_default is not null then 'default'
else ''
end as column_key,
column_default,
case
when column_default like 'nextval%%' then 'auto_increment'
else ''
end as extra,
col_description((select oid from pg_class where relname = ?), ordinal_position) as column_comment
from information_schema.columns
where table_name = ?
`
args = []string{table, table}
case CLICKHOUSE:
sql = `
select
name as column_name,
type as column_type,
if(is_nullable, 'YES', 'NO') as is_nullable,
default_kind as column_key,
default_expression as column_default,
default_kind as extra,
comment as column_comment
from system.columns
where database = currentDatabase() and table = ?
`
args = []string{table}
case SQLITE:
sql = `
select
name as column_name,
type as column_type,
not (notnull = 1) as is_nullable,
pk as column_key,
dflt_value as column_default,
'' as extra,
'' as column_comment
from pragma_table_info(?)
`
args = []string{table}
default:
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
}
return c.Query(sql, args)
}
// ListTables List all tables in the connected database.
func (c *DBClient) ListTables() ([]string, error) {
var sql string
switch c.dbType {
case MYSQL:
sql = "show tables"
case POSTGRES:
sql = "select tablename from pg_tables where schemaname = 'public'"
case CLICKHOUSE:
sql = "select name from system.tables where database = currentDatabase()"
case SQLITE:
sql = "select name from sqlite_master where type='table'"
default:
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
}
rows, err := c.db.Raw(sql).Rows()
if err := c.handleSQLError(err); err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, fmt.Errorf("failed to scan table name: %w", err)
}
tables = append(tables, table)
}
return tables, nil
}
// Execute executes an INSERT, UPDATE, or DELETE raw SQL and returns the rows affected
func (c *DBClient) Execute(sql string, args ...interface{}) (int64, error) {
if err := c.reconnectIfDbEmpty(); err != nil {
return 0, err
}
tx := c.db.Exec(sql, args...)
if err := c.handleSQLError(tx.Error); err != nil {
return 0, err
}
defer tx.Commit()
return tx.RowsAffected, nil
}
// Query executes a raw SQL query and returns the result as a slice of maps
func (c *DBClient) Query(sql string, args ...interface{}) ([]map[string]interface{}, error) {
if err := c.reconnectIfDbEmpty(); err != nil {
return nil, err
}
rows, err := c.db.Raw(sql, args...).Rows()
if err := c.handleSQLError(err); err != nil {
return nil, err
}
defer rows.Close()

View File

@@ -49,11 +49,24 @@ func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
)
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
descriptionSuffix := fmt.Sprintf("in database %s. Database description: %s", c.dbType, c.description)
// Add query tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()),
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query %s", descriptionSuffix), GetQueryToolSchema()),
HandleQueryTool(dbClient),
)
mcpServer.AddTool(
mcp.NewToolWithRawSchema("execute", fmt.Sprintf("Execute an insert, update, or delete SQL %s", descriptionSuffix), GetExecuteToolSchema()),
HandleExecuteTool(dbClient),
)
mcpServer.AddTool(
mcp.NewToolWithRawSchema("list tables", fmt.Sprintf("List all tables %s", descriptionSuffix), GetListTablesToolSchema()),
HandleListTablesTool(dbClient),
)
mcpServer.AddTool(
mcp.NewToolWithRawSchema("describe table", fmt.Sprintf("Get the structure of a specific table %s", descriptionSuffix), GetDescribeTableToolSchema()),
HandleDescribeTableTool(dbClient),
)
return mcpServer, nil
}

View File

@@ -18,27 +18,80 @@ func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
return nil, fmt.Errorf("invalid message argument")
}
results, err := dbClient.ExecuteSQL(message)
results, err := dbClient.Query(message)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
jsonData, err := json.Marshal(results)
if err != nil {
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
return buildCallToolResult(results)
}
}
// HandleExecuteTool handles SQL INSERT, UPDATE, or DELETE execution
func HandleExecuteTool(dbClient *DBClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
message, ok := arguments["sql"].(string)
if !ok {
return nil, fmt.Errorf("invalid message argument")
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
results, err := dbClient.Execute(message)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
return buildCallToolResult(results)
}
}
// HandleListTablesTool handles list all tables
func HandleListTablesTool(dbClient *DBClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
results, err := dbClient.ListTables()
if err != nil {
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
return buildCallToolResult(results)
}
}
// HandleDescribeTableTool handles describe table
func HandleDescribeTableTool(dbClient *DBClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
message, ok := arguments["table"].(string)
if !ok {
return nil, fmt.Errorf("invalid message argument")
}
results, err := dbClient.DescribeTable(message)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
return buildCallToolResult(results)
}
}
// buildCallToolResult builds the call tool result
func buildCallToolResult(results any) (*mcp.CallToolResult, error) {
jsonData, err := json.Marshal(results)
if err != nil {
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
}
// GetQueryToolSchema returns the schema for query tool
func GetQueryToolSchema() json.RawMessage {
return json.RawMessage(`
@@ -53,3 +106,44 @@ func GetQueryToolSchema() json.RawMessage {
}
`)
}
// GetExecuteToolSchema returns the schema for execute tool
func GetExecuteToolSchema() json.RawMessage {
return json.RawMessage(`
{
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "The sql to execute"
}
}
}
`)
}
// GetDescribeTableToolSchema returns the schema for DescribeTable tool
func GetDescribeTableToolSchema() json.RawMessage {
return json.RawMessage(`
{
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "table name"
}
}
}
`)
}
// GetListTablesToolSchema returns the schema for ListTables tool
func GetListTablesToolSchema() json.RawMessage {
return json.RawMessage(`
{
"type": "object",
"properties": {
}
}
`)
}