feat: add config parse in mcp server (#1944)

This commit is contained in:
Jingze
2025-03-24 17:52:16 +08:00
committed by GitHub
parent 9bde0dfb46
commit f5d20b72e0
11 changed files with 382 additions and 260 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.23 AS golang-base
FROM golang:1.23-bullseye AS golang-base
ARG GOPROXY
ARG GO_FILTER_NAME

View File

@@ -1,33 +1,28 @@
package main
import (
"errors"
"fmt"
xds "github.com/cncf/xds/go/xds/type/v3"
"github.com/mark3labs/mcp-go/mcp"
"google.golang.org/protobuf/types/known/anypb"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm" // 导入gorm包以执行其init函数
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
"github.com/envoyproxy/envoy/examples/golang-http/simple/internal"
"github.com/envoyproxy/envoy/examples/golang-http/simple/servers/gorm"
)
const Name = "mcp-server"
const SCHEME_PATH = "scheme"
func init() {
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
}
type config struct {
echoBody string
// other fields
dbClient *gorm.DBClient
redisClient *internal.RedisClient
stopChan chan struct{}
SSEServer *internal.SSEServer
ssePathSuffix string
redisClient *internal.RedisClient
stopChan chan struct{}
servers []*internal.SSEServer
}
type parser struct {
@@ -39,34 +34,68 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
if err := any.UnmarshalTo(configStruct); err != nil {
return nil, err
}
v := configStruct.Value
conf := &config{}
dsn, ok := v.AsMap()["dsn"].(string)
if !ok {
return nil, errors.New("missing dsn")
}
dbType, ok := v.AsMap()["dbType"].(string)
if !ok {
return nil, errors.New("missing database type")
}
dbClient, err := gorm.NewDBClient(dsn, dbType)
if err != nil {
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
}
conf.dbClient = dbClient
conf.stopChan = make(chan struct{})
redisClient, err := internal.NewRedisClient("localhost:6379", conf.stopChan)
redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("redis config is not set")
}
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
if err != nil {
return nil, fmt.Errorf("failed to parse redis config: %w", err)
}
redisClient, err := internal.NewRedisClient(redisConfig, conf.stopChan)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
}
conf.redisClient = redisClient
conf.SSEServer = internal.NewSSEServer(NewServer(conf.dbClient), internal.WithRedisClient(redisClient))
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
if !ok {
return nil, fmt.Errorf("sse path suffix is not set")
}
conf.ssePathSuffix = ssePathSuffix
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
if !ok {
api.LogInfo("No servers are configured")
return conf, nil
}
for _, serverConfig := range serverConfigs {
serverConfigMap, ok := serverConfig.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("server config must be an object")
}
serverType, ok := serverConfigMap["type"].(string)
if !ok {
return nil, fmt.Errorf("server type is not set")
}
serverPath, ok := serverConfigMap["path"].(string)
if !ok {
return nil, fmt.Errorf("server %s path is not set", serverType)
}
server := internal.GlobalRegistry.GetServer(serverType)
if server == nil {
return nil, fmt.Errorf("server %s is not registered", serverType)
}
server.ParseConfig(serverConfigMap)
serverInstance, err := server.NewServer()
if err != nil {
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
}
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
internal.WithRedisClient(redisClient),
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
internal.WithMessageEndpoint(serverPath)))
api.LogInfo(fmt.Sprintf("Registered MCP Server: %s", serverType))
}
return conf, nil
}
@@ -75,15 +104,15 @@ func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
childConfig := child.(*config)
newConfig := *parentConfig
if childConfig.echoBody != "" {
newConfig.echoBody = childConfig.echoBody
}
if childConfig.dbClient != nil {
newConfig.dbClient = childConfig.dbClient
}
if childConfig.redisClient != nil {
newConfig.redisClient = childConfig.redisClient
}
if childConfig.ssePathSuffix != "" {
newConfig.ssePathSuffix = childConfig.ssePathSuffix
}
if childConfig.servers != nil {
newConfig.servers = append(newConfig.servers, childConfig.servers...)
}
return &newConfig
}
@@ -98,25 +127,4 @@ func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
}
}
func NewServer(dbClient *gorm.DBClient) *internal.MCPServer {
mcpServer := internal.NewMCPServer(
"mcp-server-envoy-poc",
"1.0.0",
)
// Add query tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", gorm.GetQueryToolSchema()),
gorm.HandleQueryTool(dbClient),
)
api.LogInfo("Added query tool successfully")
// Add favorite files tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("author_favorite_files", "Favorite files for an author", gorm.GetFavoriteToolSchema()),
gorm.HandleFavoriteTool(dbClient),
)
return mcpServer
}
func main() {}

