fix: support mcp server database reconnect and fix tool/list method denied (#2074)

This commit is contained in:
Jingze
2025-04-18 11:19:56 +08:00
committed by GitHub
parent 7f9ae38e51
commit 1834d4acef
8 changed files with 322 additions and 82 deletions

View File

@@ -49,6 +49,9 @@ func (c *config) Destroy() {
api.LogDebug("Closing Redis client")
c.redisClient.Close()
}
for _, server := range c.servers {
server.Close()
}
}
type parser struct {
@@ -127,6 +130,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
}
}
}
if errorText, ok := rateLimit["error_text"].(string); ok {
rateLimitConfig.ErrorText = errorText
}
conf.rateLimitConfig = rateLimitConfig
}

View File

@@ -1,6 +1,7 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -11,6 +12,7 @@ import (
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
@@ -35,6 +37,7 @@ type filter struct {
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
ratelimit bool
mcpRatelimitHandler *handler.MCPRatelimitHandler
}
@@ -110,6 +113,11 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
}
f.req = &http.Request{
Method: url.method,
URL: url.parsedURL,
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !url.internalIP {
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
@@ -118,13 +126,9 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.path, url.method, []byte{})
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
return api.LocalReply
}
f.req = &http.Request{
Method: url.method,
URL: url.parsedURL,
}
f.userLevelConfig = true
if endStream {
return api.Continue
@@ -137,30 +141,23 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.proxyURL = url.parsedURL
if f.config.enableUserLevelServer {
parts := strings.Split(url.parsedURL.Path, "/")
if len(parts) < 3 {
api.LogDebugf("Access denied: missing uid in path %s", url.parsedURL.Path)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "Access denied: missing uid", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
} else {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(url.parsedURL.Path, url.method, []byte{}) {
return api.LocalReply
if len(parts) >= 3 {
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
}
}
f.ratelimit = true
}
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
return api.Continue
}
if url.method != http.MethodGet {
@@ -183,26 +180,50 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu
if f.skip {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.message {
if endStream {
for _, server := range f.config.servers {
if f.path == server.GetMessageEndpoint() {
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer with complete body
server.HandleMessage(recorder, f.req, buffer.Bytes())
f.message = false
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
return api.LocalReply
}
for _, server := range f.config.servers {
if f.path == server.GetMessageEndpoint() {
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer with complete body
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
f.message = false
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
return api.LocalReply
}
}
return api.StopAndBuffer
} else if f.userLevelConfig {
// Handle config POST request
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.path, f.req.Method, buffer.Bytes())
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
return api.LocalReply
} else if f.ratelimit {
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
api.LogDebugf("Not a tools call request, skipping ratelimit")
return api.Continue
}
parts := strings.Split(f.req.URL.Path, "/")
if len(parts) < 3 {
api.LogWarnf("Access denied: no valid uid found")
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
return api.LocalReply
}
}
}
return api.Continue
}
@@ -287,3 +308,14 @@ func (f *filter) OnDestroy(reason api.DestroyReason) {
}
}
}
// check if the request is a tools/call request
func checkJSONRPCMethod(body []byte, method string) bool {
var request mcp.CallToolRequest
if err := json.Unmarshal(body, &request); err != nil {
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
return true
}
return request.Method == method
}

View File

@@ -26,14 +26,14 @@ func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.Filter
}
// HandleConfigRequest processes configuration requests
func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body []byte) bool {
func (h *MCPConfigHandler) HandleConfigRequest(req *http.Request, body []byte) bool {
// Check if it's a configuration request
if !strings.HasSuffix(path, "/config") {
if !strings.HasSuffix(req.URL.Path, "/config") {
return false
}
// Extract serverName and uid from path
pathParts := strings.Split(strings.TrimSuffix(path, "/config"), "/")
pathParts := strings.Split(strings.TrimSuffix(req.URL.Path, "/config"), "/")
if len(pathParts) < 2 {
h.sendErrorResponse(http.StatusBadRequest, "INVALID_PATH", "Invalid path format")
return true
@@ -41,7 +41,7 @@ func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body
uid := pathParts[len(pathParts)-1]
serverName := pathParts[len(pathParts)-2]
switch method {
switch req.Method {
case http.MethodGet:
return h.handleGetConfig(serverName, uid)
case http.MethodPost:
@@ -70,10 +70,13 @@ func (h *MCPConfigHandler) handleGetConfig(serverName string, uid string) bool {
}
responseBytes, _ := json.Marshal(response)
headers := map[string][]string{
"Content-Type": {"application/json"},
}
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
http.StatusOK,
string(responseBytes),
nil, 0, "",
headers, 0, "",
)
return true
}
@@ -103,10 +106,13 @@ func (h *MCPConfigHandler) handleStoreConfig(serverName string, uid string, body
}
responseBytes, _ := json.Marshal(response)
headers := map[string][]string{
"Content-Type": {"application/json"},
}
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
http.StatusOK,
string(responseBytes),
nil, 0, "",
headers, 0, "",
)
return true
}
@@ -124,10 +130,13 @@ func (h *MCPConfigHandler) sendErrorResponse(status int, code string, message st
},
}
responseBytes, _ := json.Marshal(response)
headers := map[string][]string{
"Content-Type": {"application/json"},
}
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
status,
string(responseBytes),
nil, 0, "",
headers, 0, "",
)
}

