feat: support config store and redis configuration optional in mcp server (#2035)

This commit is contained in:
Jingze
2025-04-14 20:52:48 +08:00
committed by GitHub
parent ed925ddf84
commit c7abfb8aff
16 changed files with 837 additions and 119 deletions

View File

@@ -235,8 +235,7 @@ clean-gateway: clean-istio
rm -rf external/proxy rm -rf external/proxy
rm -rf external/go-control-plane rm -rf external/go-control-plane
rm -rf external/package/envoy.tar.gz rm -rf external/package/envoy.tar.gz
rm -rf external/package/mcp-server_amd64.so rm -rf external/package/*.so
rm -rf external/package/mcp-server_arm64.so
clean-env: clean-env:
rm -rf out/ rm -rf out/

View File

@@ -41,6 +41,16 @@ type RedisConfig struct {
DB int `json:"db,omitempty"` DB int `json:"db,omitempty"`
} }
// MCPRatelimitConfig defines the configuration for rate limit
type MCPRatelimitConfig struct {
// The limit of the rate limit
Limit int64 `json:"limit,omitempty"`
// The window of the rate limit
Window int64 `json:"window,omitempty"`
// The white list of the rate limit
WhiteList []string `json:"white_list,omitempty"`
}
// SSEServer defines the configuration for Server-Sent Events (SSE) server // SSEServer defines the configuration for Server-Sent Events (SSE) server
type SSEServer struct { type SSEServer struct {
// The name of the SSE server // The name of the SSE server
@@ -75,13 +85,18 @@ type McpServer struct {
Servers []*SSEServer `json:"servers,omitempty"` Servers []*SSEServer `json:"servers,omitempty"`
// List of match rules for filtering requests // List of match rules for filtering requests
MatchList []*MatchRule `json:"match_list,omitempty"` MatchList []*MatchRule `json:"match_list,omitempty"`
// Flag to control whether user level server is enabled
EnableUserLevelServer bool `json:"enable_user_level_server,omitempty"`
// Rate limit config for MCP server
Ratelimit *MCPRatelimitConfig `json:"rate_limit,omitempty"`
} }
func NewDefaultMcpServer() *McpServer { func NewDefaultMcpServer() *McpServer {
return &McpServer{ return &McpServer{
Enable: false, Enable: false,
Servers: make([]*SSEServer, 0), Servers: make([]*SSEServer, 0),
MatchList: make([]*MatchRule, 0), MatchList: make([]*MatchRule, 0),
EnableUserLevelServer: false,
} }
} }
@@ -94,8 +109,8 @@ func validMcpServer(m *McpServer) error {
return nil return nil
} }
if m.Enable && m.Redis == nil { if m.EnableUserLevelServer && m.Redis == nil {
return errors.New("redis config cannot be empty when mcp server is enabled") return errors.New("redis config cannot be empty when user level server is enabled")
} }
// Validate match rule types // Validate match rule types
@@ -149,9 +164,17 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) {
DB: mcp.Redis.DB, DB: mcp.Redis.DB,
} }
} }
if mcp.Ratelimit != nil {
newMcp.Ratelimit = &MCPRatelimitConfig{
Limit: mcp.Ratelimit.Limit,
Window: mcp.Ratelimit.Window,
WhiteList: mcp.Ratelimit.WhiteList,
}
}
newMcp.SsePathSuffix = mcp.SsePathSuffix newMcp.SsePathSuffix = mcp.SsePathSuffix
newMcp.EnableUserLevelServer = mcp.EnableUserLevelServer
if len(mcp.Servers) > 0 { if len(mcp.Servers) > 0 {
newMcp.Servers = make([]*SSEServer, len(mcp.Servers)) newMcp.Servers = make([]*SSEServer, len(mcp.Servers))
for i, server := range mcp.Servers { for i, server := range mcp.Servers {
@@ -352,40 +375,59 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ",")) matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
} }
// 构建 Redis 配置
redisConfig := "null"
if mcp.Redis != nil {
redisConfig = fmt.Sprintf(`{
"address": "%s",
"username": "%s",
"password": "%s",
"db": %d
}`, mcp.Redis.Address, mcp.Redis.Username, mcp.Redis.Password, mcp.Redis.DB)
}
// 构建限流配置
rateLimitConfig := "null"
if mcp.Ratelimit != nil {
whiteList := "[]"
if len(mcp.Ratelimit.WhiteList) > 0 {
whiteList = fmt.Sprintf(`["%s"]`, strings.Join(mcp.Ratelimit.WhiteList, `","`))
}
rateLimitConfig = fmt.Sprintf(`{
"limit": %d,
"window": %d,
"white_list": %s
}`, mcp.Ratelimit.Limit, mcp.Ratelimit.Window, whiteList)
}
// Build complete configuration structure // Build complete configuration structure
structFmt := `{ return fmt.Sprintf(`{
"name": "envoy.filters.http.golang", "name": "envoy.filters.http.golang",
"typed_config": { "typed_config": {
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct", "@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config", "type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
"value": { "value": {
"library_id": "mcp-server", "library_id": "mcp-session",
"library_path": "/var/lib/istio/envoy/mcp-server.so", "library_path": "/var/lib/istio/envoy/golang-filter.so",
"plugin_name": "mcp-server", "plugin_name": "mcp-session",
"plugin_config": { "plugin_config": {
"@type": "type.googleapis.com/xds.type.v3.TypedStruct", "@type": "type.googleapis.com/xds.type.v3.TypedStruct",
"value": { "value": {
"redis": { "redis": %s,
"address": "%s", "rate_limit": %s,
"username": "%s",
"password": "%s",
"db": %d
},
"sse_path_suffix": "%s", "sse_path_suffix": "%s",
"match_list": %s, "match_list": %s,
"servers": %s "servers": %s,
"enable_user_level_server": %t
} }
} }
} }
} }
}` }`,
redisConfig,
return fmt.Sprintf(structFmt, rateLimitConfig,
mcp.Redis.Address,
mcp.Redis.Username,
mcp.Redis.Password,
mcp.Redis.DB,
mcp.SsePathSuffix, mcp.SsePathSuffix,
matchList, matchList,
servers) servers,
mcp.EnableUserLevelServer)
} }

View File

@@ -45,17 +45,30 @@ func Test_validMcpServer(t *testing.T) {
{ {
name: "enabled but no redis config", name: "enabled but no redis config",
mcp: &McpServer{ mcp: &McpServer{
Enable: true, Enable: true,
Redis: nil, EnableUserLevelServer: false,
MatchList: []*MatchRule{}, Redis: nil,
Servers: []*SSEServer{}, MatchList: []*MatchRule{},
Servers: []*SSEServer{},
}, },
wantErr: errors.New("redis config cannot be empty when mcp server is enabled"), wantErr: nil,
},
{
name: "enabled with user level server but no redis config",
mcp: &McpServer{
Enable: true,
EnableUserLevelServer: true,
Redis: nil,
MatchList: []*MatchRule{},
Servers: []*SSEServer{},
},
wantErr: errors.New("redis config cannot be empty when user level server is enabled"),
}, },
{ {
name: "valid config with redis", name: "valid config with redis",
mcp: &McpServer{ mcp: &McpServer{
Enable: true, Enable: true,
EnableUserLevelServer: true,
Redis: &RedisConfig{ Redis: &RedisConfig{
Address: "localhost:6379", Address: "localhost:6379",
Username: "default", Username: "default",

View File

@@ -36,4 +36,4 @@ RUN if [ "$GOARCH" = "arm64" ]; then \
FROM scratch AS output FROM scratch AS output
ARG GO_FILTER_NAME ARG GO_FILTER_NAME
ARG GOARCH ARG GOARCH
COPY --from=golang-base /${GO_FILTER_NAME}.so ${GO_FILTER_NAME}_${GOARCH}.so COPY --from=golang-base /${GO_FILTER_NAME}.so golang-filter_${GOARCH}.so

View File

@@ -3,9 +3,13 @@ package main
import ( import (
"fmt" "fmt"
"net/http"
_ "net/http/pprof"
xds "github.com/cncf/xds/go/xds/type/v3" xds "github.com/cncf/xds/go/xds/type/v3"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos" _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm" _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
@@ -13,20 +17,31 @@ import (
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http" envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
) )
const Name = "mcp-server" const Name = "mcp-session"
const Version = "1.0.0" const Version = "1.0.0"
const DefaultServerName = "defaultServer" const DefaultServerName = "defaultServer"
const ConfigPathSuffix = "/config"
func init() { func init() {
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{}) envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("PProf server recovered from panic: %v", r)
}
}()
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
}()
} }
type config struct { type config struct {
ssePathSuffix string ssePathSuffix string
redisClient *internal.RedisClient redisClient *internal.RedisClient
servers []*internal.SSEServer servers []*internal.SSEServer
defaultServer *internal.SSEServer defaultServer *internal.SSEServer
matchList []internal.MatchRule matchList []internal.MatchRule
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
} }
func (c *config) Destroy() { func (c *config) Destroy() {
@@ -71,22 +86,50 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
} }
} }
redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}) // Redis configuration is optional
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
if err != nil {
return nil, fmt.Errorf("failed to parse redis config: %w", err)
}
redisClient, err := internal.NewRedisClient(redisConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
}
conf.redisClient = redisClient
api.LogDebug("Redis client initialized")
} else {
api.LogDebug("Redis configuration not provided, running without Redis")
}
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
if !ok { if !ok {
return nil, fmt.Errorf("redis config is not set") enableUserLevelServer = false
if conf.redisClient == nil {
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
}
} }
conf.enableUserLevelServer = enableUserLevelServer
redisConfig, err := internal.ParseRedisConfig(redisConfigMap) if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
if err != nil { rateLimitConfig := &handler.MCPRatelimitConfig{}
return nil, fmt.Errorf("failed to parse redis config: %w", err) if limit, ok := rateLimit["limit"].(float64); ok {
rateLimitConfig.Limit = int(limit)
}
if window, ok := rateLimit["window"].(float64); ok {
rateLimitConfig.Window = int(window)
}
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
for _, item := range whiteList {
if uid, ok := item.(string); ok {
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
}
}
}
conf.rateLimitConfig = rateLimitConfig
} }
redisClient, err := internal.NewRedisClient(redisConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
}
conf.redisClient = redisClient
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string) ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
if !ok || ssePathSuffix == "" { if !ok || ssePathSuffix == "" {
return nil, fmt.Errorf("sse path suffix is not set or empty") return nil, fmt.Errorf("sse path suffix is not set or empty")
@@ -127,7 +170,7 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
} }
api.LogDebug(fmt.Sprintf("Server config: %+v", serverConfig)) api.LogDebug(fmt.Sprintf("Server config: %+v", serverConfig))
err = server.ParseConfig(serverConfig) err := server.ParseConfig(serverConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse server config: %w", err) return nil, fmt.Errorf("failed to parse server config: %w", err)
} }
@@ -138,7 +181,7 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
} }
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance, conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
internal.WithRedisClient(redisClient), internal.WithRedisClient(conf.redisClient),
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)), internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
internal.WithMessageEndpoint(serverPath))) internal.WithMessageEndpoint(serverPath)))
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType)) api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
@@ -158,11 +201,14 @@ func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
newConfig.ssePathSuffix = childConfig.ssePathSuffix newConfig.ssePathSuffix = childConfig.ssePathSuffix
} }
if childConfig.servers != nil { if childConfig.servers != nil {
newConfig.servers = append(newConfig.servers, childConfig.servers...) newConfig.servers = childConfig.servers
} }
if childConfig.defaultServer != nil { if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer newConfig.defaultServer = childConfig.defaultServer
} }
if childConfig.matchList != nil {
newConfig.matchList = childConfig.matchList
}
return &newConfig return &newConfig
} }
@@ -172,9 +218,11 @@ func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
panic("unexpected config type") panic("unexpected config type")
} }
return &filter{ return &filter{
callbacks: callbacks, callbacks: callbacks,
config: conf, config: conf,
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
} }
} }

View File

@@ -5,12 +5,18 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strconv"
"strings" "strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api" "github.com/envoyproxy/envoy/contrib/golang/common/go/api"
) )
const (
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
)
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand. // The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
// Because api.PassThroughStreamFilter provides a default implementation. // Because api.PassThroughStreamFilter provides a default implementation.
type filter struct { type filter struct {
@@ -26,15 +32,20 @@ type filter struct {
message bool message bool
proxyURL *url.URL proxyURL *url.URL
skip bool skip bool
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
mcpRatelimitHandler *handler.MCPRatelimitHandler
} }
type RequestURL struct { type RequestURL struct {
method string method string
scheme string scheme string
host string host string
path string path string
baseURL string baseURL string
parsedURL *url.URL parsedURL *url.URL
internalIP bool
} }
func NewRequestURL(header api.RequestHeaderMap) *RequestURL { func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
@@ -42,10 +53,11 @@ func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
scheme, _ := header.Get(":scheme") scheme, _ := header.Get(":scheme")
host, _ := header.Get(":authority") host, _ := header.Get(":authority")
path, _ := header.Get(":path") path, _ := header.Get(":path")
internalIP, _ := header.Get("x-envoy-internal")
baseURL := fmt.Sprintf("%s://%s", scheme, host) baseURL := fmt.Sprintf("%s://%s", scheme, host)
parsedURL, _ := url.Parse(path) parsedURL, _ := url.Parse(path)
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path) api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL} return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
} }
// Callbacks which are called in request path // Callbacks which are called in request path
@@ -71,11 +83,11 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
} }
api.LogDebugf("%s SSE connection started", server.GetServerName()) api.LogDebugf("%s SSE connection started", server.GetServerName())
server.SetBaseURL(url.baseURL)
return api.LocalReply return api.LocalReply
} else if f.path == server.GetMessageEndpoint() { } else if f.path == server.GetMessageEndpoint() {
if url.method != http.MethodPost { if url.method != http.MethodPost {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
return api.LocalReply
} }
// Create a new http.Request object // Create a new http.Request object
f.req = &http.Request{ f.req = &http.Request{
@@ -97,8 +109,57 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
} }
} }
} }
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !url.internalIP {
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.path, url.method, []byte{})
return api.LocalReply
}
f.req = &http.Request{
Method: url.method,
URL: url.parsedURL,
}
f.userLevelConfig = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) { if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
f.proxyURL = url.parsedURL f.proxyURL = url.parsedURL
if f.config.enableUserLevelServer {
parts := strings.Split(url.parsedURL.Path, "/")
if len(parts) < 3 {
api.LogDebugf("Access denied: missing uid in path %s", url.parsedURL.Path)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "Access denied: missing uid", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
} else {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(url.parsedURL.Path, url.method, []byte{}) {
return api.LocalReply
}
}
}
return api.Continue return api.Continue
} }
@@ -112,7 +173,6 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
f.serverName = f.config.defaultServer.GetServerName() f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create" body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "") f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
f.config.defaultServer.SetBaseURL(url.baseURL)
} }
return api.LocalReply return api.LocalReply
} }
@@ -138,6 +198,11 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu
} }
} }
return api.StopAndBuffer return api.StopAndBuffer
} else if f.userLevelConfig {
// Handle config POST request
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.path, f.req.Method, buffer.Bytes())
return api.LocalReply
} }
return api.Continue return api.Continue
} }
@@ -149,11 +214,15 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
return api.Continue return api.Continue
} }
if f.serverName != "" { if f.serverName != "" {
header.Set("Content-Type", "text/event-stream") if f.config.redisClient != nil {
header.Set("Cache-Control", "no-cache") header.Set("Content-Type", "text/event-stream")
header.Set("Connection", "keep-alive") header.Set("Cache-Control", "no-cache")
header.Set("Access-Control-Allow-Origin", "*") header.Set("Connection", "keep-alive")
header.Del("Content-Length") header.Set("Access-Control-Allow-Origin", "*")
header.Del("Content-Length")
} else {
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
}
return api.Continue return api.Continue
} }
return api.Continue return api.Continue
@@ -168,7 +237,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
if !endStream { if !endStream {
return api.StopAndBuffer return api.StopAndBuffer
} }
if f.proxyURL != nil { if f.proxyURL != nil && f.config.redisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId") sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" { if sessionID != "" {
channel := internal.GetSSEChannelName(sessionID) channel := internal.GetSSEChannelName(sessionID)
@@ -181,21 +250,26 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
} }
if f.serverName != "" { if f.serverName != "" {
// handle specific server if f.config.redisClient != nil {
for _, server := range f.config.servers { // handle specific server
if f.serverName == server.GetServerName() { for _, server := range f.config.servers {
if f.serverName == server.GetServerName() {
buffer.Reset()
server.HandleSSE(f.callbacks, f.stopChan)
return api.Running
}
}
// handle default server
if f.serverName == f.config.defaultServer.GetServerName() {
buffer.Reset() buffer.Reset()
server.HandleSSE(f.callbacks, f.stopChan) f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running return api.Running
} }
return api.Continue
} else {
buffer.SetString(RedisNotEnabledResponseBody)
return api.Continue
} }
// handle default server
if f.serverName == f.config.defaultServer.GetServerName() {
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
}
return api.Continue
} }
return api.Continue return api.Continue
} }

View File

@@ -136,12 +136,8 @@ github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9r
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ= github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/envoyproxy/envoy v1.32.3 h1:eftH199KwYfyBTtm4reeEzsWTqraACEaTQ6efl31v0I=
github.com/envoyproxy/envoy v1.32.3/go.mod h1:KGS+IUehDX1mSIdqodPTWskKOo7bZMLLy3GHxvOKcJk=
github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99 h1:jih/Ieb7BFgVCStgvY5fXQ3mI9ByOt4wfwUF0d7qmqI= github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99 h1:jih/Ieb7BFgVCStgvY5fXQ3mI9ByOt4wfwUF0d7qmqI=
github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99/go.mod h1:x7d0dNbE0xGuDBUkBg19VGCgnPQ+lJ2k8lDzDzKExow= github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99/go.mod h1:x7d0dNbE0xGuDBUkBg19VGCgnPQ+lJ2k8lDzDzKExow=
github.com/envoyproxy/envoy v1.33.2 h1:k3ChySbVo4HejvbDRxkgRroUnj6TZZpXPJJ0UGaZkXs=
github.com/envoyproxy/envoy v1.33.2/go.mod h1:faFqv1XeNGX/ph6Zto5Culdcpk4Klxp730Q6XhWarV4=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -285,6 +281,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-40 h1:nzRTBplC0riQqQwEHZThw5H4/TH5LgWTQTm6A7t1lpY=
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-40/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo= github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo=
github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
@@ -302,8 +300,6 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nacos-group/nacos-sdk-go/v2 v2.2.9 h1:etzCMnB9EBeSKfaDIOe8zH4HO/8fycpc6s0AmXCrmAw=
github.com/nacos-group/nacos-sdk-go/v2 v2.2.9/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=

View File

@@ -0,0 +1,153 @@
package handler
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
// MCPConfigHandler handles configuration requests for MCP server
type MCPConfigHandler struct {
configStore ConfigStore
callbacks api.FilterCallbackHandler
}
// NewMCPConfigHandler creates a new instance of MCP configuration handler
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
return &MCPConfigHandler{
configStore: NewRedisConfigStore(redisClient),
callbacks: callbacks,
}
}
// HandleConfigRequest processes configuration requests
func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body []byte) bool {
// Check if it's a configuration request
if !strings.HasSuffix(path, "/config") {
return false
}
// Extract serverName and uid from path
pathParts := strings.Split(strings.TrimSuffix(path, "/config"), "/")
if len(pathParts) < 2 {
h.sendErrorResponse(http.StatusBadRequest, "INVALID_PATH", "Invalid path format")
return true
}
uid := pathParts[len(pathParts)-1]
serverName := pathParts[len(pathParts)-2]
switch method {
case http.MethodGet:
return h.handleGetConfig(serverName, uid)
case http.MethodPost:
return h.handleStoreConfig(serverName, uid, body)
default:
h.sendErrorResponse(http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "Method not allowed")
return true
}
}
// handleGetConfig handles configuration retrieval requests
func (h *MCPConfigHandler) handleGetConfig(serverName string, uid string) bool {
config, err := h.configStore.GetConfig(serverName, uid)
if err != nil {
api.LogErrorf("Failed to get config for server %s, uid %s: %v", serverName, uid, err)
h.sendErrorResponse(http.StatusInternalServerError, "CONFIG_ERROR", fmt.Sprintf("Failed to get configuration: %s", err.Error()))
return true
}
response := struct {
Success bool `json:"success"`
Config map[string]string `json:"config"`
}{
Success: true,
Config: config,
}
responseBytes, _ := json.Marshal(response)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
http.StatusOK,
string(responseBytes),
nil, 0, "",
)
return true
}
// handleStoreConfig handles configuration storage requests
func (h *MCPConfigHandler) handleStoreConfig(serverName string, uid string, body []byte) bool {
// Parse request body
var requestBody struct {
Config map[string]string `json:"config"`
}
if err := json.Unmarshal(body, &requestBody); err != nil {
api.LogErrorf("Invalid request format for server %s, uid %s: %v", serverName, uid, err)
h.sendErrorResponse(http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("Invalid request format: %s", err.Error()))
return true
}
if requestBody.Config == nil {
h.sendErrorResponse(http.StatusBadRequest, "INVALID_REQUEST", "Config cannot be null")
return true
}
response, err := h.configStore.StoreConfig(serverName, uid, requestBody.Config)
if err != nil {
api.LogErrorf("Failed to store config for server %s, uid %s: %v", serverName, uid, err)
h.sendErrorResponse(http.StatusInternalServerError, "CONFIG_ERROR", fmt.Sprintf("Failed to store configuration: %s", err.Error()))
return true
}
responseBytes, _ := json.Marshal(response)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
http.StatusOK,
string(responseBytes),
nil, 0, "",
)
return true
}
// sendErrorResponse sends an error response with the specified status, code and message
func (h *MCPConfigHandler) sendErrorResponse(status int, code string, message string) {
response := &ConfigResponse{
Success: false,
Error: &struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: code,
Message: message,
},
}
responseBytes, _ := json.Marshal(response)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
status,
string(responseBytes),
nil, 0, "",
)
}
// GetEncodedConfig retrieves and encodes the configuration for a given server and uid
func (h *MCPConfigHandler) GetEncodedConfig(serverName string, uid string) (string, error) {
conf, err := h.configStore.GetConfig(serverName, uid)
if err != nil {
return "", fmt.Errorf("failed to get config: %w", err)
}
// Check if config exists and is not empty
if len(conf) > 0 {
// Convert config map to JSON string
configBytes, err := json.Marshal(conf)
if err != nil {
return "", fmt.Errorf("failed to marshal config: %w", err)
}
// Encode JSON string to base64
return base64.StdEncoding.EncodeToString(configBytes), nil
}
return "", nil
}

View File

@@ -0,0 +1,105 @@
package handler
import (
"encoding/json"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
)
const (
configExpiry = 7 * 24 * time.Hour
)
// GetConfigStoreKey returns the Redis channel name for the given session ID
func GetConfigStoreKey(serverName string, uid string) string {
return fmt.Sprintf("mcp-server-config:%s:%s", serverName, uid)
}
// ConfigResponse represents the response structure for configuration operations
type ConfigResponse struct {
Success bool `json:"success"`
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
}
// ConfigStore defines the interface for configuration storage operations
type ConfigStore interface {
// StoreConfig stores user configuration
StoreConfig(serverName string, uid string, config map[string]string) (*ConfigResponse, error)
// GetConfig retrieves user configuration
GetConfig(serverName string, uid string) (map[string]string, error)
}
// RedisConfigStore implements configuration storage using Redis
type RedisConfigStore struct {
redisClient *internal.RedisClient
}
// NewRedisConfigStore creates a new instance of Redis configuration storage
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
return &RedisConfigStore{
redisClient: redisClient,
}
}
// StoreConfig stores configuration in Redis
func (s *RedisConfigStore) StoreConfig(serverName string, uid string, config map[string]string) (*ConfigResponse, error) {
key := GetConfigStoreKey(serverName, uid)
// Convert config to JSON
configBytes, err := json.Marshal(config)
if err != nil {
return &ConfigResponse{
Success: false,
Error: &struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: "MARSHAL_ERROR",
Message: "Failed to marshal configuration",
},
}, err
}
// Store in Redis with expiry
err = s.redisClient.Set(key, string(configBytes), configExpiry)
if err != nil {
return &ConfigResponse{
Success: false,
Error: &struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: "REDIS_ERROR",
Message: "Failed to store configuration in Redis",
},
}, err
}
return &ConfigResponse{
Success: true,
}, nil
}
// GetConfig retrieves configuration from Redis
func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]string, error) {
key := GetConfigStoreKey(serverName, uid)
// Get from Redis
value, err := s.redisClient.Get(key)
if err != nil {
return nil, err
}
// Parse JSON
var config map[string]string
if err := json.Unmarshal([]byte(value), &config); err != nil {
return nil, err
}
return config, nil
}

View File

@@ -0,0 +1,129 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
type MCPRatelimitHandler struct {
redisClient *internal.RedisClient
callbacks api.FilterCallbackHandler
limit int // Maximum requests allowed per window
window int // Time window in seconds
whitelist []string // Whitelist of UIDs that bypass rate limiting
}
// MCPRatelimitConfig is the configuration for the rate limit handler
type MCPRatelimitConfig struct {
Limit int `json:"limit"`
Window int `json:"window"`
Whitelist []string `json:"white_list"` // List of UIDs that bypass rate limiting
}
// NewMCPRatelimitHandler creates a new rate limit handler
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
if conf == nil {
conf = &MCPRatelimitConfig{
Limit: 100,
Window: int(24 * time.Hour / time.Second), // 24 hours in seconds
Whitelist: []string{},
}
}
return &MCPRatelimitHandler{
redisClient: redisClient,
callbacks: callbacks,
limit: conf.Limit,
window: conf.Window,
whitelist: conf.Whitelist,
}
}
const (
// Lua script for rate limiting
LimitScript = `
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2])
return {ARGV[1], ARGV[1] - 1, ARGV[2]}
end
return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl}
`
)
type LimitContext struct {
Count int // Current request count
Remaining int // Remaining requests allowed
Reset int // Time until reset in seconds
}
func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body []byte) bool {
parts := strings.Split(path, "/")
if len(parts) < 3 {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return false
}
serverName := parts[1]
uid := parts[2]
// Check if the UID is in whitelist
for _, whitelistedUID := range h.whitelist {
if whitelistedUID == uid {
return true // Bypass rate limiting for whitelisted UIDs
}
}
// Build rate limit key using serverName, uid, window and limit
limitKey := fmt.Sprintf("mcp-server-limit:%s:%s:%d:%d", serverName, uid, h.window, h.limit)
keys := []string{limitKey}
args := []interface{}{h.limit, h.window}
result, err := h.redisClient.Eval(LimitScript, 1, keys, args)
if err != nil {
api.LogErrorf("Failed to check rate limit: %v", err)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
return false
}
// Process response
resultArray, ok := result.([]interface{})
if !ok || len(resultArray) != 3 {
api.LogErrorf("Invalid response format: %v", result)
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
return false
}
context := LimitContext{
Count: parseRedisValue(resultArray[0]),
Remaining: parseRedisValue(resultArray[1]),
Reset: parseRedisValue(resultArray[2]),
}
if context.Remaining < 0 {
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusTooManyRequests, "", nil, 0, "")
return false
}
return true
}
// parseRedisValue converts the value from Redis to an int
func parseRedisValue(value interface{}) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case string:
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return 0
}

View File

@@ -0,0 +1,76 @@
package internal
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
)
// Crypto handles encryption and decryption operations using AES-GCM
type Crypto struct {
gcm cipher.AEAD
}
func NewCrypto(secret string) (*Crypto, error) {
if secret == "" {
return nil, fmt.Errorf("secret cannot be empty")
}
// Generate a 32-byte key using SHA-256
hash := sha256.Sum256([]byte(secret))
block, err := aes.NewCipher(hash[:])
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %v", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %v", err)
}
return &Crypto{gcm: gcm}, nil
}
// Encrypt encrypts the plaintext data using AES-GCM
func (c *Crypto) Encrypt(plaintext []byte) (string, error) {
// Generate random nonce
nonce := make([]byte, c.gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt and authenticate data
ciphertext := c.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the encrypted string using AES-GCM
func (c *Crypto) Decrypt(encryptedStr string) ([]byte, error) {
// Decode base64
ciphertext, err := base64.StdEncoding.DecodeString(encryptedStr)
if err != nil {
return nil, fmt.Errorf("invalid encrypted data format")
}
// Check if the ciphertext is too short
if len(ciphertext) < c.gcm.NonceSize() {
return nil, fmt.Errorf("invalid encrypted data length")
}
// Extract nonce and ciphertext
nonce := ciphertext[:c.gcm.NonceSize()]
ciphertext = ciphertext[c.gcm.NonceSize():]
// Decrypt and verify data
plaintext, err := c.gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decryption failed")
}
return plaintext, nil
}

View File

@@ -10,35 +10,42 @@ import (
) )
type RedisConfig struct { type RedisConfig struct {
Address string address string
Username string username string
Password string password string
DB int db int
secret string // Encryption key
} }
func ParseRedisConfig(config map[string]any) (*RedisConfig, error) { // ParseRedisConfig parses Redis configuration from a map
func ParseRedisConfig(config map[string]interface{}) (*RedisConfig, error) {
c := &RedisConfig{} c := &RedisConfig{}
// address is required // address is required
addr, ok := config["address"].(string) if addr, ok := config["address"].(string); ok && addr != "" {
if !ok { c.address = addr
return nil, fmt.Errorf("address is required and must be a string") } else {
return nil, fmt.Errorf("address is required and must be a non-empty string")
} }
c.Address = addr
// username is optional // username is optional
if username, ok := config["username"].(string); ok { if username, ok := config["username"].(string); ok {
c.Username = username c.username = username
} }
// password is optional // password is optional
if password, ok := config["password"].(string); ok { if password, ok := config["password"].(string); ok {
c.Password = password c.password = password
} }
// db is optional, default to 0 // db is optional, default to 0
if db, ok := config["db"].(int); ok { if db, ok := config["db"].(int); ok {
c.DB = db c.db = db
}
// secret is optional
if secret, ok := config["secret"].(string); ok {
c.secret = secret
} }
return c, nil return c, nil
@@ -50,15 +57,16 @@ type RedisClient struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
config *RedisConfig config *RedisConfig
crypto *Crypto
} }
// NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server // NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server
func NewRedisClient(config *RedisConfig) (*RedisClient, error) { func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
client := redis.NewClient(&redis.Options{ client := redis.NewClient(&redis.Options{
Addr: config.Address, Addr: config.address,
Username: config.Username, Username: config.username,
Password: config.Password, Password: config.password,
DB: config.DB, DB: config.db,
}) })
// Ping the Redis server to check the connection // Ping the Redis server to check the connection
@@ -69,11 +77,22 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
api.LogDebugf("Connected to Redis: %s", pong) api.LogDebugf("Connected to Redis: %s", pong)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var crypto *Crypto
if config.secret != "" {
crypto, err = NewCrypto(config.secret)
if err != nil {
cancel()
return nil, err
}
}
redisClient := &RedisClient{ redisClient := &RedisClient{
client: client, client: client,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
config: config, config: config,
crypto: crypto,
} }
// Start keep-alive check // Start keep-alive check
@@ -117,10 +136,10 @@ func (r *RedisClient) reconnect() error {
// Create new client // Create new client
r.client = redis.NewClient(&redis.Options{ r.client = redis.NewClient(&redis.Options{
Addr: r.config.Address, Addr: r.config.address,
Username: r.config.Username, Username: r.config.username,
Password: r.config.Password, Password: r.config.password,
DB: r.config.DB, DB: r.config.db,
}) })
// Test the new connection // Test the new connection
@@ -150,6 +169,12 @@ func (r *RedisClient) Subscribe(channel string, stopChan chan struct{}, callback
} }
go func() { go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("Redis Subscribe recovered from panic: %v", r)
}
}()
defer func() { defer func() {
pubsub.Close() pubsub.Close()
api.LogDebugf("Closed subscription to channel %s", channel) api.LogDebugf("Closed subscription to channel %s", channel)
@@ -184,7 +209,19 @@ func (r *RedisClient) Subscribe(channel string, stopChan chan struct{}, callback
// Set sets the value of a key in Redis // Set sets the value of a key in Redis
func (r *RedisClient) Set(key string, value string, expiration time.Duration) error { func (r *RedisClient) Set(key string, value string, expiration time.Duration) error {
err := r.client.Set(r.ctx, key, value, expiration).Err() var finalValue string
if r.crypto != nil {
// Encrypt the data
encryptedValue, err := r.crypto.Encrypt([]byte(value))
if err != nil {
return fmt.Errorf("failed to encrypt value: %w", err)
}
finalValue = encryptedValue
} else {
finalValue = value
}
err := r.client.Set(r.ctx, key, finalValue, expiration).Err()
if err != nil { if err != nil {
return fmt.Errorf("failed to set key: %w", err) return fmt.Errorf("failed to set key: %w", err)
} }
@@ -193,13 +230,23 @@ func (r *RedisClient) Set(key string, value string, expiration time.Duration) er
// Get retrieves the value of a key from Redis // Get retrieves the value of a key from Redis
func (r *RedisClient) Get(key string) (string, error) { func (r *RedisClient) Get(key string) (string, error) {
val, err := r.client.Get(r.ctx, key).Result() value, err := r.client.Get(r.ctx, key).Result()
if err == redis.Nil { if err == redis.Nil {
return "", fmt.Errorf("key does not exist") return "", fmt.Errorf("key does not exist")
} else if err != nil { } else if err != nil {
return "", fmt.Errorf("failed to get key: %w", err) return "", fmt.Errorf("failed to get key: %w", err)
} }
return val, nil
if r.crypto != nil {
// Decrypt the data
decryptedValue, err := r.crypto.Decrypt(value)
if err != nil {
return "", fmt.Errorf("failed to decrypt value: %w", err)
}
return string(decryptedValue), nil
}
return value, nil
} }
// Close closes the Redis client and stops the keepalive goroutine // Close closes the Redis client and stops the keepalive goroutine
@@ -207,3 +254,13 @@ func (r *RedisClient) Close() error {
r.cancel() r.cancel()
return r.client.Close() return r.client.Close()
} }
// Eval executes a Lua script
func (r *RedisClient) Eval(script string, numKeys int, keys []string, args []interface{}) (interface{}, error) {
result, err := r.client.Eval(r.ctx, script, keys, args...).Result()
if err != nil {
return nil, fmt.Errorf("failed to execute Lua script: %w", err)
}
return result, nil
}

View File

@@ -419,6 +419,16 @@ func (s *MCPServer) HandleMessage(
) )
} }
return s.handleToolCall(ctx, baseMessage.ID, request) return s.handleToolCall(ctx, baseMessage.ID, request)
case "":
var response mcp.JSONRPCResponse
if err := json.Unmarshal(message, &response); err != nil {
return createErrorResponse(
baseMessage.ID,
mcp.INVALID_REQUEST,
"Invalid message format",
)
}
return nil
default: default:
return createErrorResponse( return createErrorResponse(
baseMessage.ID, baseMessage.ID,

View File

@@ -28,10 +28,6 @@ type SSEServer struct {
redisClient *RedisClient // Redis client for pub/sub redisClient *RedisClient // Redis client for pub/sub
} }
func (s *SSEServer) SetBaseURL(baseURL string) {
s.baseURL = baseURL
}
func (s *SSEServer) GetMessageEndpoint() string { func (s *SSEServer) GetMessageEndpoint() string {
return s.messageEndpoint return s.messageEndpoint
} }
@@ -148,6 +144,12 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
// Start health check handler // Start health check handler
go func() { go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("Health check handler recovered from panic: %v", r)
}
}()
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -158,7 +160,15 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
case <-ticker.C: case <-ticker.C:
// Send health check message // Send health check message
currentTime := time.Now().Format(time.RFC3339) currentTime := time.Now().Format(time.RFC3339)
healthCheckEvent := fmt.Sprintf(": ping - %s\n\n", currentTime) pingRequest := mcp.JSONRPCRequest{
JSONRPC: mcp.JSONRPC_VERSION,
ID: currentTime,
Request: mcp.Request{
Method: "ping",
},
}
pingData, _ := json.Marshal(pingRequest)
healthCheckEvent := fmt.Sprintf("event: message\ndata: %s\n\n", pingData)
if err := s.redisClient.Publish(channel, healthCheckEvent); err != nil { if err := s.redisClient.Publish(channel, healthCheckEvent); err != nil {
api.LogErrorf("Failed to send health check: %v", err) api.LogErrorf("Failed to send health check: %v", err)
} }
@@ -202,7 +212,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
if response != nil { if response != nil {
eventData, _ := json.Marshal(response) eventData, _ := json.Marshal(response)
if sessionID != "" { if sessionID != "" && s.redisClient != nil {
channel := GetSSEChannelName(sessionID) channel := GetSSEChannelName(sessionID)
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData)) publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))

View File

@@ -154,6 +154,12 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
nacosRegistry.RegisterToolChangeEventListener(&listener) nacosRegistry.RegisterToolChangeEventListener(&listener)
go func() { go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("NacosToolsListRefresh recovered from panic: %v", r)
}
}()
for { for {
if nacosRegistry.refreshToolsList() { if nacosRegistry.refreshToolsList() {
resetToolsToMcpServer(mcpServer, nacosRegistry) resetToolsToMcpServer(mcpServer, nacosRegistry)

View File

@@ -29,7 +29,7 @@ if [ ! -n "$INNER_GO_FILTER_NAME" ]; then
name=${file##*/} name=${file##*/}
echo "🚀 Build Go Filter: $name" echo "🚀 Build Go Filter: $name"
GO_FILTER_NAME=${name} GOARCH=${TARGET_ARCH} make build GO_FILTER_NAME=${name} GOARCH=${TARGET_ARCH} make build
cp ${GO_FILTERS_DIR}/${file}/${name}_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR} cp ${GO_FILTERS_DIR}/${file}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
fi fi
done done
else else