View File

@@ -17,9 +17,10 @@ type filter struct {
path string
config *config
req *http.Request
sse bool
message bool
req *http.Request
sse bool
message bool
bodyBuffer []byte
}
// Callbacks which are called in request path
@@ -29,38 +30,39 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
parsedURL, _ := url.Parse(fullPath)
f.path = parsedURL.Path
method, _ := header.Get(":method")
api.LogInfo(f.path)
if f.path == f.config.SSEServer.SSEEndpoint {
if method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.sse = true
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
api.LogInfo("SSE connection started")
return api.LocalReply
} else if f.path == f.config.SSEServer.MessageEndpoint {
if method != http.MethodPost {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
}
// Create a new http.Request object
f.req = &http.Request{
Method: method,
URL: parsedURL,
Header: make(http.Header),
}
api.LogInfof("Message request: %v", parsedURL)
// Copy headers from api.RequestHeaderMap to http.Header
header.Range(func(key, value string) bool {
f.req.Header.Add(key, value)
return true
})
f.message = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
for _, server := range f.config.servers {
if f.path == server.GetSSEEndpoint() {
if method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.sse = true
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
api.LogInfof("%s SSE connection started", server.GetServerName())
return api.LocalReply
} else if f.path == server.GetMessageEndpoint() {
if method != http.MethodPost {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
}
// Create a new http.Request object
f.req = &http.Request{
Method: method,
URL: parsedURL,
Header: make(http.Header),
}
api.LogDebugf("Message request: %v", parsedURL)
// Copy headers from api.RequestHeaderMap to http.Header
header.Range(func(key, value string) bool {
f.req.Header.Add(key, value)
return true
})
f.message = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
}
if endStream {
@@ -73,18 +75,25 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
// DecodeData might be called multiple times during handling the request body.
// The endStream is true when handling the last piece of the body.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
api.LogInfo("Message DecodeData")
// support suspending & resuming the filter in a background goroutine
api.LogInfof("DecodeData: {%v}", buffer)
if f.message {
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer
f.config.SSEServer.HandleMessage(recorder, f.req, buffer.Bytes())
f.message = false
api.LogInfof("Message DecodeData SendLocalReply %v", recorder)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
return api.LocalReply
f.bodyBuffer = append(f.bodyBuffer, buffer.Bytes()...)
if endStream {
for _, server := range f.config.servers {
if f.path == server.GetMessageEndpoint() {
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer with complete body
server.HandleMessage(recorder, f.req, f.bodyBuffer)
f.message = false
// clear buffer
f.bodyBuffer = nil
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
return api.LocalReply
}
}
}
return api.StopAndBuffer
}
return api.Continue
}
@@ -98,23 +107,21 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
header.Set("Connection", "keep-alive")
header.Set("Access-Control-Allow-Origin", "*")
header.Del("Content-Length")
api.LogInfo("SSE connection header set")
return api.Continue
}
return api.Continue
}
// TODO: 连接多种数据库
// TODO: 多种存储类型
// TODO: 数据库多个实例
// EncodeData might be called multiple times during handling the response body.
// The endStream is true when handling the last piece of the body.
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.sse {
//TODO: buffer cleanup
f.config.SSEServer.HandleSSE(f.callbacks)
f.sse = false
return api.Running
for _, server := range f.config.servers {
if f.sse {
buffer.Reset()
server.HandleSSE(f.callbacks)
f.sse = false
return api.Running
}
}
return api.Continue
}

View File

