mirror of
https://github.com/alibaba/higress.git
synced 2026-03-08 10:40:48 +08:00
feat: add DB MCP Server execute, list tables, describe table tools (#2506)
Signed-off-by: hongzhouzi <weihongzhou.whz@alibaba-inc.com>
This commit is contained in:
@@ -25,6 +25,14 @@ type DBClient struct {
|
||||
panicCount int32 // Add panic counter
|
||||
}
|
||||
|
||||
// supports database types
|
||||
const (
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
SQLITE = "sqlite"
|
||||
)
|
||||
|
||||
// NewDBClient creates a new DBClient instance and establishes a connection to the database
|
||||
func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient {
|
||||
client := &DBClient{
|
||||
@@ -53,13 +61,13 @@ func (c *DBClient) connect() error {
|
||||
}
|
||||
|
||||
switch c.dbType {
|
||||
case "postgres":
|
||||
case POSTGRES:
|
||||
db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig)
|
||||
case "clickhouse":
|
||||
case CLICKHOUSE:
|
||||
db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig)
|
||||
case "mysql":
|
||||
case MYSQL:
|
||||
db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig)
|
||||
case "sqlite":
|
||||
case SQLITE:
|
||||
db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type %s", c.dbType)
|
||||
@@ -125,25 +133,166 @@ func (c *DBClient) reconnectLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps
|
||||
func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
func (c *DBClient) reconnectIfDbEmpty() error {
|
||||
if c.db == nil {
|
||||
// Trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("database is not connected, attempting to reconnect")
|
||||
return fmt.Errorf("database is not connected, attempting to reconnect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query, args...).Rows()
|
||||
func (c *DBClient) handleSQLError(err error) error {
|
||||
if err != nil {
|
||||
// If execution fails, connection might be lost, trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DescribeTable Get the structure of a specific table.
|
||||
func (c *DBClient) DescribeTable(table string) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []string
|
||||
switch c.dbType {
|
||||
case MYSQL:
|
||||
sql = `
|
||||
select
|
||||
column_name,
|
||||
column_type,
|
||||
is_nullable,
|
||||
column_key,
|
||||
column_default,
|
||||
extra,
|
||||
column_comment
|
||||
from information_schema.columns
|
||||
where table_schema = database() and table_name = ?
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
case POSTGRES:
|
||||
sql = `
|
||||
select
|
||||
column_name,
|
||||
data_type as column_type,
|
||||
is_nullable,
|
||||
case
|
||||
when column_default like 'nextval%%' then 'auto_increment'
|
||||
when column_default is not null then 'default'
|
||||
else ''
|
||||
end as column_key,
|
||||
column_default,
|
||||
case
|
||||
when column_default like 'nextval%%' then 'auto_increment'
|
||||
else ''
|
||||
end as extra,
|
||||
col_description((select oid from pg_class where relname = ?), ordinal_position) as column_comment
|
||||
from information_schema.columns
|
||||
where table_name = ?
|
||||
`
|
||||
args = []string{table, table}
|
||||
|
||||
case CLICKHOUSE:
|
||||
sql = `
|
||||
select
|
||||
name as column_name,
|
||||
type as column_type,
|
||||
if(is_nullable, 'YES', 'NO') as is_nullable,
|
||||
default_kind as column_key,
|
||||
default_expression as column_default,
|
||||
default_kind as extra,
|
||||
comment as column_comment
|
||||
from system.columns
|
||||
where database = currentDatabase() and table = ?
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
case SQLITE:
|
||||
sql = `
|
||||
select
|
||||
name as column_name,
|
||||
type as column_type,
|
||||
not (notnull = 1) as is_nullable,
|
||||
pk as column_key,
|
||||
dflt_value as column_default,
|
||||
'' as extra,
|
||||
'' as column_comment
|
||||
from pragma_table_info(?)
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
|
||||
}
|
||||
|
||||
return c.Query(sql, args)
|
||||
}
|
||||
|
||||
// ListTables List all tables in the connected database.
|
||||
func (c *DBClient) ListTables() ([]string, error) {
|
||||
var sql string
|
||||
switch c.dbType {
|
||||
case MYSQL:
|
||||
sql = "show tables"
|
||||
case POSTGRES:
|
||||
sql = "select tablename from pg_tables where schemaname = 'public'"
|
||||
case CLICKHOUSE:
|
||||
sql = "select name from system.tables where database = currentDatabase()"
|
||||
case SQLITE:
|
||||
sql = "select name from sqlite_master where type='table'"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(sql).Rows()
|
||||
if err := c.handleSQLError(err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
if err := rows.Scan(&table); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// Execute executes an INSERT, UPDATE, or DELETE raw SQL and returns the rows affected
|
||||
func (c *DBClient) Execute(sql string, args ...interface{}) (int64, error) {
|
||||
if err := c.reconnectIfDbEmpty(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tx := c.db.Exec(sql, args...)
|
||||
if err := c.handleSQLError(tx.Error); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
return tx.RowsAffected, nil
|
||||
}
|
||||
|
||||
// Query executes a raw SQL query and returns the result as a slice of maps
|
||||
func (c *DBClient) Query(sql string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
if err := c.reconnectIfDbEmpty(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(sql, args...).Rows()
|
||||
if err := c.handleSQLError(err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
@@ -49,11 +49,24 @@ func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
)
|
||||
|
||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||
descriptionSuffix := fmt.Sprintf("in database %s. Database description: %s", c.dbType, c.description)
|
||||
// Add query tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()),
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query %s", descriptionSuffix), GetQueryToolSchema()),
|
||||
HandleQueryTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("execute", fmt.Sprintf("Execute an insert, update, or delete SQL %s", descriptionSuffix), GetExecuteToolSchema()),
|
||||
HandleExecuteTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("list tables", fmt.Sprintf("List all tables %s", descriptionSuffix), GetListTablesToolSchema()),
|
||||
HandleListTablesTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("describe table", fmt.Sprintf("Get the structure of a specific table %s", descriptionSuffix), GetDescribeTableToolSchema()),
|
||||
HandleDescribeTableTool(dbClient),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
|
||||
@@ -18,27 +18,80 @@ func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
results, err := dbClient.ExecuteSQL(message)
|
||||
results, err := dbClient.Query(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleExecuteTool handles SQL INSERT, UPDATE, or DELETE execution
|
||||
func HandleExecuteTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
results, err := dbClient.Execute(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleListTablesTool handles list all tables
|
||||
func HandleListTablesTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
results, err := dbClient.ListTables()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDescribeTableTool handles describe table
|
||||
func HandleDescribeTableTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["table"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
results, err := dbClient.DescribeTable(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// buildCallToolResult builds the call tool result
|
||||
func buildCallToolResult(results any) (*mcp.CallToolResult, error) {
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQueryToolSchema returns the schema for query tool
|
||||
func GetQueryToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
@@ -53,3 +106,44 @@ func GetQueryToolSchema() json.RawMessage {
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetExecuteToolSchema returns the schema for execute tool
|
||||
func GetExecuteToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {
|
||||
"type": "string",
|
||||
"description": "The sql to execute"
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetDescribeTableToolSchema returns the schema for DescribeTable tool
|
||||
func GetDescribeTableToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table": {
|
||||
"type": "string",
|
||||
"description": "table name"
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetListTablesToolSchema returns the schema for ListTables tool
|
||||
func GetListTablesToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user