mirror of
https://github.com/alibaba/higress.git
synced 2026-03-08 10:40:48 +08:00
fix: support mcp server database reconnect and fix tool/list method denied (#2074)
This commit is contained in:
@@ -49,6 +49,9 @@ func (c *config) Destroy() {
|
||||
api.LogDebug("Closing Redis client")
|
||||
c.redisClient.Close()
|
||||
}
|
||||
for _, server := range c.servers {
|
||||
server.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
@@ -127,6 +130,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
}
|
||||
}
|
||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||
rateLimitConfig.ErrorText = errorText
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -35,6 +37,7 @@ type filter struct {
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
ratelimit bool
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
}
|
||||
|
||||
@@ -110,6 +113,11 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
}
|
||||
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.internalIP {
|
||||
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
|
||||
@@ -118,13 +126,9 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.path, url.method, []byte{})
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
@@ -137,30 +141,23 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
f.proxyURL = url.parsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.parsedURL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogDebugf("Access denied: missing uid in path %s", url.parsedURL.Path)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "Access denied: missing uid", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
} else {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(url.parsedURL.Path, url.method, []byte{}) {
|
||||
return api.LocalReply
|
||||
if len(parts) >= 3 {
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
}
|
||||
}
|
||||
f.ratelimit = true
|
||||
}
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
if url.method != http.MethodGet {
|
||||
@@ -183,26 +180,50 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.message {
|
||||
if endStream {
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetMessageEndpoint() {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
server.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
f.message = false
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetMessageEndpoint() {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
f.message = false
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
return api.StopAndBuffer
|
||||
} else if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.path, f.req.Method, buffer.Bytes())
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
} else if f.ratelimit {
|
||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||
return api.Continue
|
||||
}
|
||||
parts := strings.Split(f.req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogWarnf("Access denied: no valid uid found")
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
@@ -287,3 +308,14 @@ func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if the request is a tools/call request
|
||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||
var request mcp.CallToolRequest
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return request.Method == method
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.Filter
|
||||
}
|
||||
|
||||
// HandleConfigRequest processes configuration requests
|
||||
func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body []byte) bool {
|
||||
func (h *MCPConfigHandler) HandleConfigRequest(req *http.Request, body []byte) bool {
|
||||
// Check if it's a configuration request
|
||||
if !strings.HasSuffix(path, "/config") {
|
||||
if !strings.HasSuffix(req.URL.Path, "/config") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract serverName and uid from path
|
||||
pathParts := strings.Split(strings.TrimSuffix(path, "/config"), "/")
|
||||
pathParts := strings.Split(strings.TrimSuffix(req.URL.Path, "/config"), "/")
|
||||
if len(pathParts) < 2 {
|
||||
h.sendErrorResponse(http.StatusBadRequest, "INVALID_PATH", "Invalid path format")
|
||||
return true
|
||||
@@ -41,7 +41,7 @@ func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body
|
||||
uid := pathParts[len(pathParts)-1]
|
||||
serverName := pathParts[len(pathParts)-2]
|
||||
|
||||
switch method {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
return h.handleGetConfig(serverName, uid)
|
||||
case http.MethodPost:
|
||||
@@ -70,10 +70,13 @@ func (h *MCPConfigHandler) handleGetConfig(serverName string, uid string) bool {
|
||||
}
|
||||
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
headers := map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
}
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
http.StatusOK,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
headers, 0, "",
|
||||
)
|
||||
return true
|
||||
}
|
||||
@@ -103,10 +106,13 @@ func (h *MCPConfigHandler) handleStoreConfig(serverName string, uid string, body
|
||||
}
|
||||
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
headers := map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
}
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
http.StatusOK,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
headers, 0, "",
|
||||
)
|
||||
return true
|
||||
}
|
||||
@@ -124,10 +130,13 @@ func (h *MCPConfigHandler) sendErrorResponse(status int, code string, message st
|
||||
},
|
||||
}
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
headers := map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
}
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
status,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
headers, 0, "",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type MCPRatelimitHandler struct {
|
||||
@@ -17,6 +19,7 @@ type MCPRatelimitHandler struct {
|
||||
limit int // Maximum requests allowed per window
|
||||
window int // Time window in seconds
|
||||
whitelist []string // Whitelist of UIDs that bypass rate limiting
|
||||
errorText string // Error text to be displayed
|
||||
}
|
||||
|
||||
// MCPRatelimitConfig is the configuration for the rate limit handler
|
||||
@@ -24,6 +27,7 @@ type MCPRatelimitConfig struct {
|
||||
Limit int `json:"limit"`
|
||||
Window int `json:"window"`
|
||||
Whitelist []string `json:"white_list"` // List of UIDs that bypass rate limiting
|
||||
ErrorText string `json:"error_text"` // Error text to be displayed
|
||||
}
|
||||
|
||||
// NewMCPRatelimitHandler creates a new rate limit handler
|
||||
@@ -33,6 +37,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil
|
||||
Limit: 100,
|
||||
Window: int(24 * time.Hour / time.Second), // 24 hours in seconds
|
||||
Whitelist: []string{},
|
||||
ErrorText: "API rate limit exceeded",
|
||||
}
|
||||
}
|
||||
return &MCPRatelimitHandler{
|
||||
@@ -41,6 +46,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil
|
||||
limit: conf.Limit,
|
||||
window: conf.Window,
|
||||
whitelist: conf.Whitelist,
|
||||
errorText: conf.ErrorText,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +68,9 @@ type LimitContext struct {
|
||||
Reset int // Time until reset in seconds
|
||||
}
|
||||
|
||||
func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body []byte) bool {
|
||||
parts := strings.Split(path, "/")
|
||||
// TODO: needs to be refactored, rate limit should be registered as a request hook in MCP server
|
||||
func (h *MCPRatelimitHandler) HandleRatelimit(req *http.Request, body []byte) bool {
|
||||
parts := strings.Split(req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return false
|
||||
@@ -106,13 +113,58 @@ func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body [
|
||||
}
|
||||
|
||||
if context.Remaining < 0 {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusTooManyRequests, "", nil, 0, "")
|
||||
// Create error response content
|
||||
errorContent := []mcp.TextContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: h.errorText,
|
||||
},
|
||||
}
|
||||
// Create response result
|
||||
result := map[string]interface{}{
|
||||
"content": errorContent,
|
||||
"isError": true,
|
||||
}
|
||||
// Create JSON-RPC response
|
||||
id := getJSONPRCID(body)
|
||||
response := mcp.JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: id,
|
||||
Result: result,
|
||||
}
|
||||
// Convert response to JSON
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to marshal JSON response: %v", err)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
|
||||
return false
|
||||
}
|
||||
// Send JSON-RPC response
|
||||
sessionID := req.URL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusAccepted, string(jsonResponse), nil, 0, "")
|
||||
} else {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, string(jsonResponse), nil, 0, "")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getJSONPRCID(body []byte) mcp.RequestId {
|
||||
baseMessage := struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
}{}
|
||||
if err := json.Unmarshal(body, &baseMessage); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return ""
|
||||
}
|
||||
return baseMessage.ID
|
||||
}
|
||||
|
||||
// parseRedisValue converts the value from Redis to an int
|
||||
func parseRedisValue(value interface{}) int {
|
||||
switch v := value.(type) {
|
||||
|
||||
@@ -78,6 +78,7 @@ type MCPServer struct {
|
||||
clientMu sync.Mutex // Separate mutex for client context
|
||||
currentClient NotificationContext
|
||||
initialized atomic.Bool // Use atomic for the initialized flag
|
||||
destory chan struct{}
|
||||
}
|
||||
|
||||
// serverKey is the context key for storing the server instance
|
||||
@@ -226,6 +227,7 @@ func NewMCPServer(
|
||||
prompts: nil,
|
||||
logging: false,
|
||||
},
|
||||
destory: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -826,6 +828,14 @@ func (s *MCPServer) handleNotification(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MCPServer) Close() {
|
||||
close(s.destory)
|
||||
}
|
||||
|
||||
func (s *MCPServer) GetDestoryChannel() chan struct{} {
|
||||
return s.destory
|
||||
}
|
||||
|
||||
func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage {
|
||||
return mcp.JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
|
||||
@@ -179,10 +179,10 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
|
||||
// handleMessage processes incoming JSON-RPC messages from clients and sends responses
|
||||
// back through both the SSE connection and HTTP response.
|
||||
func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) {
|
||||
func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) int {
|
||||
if r.Method != http.MethodPost {
|
||||
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, fmt.Sprintf("Method %s not allowed", r.Method))
|
||||
return
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
sessionID := r.URL.Query().Get("sessionId")
|
||||
@@ -207,7 +207,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
|
||||
// Process message through MCPServer
|
||||
response := s.server.HandleMessage(ctx, body)
|
||||
|
||||
var status int
|
||||
// Only send response if there is one (not for notifications)
|
||||
if response != nil {
|
||||
eventData, _ := json.Marshal(response)
|
||||
@@ -219,15 +219,22 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
} else {
|
||||
// support streamable http
|
||||
w.WriteHeader(http.StatusOK)
|
||||
status = http.StatusOK
|
||||
}
|
||||
// Send HTTP response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
} else {
|
||||
// For notifications, just send 202 Accepted with no body
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// writeJSONRPCError writes a JSON-RPC error response with the given error details.
|
||||
@@ -242,3 +249,7 @@ func (s *SSEServer) writeJSONRPCError(
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *SSEServer) Close() {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
@@ -1,47 +1,148 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"gorm.io/driver/clickhouse"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// DBClient is a struct to handle PostgreSQL connections and operations
|
||||
// DBClient is a struct to handle database connections and operations
|
||||
type DBClient struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
dsn string
|
||||
dbType string
|
||||
reconnect chan struct{}
|
||||
stop chan struct{}
|
||||
panicCount int32 // Add panic counter
|
||||
}
|
||||
|
||||
// NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database
|
||||
func NewDBClient(dsn string, dbType string) (*DBClient, error) {
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
if dbType == "postgres" {
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "clickhouse" {
|
||||
db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "mysql" {
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "sqlite" {
|
||||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported database type %s", dbType)
|
||||
}
|
||||
// Connect to the database
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
// NewDBClient creates a new DBClient instance and establishes a connection to the database
|
||||
func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient {
|
||||
client := &DBClient{
|
||||
dsn: dsn,
|
||||
dbType: dbType,
|
||||
reconnect: make(chan struct{}, 1),
|
||||
stop: stop,
|
||||
}
|
||||
|
||||
return &DBClient{db: db}, nil
|
||||
// Start reconnection goroutine
|
||||
go client.reconnectLoop()
|
||||
|
||||
// Try initial connection
|
||||
if err := client.connect(); err != nil {
|
||||
api.LogErrorf("Initial database connection failed: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *DBClient) connect() error {
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
gormConfig := gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
switch c.dbType {
|
||||
case "postgres":
|
||||
db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig)
|
||||
case "clickhouse":
|
||||
db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig)
|
||||
case "mysql":
|
||||
db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig)
|
||||
case "sqlite":
|
||||
db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type %s", c.dbType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
c.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DBClient) reconnectLoop() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Recovered from panic in reconnectLoop: %v", r)
|
||||
|
||||
// Increment panic counter
|
||||
atomic.AddInt32(&c.panicCount, 1)
|
||||
|
||||
// If panic count exceeds threshold, stop trying to reconnect
|
||||
if atomic.LoadInt32(&c.panicCount) > 3 {
|
||||
api.LogErrorf("Too many panics in reconnectLoop, stopping reconnection attempts")
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for a while before restarting
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Restart the reconnect loop
|
||||
go c.reconnectLoop()
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second) // Try to reconnect every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stop:
|
||||
api.LogInfof("Database %s connection closed", c.dbType)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.db == nil || c.Ping() != nil {
|
||||
if err := c.connect(); err != nil {
|
||||
api.LogErrorf("Database reconnection failed: %v", err)
|
||||
} else {
|
||||
api.LogInfof("Database reconnected successfully")
|
||||
// Reset panic count on successful connection
|
||||
atomic.StoreInt32(&c.panicCount, 0)
|
||||
}
|
||||
}
|
||||
case <-c.reconnect:
|
||||
if err := c.connect(); err != nil {
|
||||
api.LogErrorf("Database reconnection failed: %v", err)
|
||||
} else {
|
||||
api.LogInfof("Database reconnected successfully")
|
||||
// Reset panic count on successful connection
|
||||
atomic.StoreInt32(&c.panicCount, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps
|
||||
func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
if c.db == nil {
|
||||
// Trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("database is not connected, attempting to reconnect")
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query, args...).Rows()
|
||||
if err != nil {
|
||||
// If execution fails, connection might be lost, trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
@@ -88,3 +189,21 @@ func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]i
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *DBClient) Ping() error {
|
||||
if c.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Use context to set timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping the database
|
||||
sqlDB, err := c.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get underlying *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
return sqlDB.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -16,8 +16,9 @@ func init() {
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
dbType string
|
||||
dsn string
|
||||
dbType string
|
||||
dsn string
|
||||
description string
|
||||
}
|
||||
|
||||
func (c *DBConfig) ParseConfig(config map[string]any) error {
|
||||
@@ -33,6 +34,10 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
|
||||
}
|
||||
c.dbType = dbType
|
||||
api.LogDebugf("DBConfig ParseConfig: %+v", config)
|
||||
c.description, ok = config["description"].(string)
|
||||
if !ok {
|
||||
c.description = ""
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -43,14 +48,10 @@ func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
||||
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||
)
|
||||
|
||||
dbClient, err := NewDBClient(c.dsn, c.dbType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
|
||||
}
|
||||
|
||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||
// Add query tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s", c.dbType), GetQueryToolSchema()),
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()),
|
||||
HandleQueryTool(dbClient),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user