View File

@@ -1,6 +1,7 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
@@ -9,6 +10,7 @@ import (
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
type MCPRatelimitHandler struct {
@@ -17,6 +19,7 @@ type MCPRatelimitHandler struct {
limit int // Maximum requests allowed per window
window int // Time window in seconds
whitelist []string // Whitelist of UIDs that bypass rate limiting
errorText string // Error text to be displayed
}
// MCPRatelimitConfig is the configuration for the rate limit handler
@@ -24,6 +27,7 @@ type MCPRatelimitConfig struct {
Limit int `json:"limit"`
Window int `json:"window"`
Whitelist []string `json:"white_list"` // List of UIDs that bypass rate limiting
ErrorText string `json:"error_text"` // Error text to be displayed
}
// NewMCPRatelimitHandler creates a new rate limit handler
@@ -33,6 +37,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil
Limit: 100,
Window: int(24 * time.Hour / time.Second), // 24 hours in seconds
Whitelist: []string{},
ErrorText: "API rate limit exceeded",
}
}
return &MCPRatelimitHandler{
@@ -41,6 +46,7 @@ func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.Fil
limit: conf.Limit,
window: conf.Window,
whitelist: conf.Whitelist,
errorText: conf.ErrorText,
}
}
@@ -62,8 +68,9 @@ type LimitContext struct {
Reset int // Time until reset in seconds
}
func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body []byte) bool {
parts := strings.Split(path, "/")
// TODO: needs to be refactored, rate limit should be registered as a request hook in MCP server
func (h *MCPRatelimitHandler) HandleRatelimit(req *http.Request, body []byte) bool {
parts := strings.Split(req.URL.Path, "/")
if len(parts) < 3 {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return false
@@ -106,13 +113,58 @@ func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body [
}
if context.Remaining < 0 {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusTooManyRequests, "", nil, 0, "")
// Create error response content
errorContent := []mcp.TextContent{
{
Type: "text",
Text: h.errorText,
},
}
// Create response result
result := map[string]interface{}{
"content": errorContent,
"isError": true,
}
// Create JSON-RPC response
id := getJSONPRCID(body)
response := mcp.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: id,
Result: result,
}
// Convert response to JSON
jsonResponse, err := json.Marshal(response)
if err != nil {
api.LogErrorf("Failed to marshal JSON response: %v", err)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
return false
}
// Send JSON-RPC response
sessionID := req.URL.Query().Get("sessionId")
if sessionID != "" {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusAccepted, string(jsonResponse), nil, 0, "")
} else {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, string(jsonResponse), nil, 0, "")
}
return false
}
return true
}
func getJSONPRCID(body []byte) mcp.RequestId {
baseMessage := struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
ID interface{} `json:"id,omitempty"`
}{}
if err := json.Unmarshal(body, &baseMessage); err != nil {
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
return ""
}
return baseMessage.ID
}
// parseRedisValue converts the value from Redis to an int
func parseRedisValue(value interface{}) int {
switch v := value.(type) {

View File

@@ -78,6 +78,7 @@ type MCPServer struct {
clientMu sync.Mutex // Separate mutex for client context
currentClient NotificationContext
initialized atomic.Bool // Use atomic for the initialized flag
destory chan struct{}
}
// serverKey is the context key for storing the server instance
@@ -226,6 +227,7 @@ func NewMCPServer(
prompts: nil,
logging: false,
},
destory: make(chan struct{}),
}
for _, opt := range opts {
@@ -826,6 +828,14 @@ func (s *MCPServer) handleNotification(
return nil
}
func (s *MCPServer) Close() {
close(s.destory)
}
func (s *MCPServer) GetDestoryChannel() chan struct{} {
return s.destory
}
func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage {
return mcp.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,

View File

@@ -179,10 +179,10 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through both the SSE connection and HTTP response.
func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) {
func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body json.RawMessage) int {
if r.Method != http.MethodPost {
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, fmt.Sprintf("Method %s not allowed", r.Method))
return
return http.StatusBadRequest
}
sessionID := r.URL.Query().Get("sessionId")
@@ -207,7 +207,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
// Process message through MCPServer
response := s.server.HandleMessage(ctx, body)
var status int
// Only send response if there is one (not for notifications)
if response != nil {
eventData, _ := json.Marshal(response)
@@ -219,15 +219,22 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
if publishErr != nil {
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
}
w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted
} else {
// support streamable http
w.WriteHeader(http.StatusOK)
status = http.StatusOK
}
// Send HTTP response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(response)
} else {
// For notifications, just send 202 Accepted with no body
w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted
}
return status
}
// writeJSONRPCError writes a JSON-RPC error response with the given error details.
@@ -242,3 +249,7 @@ func (s *SSEServer) writeJSONRPCError(
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
}
func (s *SSEServer) Close() {
s.server.Close()
}