@@ -1,16 +1,18 @@
module github.com/envoyproxy/envoy/examples/golang-http/simple
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
go 1.23
require (
github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993
google.golang.org/protobuf v1.36.5
github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42
github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
github.com/mark3labs/mcp-go v0.12.0
google.golang.org/protobuf v1.36.5
gorm.io/driver/clickhouse v0.6.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.11
gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12
)
@@ -24,6 +26,7 @@ require (
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/go-faster/city v1.0.1 // indirect
github.com/go-faster/errors v0.7.1 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/hashicorp/go-version v1.6.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
@@ -32,6 +35,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/paulmach/orb v0.11.1 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pkg/errors v0.9.1 // indirect

View File

@@ -27,6 +27,8 @@ github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AY
github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
@@ -64,6 +66,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo=
github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
@@ -167,7 +171,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/clickhouse v0.6.1 h1:t7JMB6sLBXxN8hEO6RdzCbJCwq/jAEVZdwXlmQs1Sd4=
gorm.io/driver/clickhouse v0.6.1/go.mod h1:riMYpJcGZ3sJ/OAZZ1rEP1j/Y0H6cByOAnwz7fo2AyM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314=
gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=

View File

@@ -3,26 +3,62 @@ package internal
import (
"context"
"fmt"
"log"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/go-redis/redis/v8"
)
type RedisConfig struct {
Address string
Username string
Password string
DB int
}
func ParseRedisConfig(config map[string]any) (*RedisConfig, error) {
c := &RedisConfig{}
// address is required
addr, ok := config["address"].(string)
if !ok {
return nil, fmt.Errorf("address is required and must be a string")
}
c.Address = addr
// username is optional
if username, ok := config["username"].(string); ok {
c.Username = username
}
// password is optional
if password, ok := config["password"].(string); ok {
c.Password = password
}
// db is optional, default to 0
if db, ok := config["db"].(int); ok {
c.DB = db
}
return c, nil
}
// RedisClient is a struct to handle Redis connections and operations
type RedisClient struct {
client *redis.Client
ctx context.Context
stopChan chan struct{}
config *RedisConfig
}
// NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server
func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error) {
func NewRedisClient(config *RedisConfig, stopChan chan struct{}) (*RedisClient, error) {
client := redis.NewClient(&redis.Options{
Addr: address,
Password: "", // no password set
DB: 0, // use default DB
Addr: config.Address,
Username: config.Username,
Password: config.Password,
DB: config.DB,
})
// Ping the Redis server to check the connection
@@ -30,17 +66,71 @@ func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
log.Printf("Connected to Redis: %s", pong)
api.LogInfof("Connected to Redis: %s", pong)
return &RedisClient{
redisClient := &RedisClient{
client: client,
ctx: context.Background(),
stopChan: stopChan,
}, nil
config: config,
}
// Start keep-alive check
go redisClient.keepAlive()
return redisClient, nil
}
// keepAlive periodically checks Redis connection and attempts to reconnect if needed
func (r *RedisClient) keepAlive() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.stopChan:
return
case <-ticker.C:
if err := r.checkConnection(); err != nil {
api.LogErrorf("Redis connection check failed: %v", err)
if err := r.reconnect(); err != nil {
api.LogErrorf("Failed to reconnect to Redis: %v", err)
}
}
}
}
}
// checkConnection verifies if the Redis connection is still alive
func (r *RedisClient) checkConnection() error {
_, err := r.client.Ping(r.ctx).Result()
return err
}
// reconnect attempts to establish a new connection to Redis
func (r *RedisClient) reconnect() error {
// Close the old client
if err := r.client.Close(); err != nil {
api.LogErrorf("Error closing old Redis connection: %v", err)
}
// Create new client
r.client = redis.NewClient(&redis.Options{
Addr: r.config.Address,
Username: r.config.Username,
Password: r.config.Password,
DB: r.config.DB,
})
// Test the new connection
if err := r.checkConnection(); err != nil {
return fmt.Errorf("failed to reconnect to Redis: %w", err)
}
api.LogInfof("Successfully reconnected to Redis")
return nil
}
// TODO: redis keep alive check
// TODO: redis pub sub memory limit
// Publish publishes a message to a Redis channel
func (r *RedisClient) Publish(channel string, message string) error {
err := r.client.Publish(r.ctx, channel, message).Err()
@@ -59,19 +149,31 @@ func (r *RedisClient) Subscribe(channel string, callback func(message string)) e
}
go func() {
defer pubsub.Close()
defer func() {
pubsub.Close()
api.LogInfof("Closed subscription to channel %s", channel)
}()
ch := pubsub.Channel()
for {
select {
case <-r.stopChan:
api.LogDebugf("Stopping subscription to channel %s", channel)
api.LogInfof("Stopping subscription to channel %s", channel)
return
default:
msg, err := pubsub.ReceiveMessage(r.ctx)
if err != nil {
log.Printf("Error receiving message: %v", err)
case msg, ok := <-ch:
if !ok {
api.LogInfof("Redis subscription channel closed for %s", channel)
return
}
callback(msg.Payload)
func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("Recovered from panic in callback: %v", r)
}
}()
callback(msg.Payload)
}()
}
}
}()

