diff --git a/plugins/golang-filter/Dockerfile b/plugins/golang-filter/Dockerfile index 1f08417ae..742c9f481 100644 --- a/plugins/golang-filter/Dockerfile +++ b/plugins/golang-filter/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23 AS golang-base +FROM golang:1.23-bullseye AS golang-base ARG GOPROXY ARG GO_FILTER_NAME diff --git a/plugins/golang-filter/mcp-server/config.go b/plugins/golang-filter/mcp-server/config.go index f8adbc955..9683ef85c 100644 --- a/plugins/golang-filter/mcp-server/config.go +++ b/plugins/golang-filter/mcp-server/config.go @@ -1,33 +1,28 @@ package main import ( - "errors" "fmt" xds "github.com/cncf/xds/go/xds/type/v3" - "github.com/mark3labs/mcp-go/mcp" "google.golang.org/protobuf/types/known/anypb" + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" + _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm" // 导入gorm包以执行其init函数 "github.com/envoyproxy/envoy/contrib/golang/common/go/api" envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http" - "github.com/envoyproxy/envoy/examples/golang-http/simple/internal" - "github.com/envoyproxy/envoy/examples/golang-http/simple/servers/gorm" ) const Name = "mcp-server" -const SCHEME_PATH = "scheme" func init() { envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{}) } type config struct { - echoBody string - // other fields - dbClient *gorm.DBClient - redisClient *internal.RedisClient - stopChan chan struct{} - SSEServer *internal.SSEServer + ssePathSuffix string + redisClient *internal.RedisClient + stopChan chan struct{} + servers []*internal.SSEServer } type parser struct { @@ -39,34 +34,68 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int if err := any.UnmarshalTo(configStruct); err != nil { return nil, err } - v := configStruct.Value + conf := &config{} - - dsn, ok := v.AsMap()["dsn"].(string) - if !ok { - return nil, errors.New("missing dsn") - } - - dbType, ok := v.AsMap()["dbType"].(string) - if !ok { - return nil, errors.New("missing database type") - } - - dbClient, err := gorm.NewDBClient(dsn, dbType) - if err != nil { - return nil, fmt.Errorf("failed to initialize DBClient: %w", err) - } - conf.dbClient = dbClient - conf.stopChan = make(chan struct{}) - redisClient, err := internal.NewRedisClient("localhost:6379", conf.stopChan) + + redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("redis config is not set") + } + + redisConfig, err := internal.ParseRedisConfig(redisConfigMap) + if err != nil { + return nil, fmt.Errorf("failed to parse redis config: %w", err) + } + + redisClient, err := internal.NewRedisClient(redisConfig, conf.stopChan) if err != nil { return nil, fmt.Errorf("failed to initialize RedisClient: %w", err) } conf.redisClient = redisClient - conf.SSEServer = internal.NewSSEServer(NewServer(conf.dbClient), internal.WithRedisClient(redisClient)) + ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string) + if !ok { + return nil, fmt.Errorf("sse path suffix is not set") + } + conf.ssePathSuffix = ssePathSuffix + + serverConfigs, ok := v.AsMap()["servers"].([]interface{}) + if !ok { + api.LogInfo("No servers are configured") + return conf, nil + } + + for _, serverConfig := range serverConfigs { + serverConfigMap, ok := serverConfig.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("server config must be an object") + } + serverType, ok := serverConfigMap["type"].(string) + if !ok { + return nil, fmt.Errorf("server type is not set") + } + serverPath, ok := serverConfigMap["path"].(string) + if !ok { + return nil, fmt.Errorf("server %s path is not set", serverType) + } + server := internal.GlobalRegistry.GetServer(serverType) + + if server == nil { + return nil, fmt.Errorf("server %s is not registered", serverType) + } + server.ParseConfig(serverConfigMap) + serverInstance, err := server.NewServer() + if err != nil { + return nil, fmt.Errorf("failed to initialize DBServer: %w", err) + } + conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance, + internal.WithRedisClient(redisClient), + internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)), + internal.WithMessageEndpoint(serverPath))) + api.LogInfo(fmt.Sprintf("Registered MCP Server: %s", serverType)) + } return conf, nil } @@ -75,15 +104,15 @@ func (p *parser) Merge(parent interface{}, child interface{}) interface{} { childConfig := child.(*config) newConfig := *parentConfig - if childConfig.echoBody != "" { - newConfig.echoBody = childConfig.echoBody - } - if childConfig.dbClient != nil { - newConfig.dbClient = childConfig.dbClient - } if childConfig.redisClient != nil { newConfig.redisClient = childConfig.redisClient } + if childConfig.ssePathSuffix != "" { + newConfig.ssePathSuffix = childConfig.ssePathSuffix + } + if childConfig.servers != nil { + newConfig.servers = append(newConfig.servers, childConfig.servers...) + } return &newConfig } @@ -98,25 +127,4 @@ func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea } } -func NewServer(dbClient *gorm.DBClient) *internal.MCPServer { - mcpServer := internal.NewMCPServer( - "mcp-server-envoy-poc", - "1.0.0", - ) - - // Add query tool - mcpServer.AddTool( - mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", gorm.GetQueryToolSchema()), - gorm.HandleQueryTool(dbClient), - ) - api.LogInfo("Added query tool successfully") - - // Add favorite files tool - mcpServer.AddTool( - mcp.NewToolWithRawSchema("author_favorite_files", "Favorite files for an author", gorm.GetFavoriteToolSchema()), - gorm.HandleFavoriteTool(dbClient), - ) - return mcpServer -} - func main() {} diff --git a/plugins/golang-filter/mcp-server/filter.go b/plugins/golang-filter/mcp-server/filter.go index 0762d1241..bc17db2df 100644 --- a/plugins/golang-filter/mcp-server/filter.go +++ b/plugins/golang-filter/mcp-server/filter.go @@ -17,9 +17,10 @@ type filter struct { path string config *config - req *http.Request - sse bool - message bool + req *http.Request + sse bool + message bool + bodyBuffer []byte } // Callbacks which are called in request path @@ -29,38 +30,39 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. parsedURL, _ := url.Parse(fullPath) f.path = parsedURL.Path method, _ := header.Get(":method") - api.LogInfo(f.path) - if f.path == f.config.SSEServer.SSEEndpoint { - if method != http.MethodGet { - f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") - } else { - f.sse = true - body := "SSE connection create" - f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") - } - api.LogInfo("SSE connection started") - return api.LocalReply - } else if f.path == f.config.SSEServer.MessageEndpoint { - if method != http.MethodPost { - f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") - } - // Create a new http.Request object - f.req = &http.Request{ - Method: method, - URL: parsedURL, - Header: make(http.Header), - } - api.LogInfof("Message request: %v", parsedURL) - // Copy headers from api.RequestHeaderMap to http.Header - header.Range(func(key, value string) bool { - f.req.Header.Add(key, value) - return true - }) - f.message = true - if endStream { - return api.Continue - } else { - return api.StopAndBuffer + for _, server := range f.config.servers { + if f.path == server.GetSSEEndpoint() { + if method != http.MethodGet { + f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") + } else { + f.sse = true + body := "SSE connection create" + f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") + } + api.LogInfof("%s SSE connection started", server.GetServerName()) + return api.LocalReply + } else if f.path == server.GetMessageEndpoint() { + if method != http.MethodPost { + f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") + } + // Create a new http.Request object + f.req = &http.Request{ + Method: method, + URL: parsedURL, + Header: make(http.Header), + } + api.LogDebugf("Message request: %v", parsedURL) + // Copy headers from api.RequestHeaderMap to http.Header + header.Range(func(key, value string) bool { + f.req.Header.Add(key, value) + return true + }) + f.message = true + if endStream { + return api.Continue + } else { + return api.StopAndBuffer + } } } if endStream { @@ -73,18 +75,25 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api. // DecodeData might be called multiple times during handling the request body. // The endStream is true when handling the last piece of the body. func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType { - api.LogInfo("Message DecodeData") - // support suspending & resuming the filter in a background goroutine - api.LogInfof("DecodeData: {%v}", buffer) if f.message { - // Create a response recorder to capture the response - recorder := httptest.NewRecorder() - // Call the handleMessage method of SSEServer - f.config.SSEServer.HandleMessage(recorder, f.req, buffer.Bytes()) - f.message = false - api.LogInfof("Message DecodeData SendLocalReply %v", recorder) - f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "") - return api.LocalReply + f.bodyBuffer = append(f.bodyBuffer, buffer.Bytes()...) + + if endStream { + for _, server := range f.config.servers { + if f.path == server.GetMessageEndpoint() { + // Create a response recorder to capture the response + recorder := httptest.NewRecorder() + // Call the handleMessage method of SSEServer with complete body + server.HandleMessage(recorder, f.req, f.bodyBuffer) + f.message = false + // clear buffer + f.bodyBuffer = nil + f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "") + return api.LocalReply + } + } + } + return api.StopAndBuffer } return api.Continue } @@ -98,23 +107,21 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api header.Set("Connection", "keep-alive") header.Set("Access-Control-Allow-Origin", "*") header.Del("Content-Length") - api.LogInfo("SSE connection header set") return api.Continue } return api.Continue } -// TODO: 连接多种数据库 -// TODO: 多种存储类型 -// TODO: 数据库多个实例 // EncodeData might be called multiple times during handling the response body. // The endStream is true when handling the last piece of the body. func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType { - if f.sse { - //TODO: buffer cleanup - f.config.SSEServer.HandleSSE(f.callbacks) - f.sse = false - return api.Running + for _, server := range f.config.servers { + if f.sse { + buffer.Reset() + server.HandleSSE(f.callbacks) + f.sse = false + return api.Running + } } return api.Continue } diff --git a/plugins/golang-filter/mcp-server/go.mod b/plugins/golang-filter/mcp-server/go.mod index 860781592..b6ef3ad71 100644 --- a/plugins/golang-filter/mcp-server/go.mod +++ b/plugins/golang-filter/mcp-server/go.mod @@ -1,16 +1,18 @@ -module github.com/envoyproxy/envoy/examples/golang-http/simple +module github.com/alibaba/higress/plugins/golang-filter/mcp-server go 1.23 require ( - github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993 - google.golang.org/protobuf v1.36.5 github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 + github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.6.0 github.com/mark3labs/mcp-go v0.12.0 + google.golang.org/protobuf v1.36.5 gorm.io/driver/clickhouse v0.6.1 + gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 ) @@ -24,6 +26,7 @@ require ( github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect @@ -32,6 +35,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.8 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/paulmach/orb v0.11.1 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/plugins/golang-filter/mcp-server/go.sum b/plugins/golang-filter/mcp-server/go.sum index 2508afce7..5e1c7cf05 100644 --- a/plugins/golang-filter/mcp-server/go.sum +++ b/plugins/golang-filter/mcp-server/go.sum @@ -27,6 +27,8 @@ github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AY github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -64,6 +66,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo= github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -167,7 +171,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/clickhouse v0.6.1 h1:t7JMB6sLBXxN8hEO6RdzCbJCwq/jAEVZdwXlmQs1Sd4= gorm.io/driver/clickhouse v0.6.1/go.mod h1:riMYpJcGZ3sJ/OAZZ1rEP1j/Y0H6cByOAnwz7fo2AyM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= +gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= +gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/plugins/golang-filter/mcp-server/internal/redis.go b/plugins/golang-filter/mcp-server/internal/redis.go index c62848087..36bdbdc2a 100644 --- a/plugins/golang-filter/mcp-server/internal/redis.go +++ b/plugins/golang-filter/mcp-server/internal/redis.go @@ -3,26 +3,62 @@ package internal import ( "context" "fmt" - "log" "time" "github.com/envoyproxy/envoy/contrib/golang/common/go/api" "github.com/go-redis/redis/v8" ) +type RedisConfig struct { + Address string + Username string + Password string + DB int +} + +func ParseRedisConfig(config map[string]any) (*RedisConfig, error) { + c := &RedisConfig{} + + // address is required + addr, ok := config["address"].(string) + if !ok { + return nil, fmt.Errorf("address is required and must be a string") + } + c.Address = addr + + // username is optional + if username, ok := config["username"].(string); ok { + c.Username = username + } + + // password is optional + if password, ok := config["password"].(string); ok { + c.Password = password + } + + // db is optional, default to 0 + if db, ok := config["db"].(int); ok { + c.DB = db + } + + return c, nil +} + // RedisClient is a struct to handle Redis connections and operations type RedisClient struct { client *redis.Client ctx context.Context stopChan chan struct{} + config *RedisConfig } // NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server -func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error) { +func NewRedisClient(config *RedisConfig, stopChan chan struct{}) (*RedisClient, error) { client := redis.NewClient(&redis.Options{ - Addr: address, - Password: "", // no password set - DB: 0, // use default DB + Addr: config.Address, + Username: config.Username, + Password: config.Password, + DB: config.DB, }) // Ping the Redis server to check the connection @@ -30,17 +66,71 @@ func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error if err != nil { return nil, fmt.Errorf("failed to connect to Redis: %w", err) } - log.Printf("Connected to Redis: %s", pong) + api.LogInfof("Connected to Redis: %s", pong) - return &RedisClient{ + redisClient := &RedisClient{ client: client, ctx: context.Background(), stopChan: stopChan, - }, nil + config: config, + } + + // Start keep-alive check + go redisClient.keepAlive() + + return redisClient, nil +} + +// keepAlive periodically checks Redis connection and attempts to reconnect if needed +func (r *RedisClient) keepAlive() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-r.stopChan: + return + case <-ticker.C: + if err := r.checkConnection(); err != nil { + api.LogErrorf("Redis connection check failed: %v", err) + if err := r.reconnect(); err != nil { + api.LogErrorf("Failed to reconnect to Redis: %v", err) + } + } + } + } +} + +// checkConnection verifies if the Redis connection is still alive +func (r *RedisClient) checkConnection() error { + _, err := r.client.Ping(r.ctx).Result() + return err +} + +// reconnect attempts to establish a new connection to Redis +func (r *RedisClient) reconnect() error { + // Close the old client + if err := r.client.Close(); err != nil { + api.LogErrorf("Error closing old Redis connection: %v", err) + } + + // Create new client + r.client = redis.NewClient(&redis.Options{ + Addr: r.config.Address, + Username: r.config.Username, + Password: r.config.Password, + DB: r.config.DB, + }) + + // Test the new connection + if err := r.checkConnection(); err != nil { + return fmt.Errorf("failed to reconnect to Redis: %w", err) + } + + api.LogInfof("Successfully reconnected to Redis") + return nil } -// TODO: redis keep alive check -// TODO: redis pub sub memory limit // Publish publishes a message to a Redis channel func (r *RedisClient) Publish(channel string, message string) error { err := r.client.Publish(r.ctx, channel, message).Err() @@ -59,19 +149,31 @@ func (r *RedisClient) Subscribe(channel string, callback func(message string)) e } go func() { - defer pubsub.Close() + defer func() { + pubsub.Close() + api.LogInfof("Closed subscription to channel %s", channel) + }() + + ch := pubsub.Channel() for { select { case <-r.stopChan: - api.LogDebugf("Stopping subscription to channel %s", channel) + api.LogInfof("Stopping subscription to channel %s", channel) return - default: - msg, err := pubsub.ReceiveMessage(r.ctx) - if err != nil { - log.Printf("Error receiving message: %v", err) + case msg, ok := <-ch: + if !ok { + api.LogInfof("Redis subscription channel closed for %s", channel) return } - callback(msg.Payload) + + func() { + defer func() { + if r := recover(); r != nil { + api.LogErrorf("Recovered from panic in callback: %v", r) + } + }() + callback(msg.Payload) + }() } } }() diff --git a/plugins/golang-filter/mcp-server/internal/registry.go b/plugins/golang-filter/mcp-server/internal/registry.go new file mode 100644 index 000000000..498bebae6 --- /dev/null +++ b/plugins/golang-filter/mcp-server/internal/registry.go @@ -0,0 +1,26 @@ +package internal + +var GlobalRegistry = NewServerRegistry() + +type Server interface { + ParseConfig(config map[string]any) error + NewServer() (*MCPServer, error) +} + +type ServerRegistry struct { + servers map[string]Server +} + +func NewServerRegistry() *ServerRegistry { + return &ServerRegistry{ + servers: make(map[string]Server), + } +} + +func (r *ServerRegistry) RegisterServer(name string, server Server) { + r.servers[name] = server +} + +func (r *ServerRegistry) GetServer(name string) Server { + return r.servers[name] +} diff --git a/plugins/golang-filter/mcp-server/internal/sse.go b/plugins/golang-filter/mcp-server/internal/sse.go index 51e035432..c50fc41ca 100644 --- a/plugins/golang-filter/mcp-server/internal/sse.go +++ b/plugins/golang-filter/mcp-server/internal/sse.go @@ -15,12 +15,24 @@ import ( type SSEServer struct { server *MCPServer baseURL string - MessageEndpoint string - SSEEndpoint string + messageEndpoint string + sseEndpoint string sessions map[string]bool redisClient *RedisClient // Redis client for pub/sub } +func (s *SSEServer) GetMessageEndpoint() string { + return s.messageEndpoint +} + +func (s *SSEServer) GetSSEEndpoint() string { + return s.sseEndpoint +} + +func (s *SSEServer) GetServerName() string { + return s.server.name +} + // Option defines a function type for configuring SSEServer type Option func(*SSEServer) @@ -34,14 +46,14 @@ func WithBaseURL(baseURL string) Option { // WithMessageEndpoint sets the message endpoint path func WithMessageEndpoint(endpoint string) Option { return func(s *SSEServer) { - s.MessageEndpoint = endpoint + s.messageEndpoint = endpoint } } // WithSSEEndpoint sets the SSE endpoint path func WithSSEEndpoint(endpoint string) Option { return func(s *SSEServer) { - s.SSEEndpoint = endpoint + s.sseEndpoint = endpoint } } @@ -55,8 +67,8 @@ func WithRedisClient(redisClient *RedisClient) Option { func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer { s := &SSEServer{ server: server, - SSEEndpoint: "/sse", - MessageEndpoint: "/message", + sseEndpoint: "/sse", + messageEndpoint: "/message", sessions: make(map[string]bool), } @@ -84,7 +96,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler) { messageEndpoint := fmt.Sprintf( "%s%s?sessionId=%s", s.baseURL, - s.MessageEndpoint, + s.messageEndpoint, sessionID, ) @@ -135,10 +147,10 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j } sessionID := r.URL.Query().Get("sessionId") - // if sessionID == "" { - // s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") - // return - // } + if sessionID == "" { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") + return + } // Set the client context in the server before handling the message ctx := s.server.WithContext(r.Context(), NotificationContext{ @@ -146,21 +158,13 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j SessionID: sessionID, }) - //TODO: sessions + //TODO: check session id // _, ok := s.sessions.Load(sessionID) // if !ok { // s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") // return // } - //TODO - // // Parse message as raw JSON - // var rawMessage json.RawMessage - // if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { - // s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") - // return - // } - // Process message through MCPServer response := s.server.HandleMessage(ctx, body) @@ -198,15 +202,3 @@ func (s *SSEServer) writeJSONRPCError( w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(response) } - -// // ServeHTTP implements the http.Handler interface. -// func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { -// switch r.URL.Path { -// case s.sseEndpoint: -// s.handleSSE(w, r) -// case s.messageEndpoint: -// s.handleMessage(w, r) -// default: -// http.NotFound(w, r) -// } -// } diff --git a/plugins/golang-filter/mcp-server/servers/gorm/db.go b/plugins/golang-filter/mcp-server/servers/gorm/db.go index 5b53b509c..db9dc90dc 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/db.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/db.go @@ -2,13 +2,12 @@ package gorm import ( "fmt" - "log" - "os" "gorm.io/driver/clickhouse" + "gorm.io/driver/mysql" "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/logger" ) // DBClient is a struct to handle PostgreSQL connections and operations @@ -18,24 +17,16 @@ type DBClient struct { // NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database func NewDBClient(dsn string, dbType string) (*DBClient, error) { - // Configure GORM logger - newLogger := logger.New( - log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer - logger.Config{ - SlowThreshold: 0, // Slow SQL threshold - LogLevel: logger.Info, // Log level - IgnoreRecordNotFoundError: false, // Ignore ErrRecordNotFound error for logger - Colorful: true, // Disable color - }, - ) var db *gorm.DB var err error if dbType == "postgres" { - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ - Logger: newLogger, - }) + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) } else if dbType == "clickhouse" { db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{}) + } else if dbType == "mysql" { + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) + } else if dbType == "sqlite" { + db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) } else { return nil, fmt.Errorf("unsupported database type") } diff --git a/plugins/golang-filter/mcp-server/servers/gorm/server.go b/plugins/golang-filter/mcp-server/servers/gorm/server.go new file mode 100644 index 000000000..3d8f50d9d --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/gorm/server.go @@ -0,0 +1,60 @@ +package gorm + +import ( + "errors" + "fmt" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" + "github.com/mark3labs/mcp-go/mcp" +) + +func init() { + internal.GlobalRegistry.RegisterServer("database", &DBConfig{}) +} + +type DBConfig struct { + name string + dbType string + dsn string +} + +func (c *DBConfig) ParseConfig(config map[string]any) error { + name, ok := config["name"].(string) + if !ok { + return errors.New("missing servername") + } + c.name = name + + dsn, ok := config["dsn"].(string) + if !ok { + return errors.New("missing dsn") + } + c.dsn = dsn + + dbType, ok := config["dbType"].(string) + if !ok { + return errors.New("missing database type") + } + c.dbType = dbType + return nil +} + +func (c *DBConfig) NewServer() (*internal.MCPServer, error) { + mcpServer := internal.NewMCPServer( + c.name, + "1.0.0", + ) + + dbClient, err := NewDBClient(c.dsn, c.dbType) + if err != nil { + return nil, fmt.Errorf("failed to initialize DBClient: %w", err) + } + + // Add query tool + mcpServer.AddTool( + mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", GetQueryToolSchema()), + HandleQueryTool(dbClient), + ) + + return mcpServer, nil +} diff --git a/plugins/golang-filter/mcp-server/servers/gorm/tools.go b/plugins/golang-filter/mcp-server/servers/gorm/tools.go index cbeefa25f..3cebe1949 100644 --- a/plugins/golang-filter/mcp-server/servers/gorm/tools.go +++ b/plugins/golang-filter/mcp-server/servers/gorm/tools.go @@ -5,41 +5,10 @@ import ( "encoding/json" "fmt" - "github.com/envoyproxy/envoy/examples/golang-http/simple/internal" + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" "github.com/mark3labs/mcp-go/mcp" ) -const favoriteFilesTemplate = ` -WITH current_files AS ( - SELECT path - FROM ( - SELECT - old_path AS path, - max(time) AS last_time, - 2 AS change_type - FROM git.file_changes - GROUP BY old_path - UNION ALL - SELECT - path, - max(time) AS last_time, - argMax(change_type, time) AS change_type - FROM git.file_changes - GROUP BY path - ) - GROUP BY path - HAVING (argMax(change_type, last_time) != 2) AND (NOT match(path, '(^dbms/)|(^libs/)|(^tests/testflows/)|(^programs/server/store/)')) - ORDER BY path ASC -) -SELECT - path, - count() AS c -FROM git.file_changes -WHERE (author = '%s') AND (path IN (current_files)) -GROUP BY path -ORDER BY c DESC -LIMIT 10` - // HandleQueryTool handles SQL query execution func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -70,37 +39,6 @@ func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc { } } -// HandleFavoriteTool handles author's favorite files query -func HandleFavoriteTool(dbClient *DBClient) internal.ToolHandlerFunc { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments - author, ok := arguments["author"].(string) - if !ok { - return nil, fmt.Errorf("invalid author argument") - } - query := fmt.Sprintf(favoriteFilesTemplate, author) - - results, err := dbClient.ExecuteSQL(query) - if err != nil { - return nil, fmt.Errorf("failed to execute SQL query: %w", err) - } - - jsonData, err := json.Marshal(results) - if err != nil { - return nil, fmt.Errorf("failed to marshal SQL results: %w", err) - } - - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: string(jsonData), - }, - }, - }, nil - } -} - // GetQueryToolSchema returns the schema for query tool func GetQueryToolSchema() json.RawMessage { return json.RawMessage(` @@ -115,18 +53,3 @@ func GetQueryToolSchema() json.RawMessage { } `) } - -// GetFavoriteToolSchema returns the schema for favorite files tool -func GetFavoriteToolSchema() json.RawMessage { - return json.RawMessage(` - { - "type": "object", - "properties": { - "author": { - "type": "string", - "description": "the author name" - } - } - } - `) -}