View File

@@ -1,47 +1,148 @@
package gorm
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"gorm.io/driver/clickhouse"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DBClient is a struct to handle PostgreSQL connections and operations
// DBClient is a struct to handle database connections and operations
type DBClient struct {
db *gorm.DB
db *gorm.DB
dsn string
dbType string
reconnect chan struct{}
stop chan struct{}
panicCount int32 // Add panic counter
}
// NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database
func NewDBClient(dsn string, dbType string) (*DBClient, error) {
var db *gorm.DB
var err error
if dbType == "postgres" {
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
} else if dbType == "clickhouse" {
db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{})
} else if dbType == "mysql" {
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
} else if dbType == "sqlite" {
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
} else {
return nil, fmt.Errorf("unsupported database type %s", dbType)
}
// Connect to the database
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
// NewDBClient creates a new DBClient instance and establishes a connection to the database
func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient {
client := &DBClient{
dsn: dsn,
dbType: dbType,
reconnect: make(chan struct{}, 1),
stop: stop,
}
return &DBClient{db: db}, nil
// Start reconnection goroutine
go client.reconnectLoop()
// Try initial connection
if err := client.connect(); err != nil {
api.LogErrorf("Initial database connection failed: %v", err)
}
return client
}
func (c *DBClient) connect() error {
var db *gorm.DB
var err error
gormConfig := gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
switch c.dbType {
case "postgres":
db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig)
case "clickhouse":
db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig)
case "mysql":
db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig)
case "sqlite":
db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig)
default:
return fmt.Errorf("unsupported database type %s", c.dbType)
}
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
c.db = db
return nil
}
func (c *DBClient) reconnectLoop() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("Recovered from panic in reconnectLoop: %v", r)
// Increment panic counter
atomic.AddInt32(&c.panicCount, 1)
// If panic count exceeds threshold, stop trying to reconnect
if atomic.LoadInt32(&c.panicCount) > 3 {
api.LogErrorf("Too many panics in reconnectLoop, stopping reconnection attempts")
return
}
// Wait for a while before restarting
time.Sleep(5 * time.Second)
// Restart the reconnect loop
go c.reconnectLoop()
}
}()
ticker := time.NewTicker(30 * time.Second) // Try to reconnect every 30 seconds
defer ticker.Stop()
for {
select {
case <-c.stop:
api.LogInfof("Database %s connection closed", c.dbType)
return
case <-ticker.C:
if c.db == nil || c.Ping() != nil {
if err := c.connect(); err != nil {
api.LogErrorf("Database reconnection failed: %v", err)
} else {
api.LogInfof("Database reconnected successfully")
// Reset panic count on successful connection
atomic.StoreInt32(&c.panicCount, 0)
}
}
case <-c.reconnect:
if err := c.connect(); err != nil {
api.LogErrorf("Database reconnection failed: %v", err)
} else {
api.LogInfof("Database reconnected successfully")
// Reset panic count on successful connection
atomic.StoreInt32(&c.panicCount, 0)
}
}
}
}
// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps
func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) {
if c.db == nil {
// Trigger reconnection
select {
case c.reconnect <- struct{}{}:
default:
}
return nil, fmt.Errorf("database is not connected, attempting to reconnect")
}
rows, err := c.db.Raw(query, args...).Rows()
if err != nil {
// If execution fails, connection might be lost, trigger reconnection
select {
case c.reconnect <- struct{}{}:
default:
}
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
defer rows.Close()
@@ -88,3 +189,21 @@ func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]i
return results, nil
}
func (c *DBClient) Ping() error {
if c.db == nil {
return fmt.Errorf("database connection is nil")
}
// Use context to set timeout
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Try to ping the database
sqlDB, err := c.db.DB()
if err != nil {
return fmt.Errorf("failed to get underlying *sql.DB: %v", err)
}
return sqlDB.PingContext(ctx)
}

View File

@@ -16,8 +16,9 @@ func init() {
}
type DBConfig struct {
dbType string
dsn string
dbType string
dsn string
description string
}
func (c *DBConfig) ParseConfig(config map[string]any) error {
@@ -33,6 +34,10 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
}
c.dbType = dbType
api.LogDebugf("DBConfig ParseConfig: %+v", config)
c.description, ok = config["description"].(string)
if !ok {
c.description = ""
}
return nil
}
@@ -43,14 +48,10 @@ func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
)
dbClient, err := NewDBClient(c.dsn, c.dbType)
if err != nil {
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
}
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
// Add query tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s", c.dbType), GetQueryToolSchema()),
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()),
HandleQueryTool(dbClient),
)