View File

@@ -0,0 +1,26 @@
package internal
var GlobalRegistry = NewServerRegistry()
type Server interface {
ParseConfig(config map[string]any) error
NewServer() (*MCPServer, error)
}
type ServerRegistry struct {
servers map[string]Server
}
func NewServerRegistry() *ServerRegistry {
return &ServerRegistry{
servers: make(map[string]Server),
}
}
func (r *ServerRegistry) RegisterServer(name string, server Server) {
r.servers[name] = server
}
func (r *ServerRegistry) GetServer(name string) Server {
return r.servers[name]
}

View File

@@ -15,12 +15,24 @@ import (
type SSEServer struct {
server *MCPServer
baseURL string
MessageEndpoint string
SSEEndpoint string
messageEndpoint string
sseEndpoint string
sessions map[string]bool
redisClient *RedisClient // Redis client for pub/sub
}
func (s *SSEServer) GetMessageEndpoint() string {
return s.messageEndpoint
}
func (s *SSEServer) GetSSEEndpoint() string {
return s.sseEndpoint
}
func (s *SSEServer) GetServerName() string {
return s.server.name
}
// Option defines a function type for configuring SSEServer
type Option func(*SSEServer)
@@ -34,14 +46,14 @@ func WithBaseURL(baseURL string) Option {
// WithMessageEndpoint sets the message endpoint path
func WithMessageEndpoint(endpoint string) Option {
return func(s *SSEServer) {
s.MessageEndpoint = endpoint
s.messageEndpoint = endpoint
}
}
// WithSSEEndpoint sets the SSE endpoint path
func WithSSEEndpoint(endpoint string) Option {
return func(s *SSEServer) {
s.SSEEndpoint = endpoint
s.sseEndpoint = endpoint
}
}
@@ -55,8 +67,8 @@ func WithRedisClient(redisClient *RedisClient) Option {
func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
s := &SSEServer{
server: server,
SSEEndpoint: "/sse",
MessageEndpoint: "/message",
sseEndpoint: "/sse",
messageEndpoint: "/message",
sessions: make(map[string]bool),
}
@@ -84,7 +96,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler) {
messageEndpoint := fmt.Sprintf(
"%s%s?sessionId=%s",
s.baseURL,
s.MessageEndpoint,
s.messageEndpoint,
sessionID,
)
@@ -135,10 +147,10 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
}
sessionID := r.URL.Query().Get("sessionId")
// if sessionID == "" {
// s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
// return
// }
if sessionID == "" {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
return
}
// Set the client context in the server before handling the message
ctx := s.server.WithContext(r.Context(), NotificationContext{
@@ -146,21 +158,13 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
SessionID: sessionID,
})
//TODO sessions
//TODO check session id
// _, ok := s.sessions.Load(sessionID)
// if !ok {
// s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
// return
// }
//TODO
// // Parse message as raw JSON
// var rawMessage json.RawMessage
// if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
// s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
// return
// }
// Process message through MCPServer
response := s.server.HandleMessage(ctx, body)
@@ -198,15 +202,3 @@ func (s *SSEServer) writeJSONRPCError(
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
}
// // ServeHTTP implements the http.Handler interface.
// func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// switch r.URL.Path {
// case s.sseEndpoint:
// s.handleSSE(w, r)
// case s.messageEndpoint:
// s.handleMessage(w, r)
// default:
// http.NotFound(w, r)
// }
// }

View File

@@ -2,13 +2,12 @@ package gorm
import (
"fmt"
"log"
"os"
"gorm.io/driver/clickhouse"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DBClient is a struct to handle PostgreSQL connections and operations
@@ -18,24 +17,16 @@ type DBClient struct {
// NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database
func NewDBClient(dsn string, dbType string) (*DBClient, error) {
// Configure GORM logger
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
logger.Config{
SlowThreshold: 0, // Slow SQL threshold
LogLevel: logger.Info, // Log level
IgnoreRecordNotFoundError: false, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Disable color
},
)
var db *gorm.DB
var err error
if dbType == "postgres" {
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
})
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
} else if dbType == "clickhouse" {
db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{})
} else if dbType == "mysql" {
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
} else if dbType == "sqlite" {
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
} else {
return nil, fmt.Errorf("unsupported database type")
}

View File

@@ -0,0 +1,60 @@
package gorm
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/mark3labs/mcp-go/mcp"
)
func init() {
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
}
type DBConfig struct {
name string
dbType string
dsn string
}
func (c *DBConfig) ParseConfig(config map[string]any) error {
name, ok := config["name"].(string)
if !ok {
return errors.New("missing servername")
}
c.name = name
dsn, ok := config["dsn"].(string)
if !ok {
return errors.New("missing dsn")
}
c.dsn = dsn
dbType, ok := config["dbType"].(string)
if !ok {
return errors.New("missing database type")
}
c.dbType = dbType
return nil
}
func (c *DBConfig) NewServer() (*internal.MCPServer, error) {
mcpServer := internal.NewMCPServer(
c.name,
"1.0.0",
)
dbClient, err := NewDBClient(c.dsn, c.dbType)
if err != nil {
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
}
// Add query tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", GetQueryToolSchema()),
HandleQueryTool(dbClient),
)
return mcpServer, nil
}

