mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat: add config parse in mcp server (#1944)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.23 AS golang-base
|
||||
FROM golang:1.23-bullseye AS golang-base
|
||||
|
||||
ARG GOPROXY
|
||||
ARG GO_FILTER_NAME
|
||||
|
||||
@@ -1,33 +1,28 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm" // 导入gorm包以执行其init函数
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
"github.com/envoyproxy/envoy/examples/golang-http/simple/internal"
|
||||
"github.com/envoyproxy/envoy/examples/golang-http/simple/servers/gorm"
|
||||
)
|
||||
|
||||
const Name = "mcp-server"
|
||||
const SCHEME_PATH = "scheme"
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
|
||||
}
|
||||
|
||||
type config struct {
|
||||
echoBody string
|
||||
// other fields
|
||||
dbClient *gorm.DBClient
|
||||
redisClient *internal.RedisClient
|
||||
stopChan chan struct{}
|
||||
SSEServer *internal.SSEServer
|
||||
ssePathSuffix string
|
||||
redisClient *internal.RedisClient
|
||||
stopChan chan struct{}
|
||||
servers []*internal.SSEServer
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
@@ -39,34 +34,68 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := configStruct.Value
|
||||
|
||||
conf := &config{}
|
||||
|
||||
dsn, ok := v.AsMap()["dsn"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("missing dsn")
|
||||
}
|
||||
|
||||
dbType, ok := v.AsMap()["dbType"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("missing database type")
|
||||
}
|
||||
|
||||
dbClient, err := gorm.NewDBClient(dsn, dbType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
|
||||
}
|
||||
conf.dbClient = dbClient
|
||||
|
||||
conf.stopChan = make(chan struct{})
|
||||
redisClient, err := internal.NewRedisClient("localhost:6379", conf.stopChan)
|
||||
|
||||
redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redis config is not set")
|
||||
}
|
||||
|
||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := internal.NewRedisClient(redisConfig, conf.stopChan)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
conf.redisClient = redisClient
|
||||
|
||||
conf.SSEServer = internal.NewSSEServer(NewServer(conf.dbClient), internal.WithRedisClient(redisClient))
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sse path suffix is not set")
|
||||
}
|
||||
conf.ssePathSuffix = ssePathSuffix
|
||||
|
||||
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
|
||||
if !ok {
|
||||
api.LogInfo("No servers are configured")
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
for _, serverConfig := range serverConfigs {
|
||||
serverConfigMap, ok := serverConfig.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server config must be an object")
|
||||
}
|
||||
serverType, ok := serverConfigMap["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server type is not set")
|
||||
}
|
||||
serverPath, ok := serverConfigMap["path"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s path is not set", serverType)
|
||||
}
|
||||
server := internal.GlobalRegistry.GetServer(serverType)
|
||||
|
||||
if server == nil {
|
||||
return nil, fmt.Errorf("server %s is not registered", serverType)
|
||||
}
|
||||
server.ParseConfig(serverConfigMap)
|
||||
serverInstance, err := server.NewServer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
|
||||
}
|
||||
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
|
||||
internal.WithRedisClient(redisClient),
|
||||
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
|
||||
internal.WithMessageEndpoint(serverPath)))
|
||||
api.LogInfo(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -75,15 +104,15 @@ func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
childConfig := child.(*config)
|
||||
|
||||
newConfig := *parentConfig
|
||||
if childConfig.echoBody != "" {
|
||||
newConfig.echoBody = childConfig.echoBody
|
||||
}
|
||||
if childConfig.dbClient != nil {
|
||||
newConfig.dbClient = childConfig.dbClient
|
||||
}
|
||||
if childConfig.redisClient != nil {
|
||||
newConfig.redisClient = childConfig.redisClient
|
||||
}
|
||||
if childConfig.ssePathSuffix != "" {
|
||||
newConfig.ssePathSuffix = childConfig.ssePathSuffix
|
||||
}
|
||||
if childConfig.servers != nil {
|
||||
newConfig.servers = append(newConfig.servers, childConfig.servers...)
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
@@ -98,25 +127,4 @@ func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
|
||||
}
|
||||
}
|
||||
|
||||
func NewServer(dbClient *gorm.DBClient) *internal.MCPServer {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
"mcp-server-envoy-poc",
|
||||
"1.0.0",
|
||||
)
|
||||
|
||||
// Add query tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", gorm.GetQueryToolSchema()),
|
||||
gorm.HandleQueryTool(dbClient),
|
||||
)
|
||||
api.LogInfo("Added query tool successfully")
|
||||
|
||||
// Add favorite files tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("author_favorite_files", "Favorite files for an author", gorm.GetFavoriteToolSchema()),
|
||||
gorm.HandleFavoriteTool(dbClient),
|
||||
)
|
||||
return mcpServer
|
||||
}
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -17,9 +17,10 @@ type filter struct {
|
||||
path string
|
||||
config *config
|
||||
|
||||
req *http.Request
|
||||
sse bool
|
||||
message bool
|
||||
req *http.Request
|
||||
sse bool
|
||||
message bool
|
||||
bodyBuffer []byte
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
@@ -29,38 +30,39 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
parsedURL, _ := url.Parse(fullPath)
|
||||
f.path = parsedURL.Path
|
||||
method, _ := header.Get(":method")
|
||||
api.LogInfo(f.path)
|
||||
if f.path == f.config.SSEServer.SSEEndpoint {
|
||||
if method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.sse = true
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
api.LogInfo("SSE connection started")
|
||||
return api.LocalReply
|
||||
} else if f.path == f.config.SSEServer.MessageEndpoint {
|
||||
if method != http.MethodPost {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
}
|
||||
// Create a new http.Request object
|
||||
f.req = &http.Request{
|
||||
Method: method,
|
||||
URL: parsedURL,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
api.LogInfof("Message request: %v", parsedURL)
|
||||
// Copy headers from api.RequestHeaderMap to http.Header
|
||||
header.Range(func(key, value string) bool {
|
||||
f.req.Header.Add(key, value)
|
||||
return true
|
||||
})
|
||||
f.message = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetSSEEndpoint() {
|
||||
if method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.sse = true
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
api.LogInfof("%s SSE connection started", server.GetServerName())
|
||||
return api.LocalReply
|
||||
} else if f.path == server.GetMessageEndpoint() {
|
||||
if method != http.MethodPost {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
}
|
||||
// Create a new http.Request object
|
||||
f.req = &http.Request{
|
||||
Method: method,
|
||||
URL: parsedURL,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
api.LogDebugf("Message request: %v", parsedURL)
|
||||
// Copy headers from api.RequestHeaderMap to http.Header
|
||||
header.Range(func(key, value string) bool {
|
||||
f.req.Header.Add(key, value)
|
||||
return true
|
||||
})
|
||||
f.message = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
}
|
||||
if endStream {
|
||||
@@ -73,18 +75,25 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
api.LogInfo("Message DecodeData")
|
||||
// support suspending & resuming the filter in a background goroutine
|
||||
api.LogInfof("DecodeData: {%v}", buffer)
|
||||
if f.message {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer
|
||||
f.config.SSEServer.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
f.message = false
|
||||
api.LogInfof("Message DecodeData SendLocalReply %v", recorder)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
f.bodyBuffer = append(f.bodyBuffer, buffer.Bytes()...)
|
||||
|
||||
if endStream {
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetMessageEndpoint() {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
server.HandleMessage(recorder, f.req, f.bodyBuffer)
|
||||
f.message = false
|
||||
// clear buffer
|
||||
f.bodyBuffer = nil
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(recorder.Code, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
@@ -98,23 +107,21 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
api.LogInfo("SSE connection header set")
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// TODO: 连接多种数据库
|
||||
// TODO: 多种存储类型
|
||||
// TODO: 数据库多个实例
|
||||
// EncodeData might be called multiple times during handling the response body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.sse {
|
||||
//TODO: buffer cleanup
|
||||
f.config.SSEServer.HandleSSE(f.callbacks)
|
||||
f.sse = false
|
||||
return api.Running
|
||||
for _, server := range f.config.servers {
|
||||
if f.sse {
|
||||
buffer.Reset()
|
||||
server.HandleSSE(f.callbacks)
|
||||
f.sse = false
|
||||
return api.Running
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
module github.com/envoyproxy/envoy/examples/golang-http/simple
|
||||
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
|
||||
|
||||
go 1.23
|
||||
|
||||
require (
|
||||
github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993
|
||||
google.golang.org/protobuf v1.36.5
|
||||
github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42
|
||||
github.com/envoyproxy/envoy v1.33.1-0.20250224062430-6c11eac01993
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mark3labs/mcp-go v0.12.0
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gorm.io/driver/clickhouse v0.6.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
|
||||
@@ -24,6 +26,7 @@ require (
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
||||
github.com/go-faster/city v1.0.1 // indirect
|
||||
github.com/go-faster/errors v0.7.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/hashicorp/go-version v1.6.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
@@ -32,6 +35,7 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/paulmach/orb v0.11.1 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
|
||||
@@ -27,6 +27,8 @@ github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AY
|
||||
github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
@@ -64,6 +66,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo=
|
||||
github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
@@ -167,7 +171,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/clickhouse v0.6.1 h1:t7JMB6sLBXxN8hEO6RdzCbJCwq/jAEVZdwXlmQs1Sd4=
|
||||
gorm.io/driver/clickhouse v0.6.1/go.mod h1:riMYpJcGZ3sJ/OAZZ1rEP1j/Y0H6cByOAnwz7fo2AyM=
|
||||
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
|
||||
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||
gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314=
|
||||
gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
|
||||
gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
|
||||
gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
|
||||
@@ -3,26 +3,62 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
type RedisConfig struct {
|
||||
Address string
|
||||
Username string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
func ParseRedisConfig(config map[string]any) (*RedisConfig, error) {
|
||||
c := &RedisConfig{}
|
||||
|
||||
// address is required
|
||||
addr, ok := config["address"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("address is required and must be a string")
|
||||
}
|
||||
c.Address = addr
|
||||
|
||||
// username is optional
|
||||
if username, ok := config["username"].(string); ok {
|
||||
c.Username = username
|
||||
}
|
||||
|
||||
// password is optional
|
||||
if password, ok := config["password"].(string); ok {
|
||||
c.Password = password
|
||||
}
|
||||
|
||||
// db is optional, default to 0
|
||||
if db, ok := config["db"].(int); ok {
|
||||
c.DB = db
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// RedisClient is a struct to handle Redis connections and operations
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
ctx context.Context
|
||||
stopChan chan struct{}
|
||||
config *RedisConfig
|
||||
}
|
||||
|
||||
// NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server
|
||||
func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error) {
|
||||
func NewRedisClient(config *RedisConfig, stopChan chan struct{}) (*RedisClient, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: address,
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
Addr: config.Address,
|
||||
Username: config.Username,
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
})
|
||||
|
||||
// Ping the Redis server to check the connection
|
||||
@@ -30,17 +66,71 @@ func NewRedisClient(address string, stopChan chan struct{}) (*RedisClient, error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
log.Printf("Connected to Redis: %s", pong)
|
||||
api.LogInfof("Connected to Redis: %s", pong)
|
||||
|
||||
return &RedisClient{
|
||||
redisClient := &RedisClient{
|
||||
client: client,
|
||||
ctx: context.Background(),
|
||||
stopChan: stopChan,
|
||||
}, nil
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Start keep-alive check
|
||||
go redisClient.keepAlive()
|
||||
|
||||
return redisClient, nil
|
||||
}
|
||||
|
||||
// keepAlive periodically checks Redis connection and attempts to reconnect if needed
|
||||
func (r *RedisClient) keepAlive() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.checkConnection(); err != nil {
|
||||
api.LogErrorf("Redis connection check failed: %v", err)
|
||||
if err := r.reconnect(); err != nil {
|
||||
api.LogErrorf("Failed to reconnect to Redis: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkConnection verifies if the Redis connection is still alive
|
||||
func (r *RedisClient) checkConnection() error {
|
||||
_, err := r.client.Ping(r.ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// reconnect attempts to establish a new connection to Redis
|
||||
func (r *RedisClient) reconnect() error {
|
||||
// Close the old client
|
||||
if err := r.client.Close(); err != nil {
|
||||
api.LogErrorf("Error closing old Redis connection: %v", err)
|
||||
}
|
||||
|
||||
// Create new client
|
||||
r.client = redis.NewClient(&redis.Options{
|
||||
Addr: r.config.Address,
|
||||
Username: r.config.Username,
|
||||
Password: r.config.Password,
|
||||
DB: r.config.DB,
|
||||
})
|
||||
|
||||
// Test the new connection
|
||||
if err := r.checkConnection(); err != nil {
|
||||
return fmt.Errorf("failed to reconnect to Redis: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Successfully reconnected to Redis")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: redis keep alive check
|
||||
// TODO: redis pub sub memory limit
|
||||
// Publish publishes a message to a Redis channel
|
||||
func (r *RedisClient) Publish(channel string, message string) error {
|
||||
err := r.client.Publish(r.ctx, channel, message).Err()
|
||||
@@ -59,19 +149,31 @@ func (r *RedisClient) Subscribe(channel string, callback func(message string)) e
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer pubsub.Close()
|
||||
defer func() {
|
||||
pubsub.Close()
|
||||
api.LogInfof("Closed subscription to channel %s", channel)
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-r.stopChan:
|
||||
api.LogDebugf("Stopping subscription to channel %s", channel)
|
||||
api.LogInfof("Stopping subscription to channel %s", channel)
|
||||
return
|
||||
default:
|
||||
msg, err := pubsub.ReceiveMessage(r.ctx)
|
||||
if err != nil {
|
||||
log.Printf("Error receiving message: %v", err)
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
api.LogInfof("Redis subscription channel closed for %s", channel)
|
||||
return
|
||||
}
|
||||
callback(msg.Payload)
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Recovered from panic in callback: %v", r)
|
||||
}
|
||||
}()
|
||||
callback(msg.Payload)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
26
plugins/golang-filter/mcp-server/internal/registry.go
Normal file
26
plugins/golang-filter/mcp-server/internal/registry.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package internal
|
||||
|
||||
var GlobalRegistry = NewServerRegistry()
|
||||
|
||||
type Server interface {
|
||||
ParseConfig(config map[string]any) error
|
||||
NewServer() (*MCPServer, error)
|
||||
}
|
||||
|
||||
type ServerRegistry struct {
|
||||
servers map[string]Server
|
||||
}
|
||||
|
||||
func NewServerRegistry() *ServerRegistry {
|
||||
return &ServerRegistry{
|
||||
servers: make(map[string]Server),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ServerRegistry) RegisterServer(name string, server Server) {
|
||||
r.servers[name] = server
|
||||
}
|
||||
|
||||
func (r *ServerRegistry) GetServer(name string) Server {
|
||||
return r.servers[name]
|
||||
}
|
||||
@@ -15,12 +15,24 @@ import (
|
||||
type SSEServer struct {
|
||||
server *MCPServer
|
||||
baseURL string
|
||||
MessageEndpoint string
|
||||
SSEEndpoint string
|
||||
messageEndpoint string
|
||||
sseEndpoint string
|
||||
sessions map[string]bool
|
||||
redisClient *RedisClient // Redis client for pub/sub
|
||||
}
|
||||
|
||||
func (s *SSEServer) GetMessageEndpoint() string {
|
||||
return s.messageEndpoint
|
||||
}
|
||||
|
||||
func (s *SSEServer) GetSSEEndpoint() string {
|
||||
return s.sseEndpoint
|
||||
}
|
||||
|
||||
func (s *SSEServer) GetServerName() string {
|
||||
return s.server.name
|
||||
}
|
||||
|
||||
// Option defines a function type for configuring SSEServer
|
||||
type Option func(*SSEServer)
|
||||
|
||||
@@ -34,14 +46,14 @@ func WithBaseURL(baseURL string) Option {
|
||||
// WithMessageEndpoint sets the message endpoint path
|
||||
func WithMessageEndpoint(endpoint string) Option {
|
||||
return func(s *SSEServer) {
|
||||
s.MessageEndpoint = endpoint
|
||||
s.messageEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSEEndpoint sets the SSE endpoint path
|
||||
func WithSSEEndpoint(endpoint string) Option {
|
||||
return func(s *SSEServer) {
|
||||
s.SSEEndpoint = endpoint
|
||||
s.sseEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,8 +67,8 @@ func WithRedisClient(redisClient *RedisClient) Option {
|
||||
func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
|
||||
s := &SSEServer{
|
||||
server: server,
|
||||
SSEEndpoint: "/sse",
|
||||
MessageEndpoint: "/message",
|
||||
sseEndpoint: "/sse",
|
||||
messageEndpoint: "/message",
|
||||
sessions: make(map[string]bool),
|
||||
}
|
||||
|
||||
@@ -84,7 +96,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler) {
|
||||
messageEndpoint := fmt.Sprintf(
|
||||
"%s%s?sessionId=%s",
|
||||
s.baseURL,
|
||||
s.MessageEndpoint,
|
||||
s.messageEndpoint,
|
||||
sessionID,
|
||||
)
|
||||
|
||||
@@ -135,10 +147,10 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
}
|
||||
|
||||
sessionID := r.URL.Query().Get("sessionId")
|
||||
// if sessionID == "" {
|
||||
// s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
|
||||
// return
|
||||
// }
|
||||
if sessionID == "" {
|
||||
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the client context in the server before handling the message
|
||||
ctx := s.server.WithContext(r.Context(), NotificationContext{
|
||||
@@ -146,21 +158,13 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
SessionID: sessionID,
|
||||
})
|
||||
|
||||
//TODO: sessions
|
||||
//TODO: check session id
|
||||
// _, ok := s.sessions.Load(sessionID)
|
||||
// if !ok {
|
||||
// s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
|
||||
// return
|
||||
// }
|
||||
|
||||
//TODO
|
||||
// // Parse message as raw JSON
|
||||
// var rawMessage json.RawMessage
|
||||
// if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
|
||||
// s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
|
||||
// return
|
||||
// }
|
||||
|
||||
// Process message through MCPServer
|
||||
response := s.server.HandleMessage(ctx, body)
|
||||
|
||||
@@ -198,15 +202,3 @@ func (s *SSEServer) writeJSONRPCError(
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// // ServeHTTP implements the http.Handler interface.
|
||||
// func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// switch r.URL.Path {
|
||||
// case s.sseEndpoint:
|
||||
// s.handleSSE(w, r)
|
||||
// case s.messageEndpoint:
|
||||
// s.handleMessage(w, r)
|
||||
// default:
|
||||
// http.NotFound(w, r)
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -2,13 +2,12 @@ package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"gorm.io/driver/clickhouse"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// DBClient is a struct to handle PostgreSQL connections and operations
|
||||
@@ -18,24 +17,16 @@ type DBClient struct {
|
||||
|
||||
// NewDBClient creates a new DBClient instance and establishes a connection to the PostgreSQL database
|
||||
func NewDBClient(dsn string, dbType string) (*DBClient, error) {
|
||||
// Configure GORM logger
|
||||
newLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: 0, // Slow SQL threshold
|
||||
LogLevel: logger.Info, // Log level
|
||||
IgnoreRecordNotFoundError: false, // Ignore ErrRecordNotFound error for logger
|
||||
Colorful: true, // Disable color
|
||||
},
|
||||
)
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
if dbType == "postgres" {
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: newLogger,
|
||||
})
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "clickhouse" {
|
||||
db, err = gorm.Open(clickhouse.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "mysql" {
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
} else if dbType == "sqlite" {
|
||||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported database type")
|
||||
}
|
||||
|
||||
60
plugins/golang-filter/mcp-server/servers/gorm/server.go
Normal file
60
plugins/golang-filter/mcp-server/servers/gorm/server.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
name string
|
||||
dbType string
|
||||
dsn string
|
||||
}
|
||||
|
||||
func (c *DBConfig) ParseConfig(config map[string]any) error {
|
||||
name, ok := config["name"].(string)
|
||||
if !ok {
|
||||
return errors.New("missing servername")
|
||||
}
|
||||
c.name = name
|
||||
|
||||
dsn, ok := config["dsn"].(string)
|
||||
if !ok {
|
||||
return errors.New("missing dsn")
|
||||
}
|
||||
c.dsn = dsn
|
||||
|
||||
dbType, ok := config["dbType"].(string)
|
||||
if !ok {
|
||||
return errors.New("missing database type")
|
||||
}
|
||||
c.dbType = dbType
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DBConfig) NewServer() (*internal.MCPServer, error) {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
c.name,
|
||||
"1.0.0",
|
||||
)
|
||||
|
||||
dbClient, err := NewDBClient(c.dsn, c.dbType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize DBClient: %w", err)
|
||||
}
|
||||
|
||||
// Add query tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("query", "Run a read-only SQL query in clickhouse database with repository git data", GetQueryToolSchema()),
|
||||
HandleQueryTool(dbClient),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
@@ -5,41 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/envoyproxy/envoy/examples/golang-http/simple/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const favoriteFilesTemplate = `
|
||||
WITH current_files AS (
|
||||
SELECT path
|
||||
FROM (
|
||||
SELECT
|
||||
old_path AS path,
|
||||
max(time) AS last_time,
|
||||
2 AS change_type
|
||||
FROM git.file_changes
|
||||
GROUP BY old_path
|
||||
UNION ALL
|
||||
SELECT
|
||||
path,
|
||||
max(time) AS last_time,
|
||||
argMax(change_type, time) AS change_type
|
||||
FROM git.file_changes
|
||||
GROUP BY path
|
||||
)
|
||||
GROUP BY path
|
||||
HAVING (argMax(change_type, last_time) != 2) AND (NOT match(path, '(^dbms/)|(^libs/)|(^tests/testflows/)|(^programs/server/store/)'))
|
||||
ORDER BY path ASC
|
||||
)
|
||||
SELECT
|
||||
path,
|
||||
count() AS c
|
||||
FROM git.file_changes
|
||||
WHERE (author = '%s') AND (path IN (current_files))
|
||||
GROUP BY path
|
||||
ORDER BY c DESC
|
||||
LIMIT 10`
|
||||
|
||||
// HandleQueryTool handles SQL query execution
|
||||
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
@@ -70,37 +39,6 @@ func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// HandleFavoriteTool handles author's favorite files query
|
||||
func HandleFavoriteTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
author, ok := arguments["author"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid author argument")
|
||||
}
|
||||
query := fmt.Sprintf(favoriteFilesTemplate, author)
|
||||
|
||||
results, err := dbClient.ExecuteSQL(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryToolSchema returns the schema for query tool
|
||||
func GetQueryToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
@@ -115,18 +53,3 @@ func GetQueryToolSchema() json.RawMessage {
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetFavoriteToolSchema returns the schema for favorite files tool
|
||||
func GetFavoriteToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"author": {
|
||||
"type": "string",
|
||||
"description": "the author name"
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user