View File

@@ -5,41 +5,10 @@ import (
"encoding/json"
"fmt"
"github.com/envoyproxy/envoy/examples/golang-http/simple/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/mark3labs/mcp-go/mcp"
)
const favoriteFilesTemplate = `
WITH current_files AS (
SELECT path
FROM (
SELECT
old_path AS path,
max(time) AS last_time,
2 AS change_type
FROM git.file_changes
GROUP BY old_path
UNION ALL
SELECT
path,
max(time) AS last_time,
argMax(change_type, time) AS change_type
FROM git.file_changes
GROUP BY path
)
GROUP BY path
HAVING (argMax(change_type, last_time) != 2) AND (NOT match(path, '(^dbms/)|(^libs/)|(^tests/testflows/)|(^programs/server/store/)'))
ORDER BY path ASC
)
SELECT
path,
count() AS c
FROM git.file_changes
WHERE (author = '%s') AND (path IN (current_files))
GROUP BY path
ORDER BY c DESC
LIMIT 10`
// HandleQueryTool handles SQL query execution
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -70,37 +39,6 @@ func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
}
}
// HandleFavoriteTool handles author's favorite files query
func HandleFavoriteTool(dbClient *DBClient) internal.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
author, ok := arguments["author"].(string)
if !ok {
return nil, fmt.Errorf("invalid author argument")
}
query := fmt.Sprintf(favoriteFilesTemplate, author)
results, err := dbClient.ExecuteSQL(query)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
jsonData, err := json.Marshal(results)
if err != nil {
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
}
}
// GetQueryToolSchema returns the schema for query tool
func GetQueryToolSchema() json.RawMessage {
return json.RawMessage(`
@@ -115,18 +53,3 @@ func GetQueryToolSchema() json.RawMessage {
}
`)
}
// GetFavoriteToolSchema returns the schema for favorite files tool
func GetFavoriteToolSchema() json.RawMessage {
return json.RawMessage(`
{
"type": "object",
"properties": {
"author": {
"type": "string",
"description": "the author name"
}
}
}
`)
}