fix: Refactor MCP Server into MCP Session and MCP Server (#2120)

This commit is contained in:
Jingze
2025-04-28 13:42:14 +08:00
committed by GitHub
parent e381806ba0
commit c382635e7f
29 changed files with 1025 additions and 517 deletions

View File

@@ -24,7 +24,7 @@ WORKDIR /workspace
COPY . .
WORKDIR /workspace/$GO_FILTER_NAME
WORKDIR /workspace
RUN go mod tidy
RUN if [ "$GOARCH" = "arm64" ]; then \

View File

@@ -1,4 +1,4 @@
GO_FILTER_NAME ?= mcp-server
GO_FILTER_NAME ?= golang-filter
GOPROXY := $(shell go env GOPROXY)
GOARCH ?= amd64
@@ -8,5 +8,5 @@ build:
--build-arg GO_FILTER_NAME=${GO_FILTER_NAME} \
--build-arg GOARCH=${GOARCH} \
-t ${GO_FILTER_NAME} \
--output ./${GO_FILTER_NAME} \
--output . \
.

View File

@@ -28,7 +28,7 @@ http_filters:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config
library_id: my-go-filter
library_path: "./my-go-filter.so"
library_path: "./go-filter.so"
plugin_name: my-go-filter
plugin_config:
"@type": type.googleapis.com/xds.type.v3.TypedStruct
@@ -43,5 +43,5 @@ http_filters:
使用以下命令可以快速构建 golang filter 插件:
```bash
GO_FILTER_NAME=mcp-server make build
make build
```

View File

@@ -1,4 +1,4 @@
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
module github.com/alibaba/higress/plugins/golang-filter
go 1.23

View File

@@ -0,0 +1,25 @@
package main
import (
"net/http"
mcp_server "github.com/alibaba/higress/plugins/golang-filter/mcp-server"
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
)
func init() {
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_session.Name, mcp_session.FilterFactory, &mcp_session.Parser{})
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_server.Name, mcp_server.FilterFactory, &mcp_server.Parser{})
go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("PProf server recovered from panic: %v", r)
}
}()
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
}()
}
func main() {}

View File

@@ -1,64 +1,39 @@
package main
package mcp_server
import (
"fmt"
"net/http"
_ "net/http/pprof"
xds "github.com/cncf/xds/go/xds/type/v3"
"google.golang.org/protobuf/types/known/anypb"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
xds "github.com/cncf/xds/go/xds/type/v3"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
"google.golang.org/protobuf/types/known/anypb"
)
const Name = "mcp-session"
const Name = "mcp-server"
const Version = "1.0.0"
const DefaultServerName = "defaultServer"
const ConfigPathSuffix = "/config"
func init() {
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
go func() {
defer func() {
if r := recover(); r != nil {
api.LogErrorf("PProf server recovered from panic: %v", r)
}
}()
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
}()
type SSEServerWrapper struct {
BaseServer *common.SSEServer
DomainList []string
}
type config struct {
ssePathSuffix string
redisClient *internal.RedisClient
servers []*internal.SSEServer
defaultServer *internal.SSEServer
matchList []internal.MatchRule
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
servers []*SSEServerWrapper
}
func (c *config) Destroy() {
if c.redisClient != nil {
api.LogDebug("Closing Redis client")
c.redisClient.Close()
}
for _, server := range c.servers {
server.Close()
server.BaseServer.Close()
}
}
type parser struct {
type Parser struct {
}
// Parse the filter configuration
func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
configStruct := &xds.TypedStruct{}
if err := any.UnmarshalTo(configStruct); err != nil {
return nil, err
@@ -66,82 +41,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
v := configStruct.Value
conf := &config{
matchList: make([]internal.MatchRule, 0),
servers: make([]*internal.SSEServer, 0),
servers: make([]*SSEServerWrapper, 0),
}
// Parse match_list if exists
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
for _, item := range matchList {
if ruleMap, ok := item.(map[string]interface{}); ok {
rule := internal.MatchRule{}
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
rule.MatchRuleDomain = domain
}
if path, ok := ruleMap["match_rule_path"].(string); ok {
rule.MatchRulePath = path
}
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
rule.MatchRuleType = internal.RuleType(ruleType)
}
conf.matchList = append(conf.matchList, rule)
}
}
}
// Redis configuration is optional
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
if err != nil {
return nil, fmt.Errorf("failed to parse redis config: %w", err)
}
redisClient, err := internal.NewRedisClient(redisConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
}
conf.redisClient = redisClient
api.LogDebug("Redis client initialized")
} else {
api.LogDebug("Redis configuration not provided, running without Redis")
}
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
if !ok {
enableUserLevelServer = false
if conf.redisClient == nil {
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
}
}
conf.enableUserLevelServer = enableUserLevelServer
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
rateLimitConfig := &handler.MCPRatelimitConfig{}
if limit, ok := rateLimit["limit"].(float64); ok {
rateLimitConfig.Limit = int(limit)
}
if window, ok := rateLimit["window"].(float64); ok {
rateLimitConfig.Window = int(window)
}
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
for _, item := range whiteList {
if uid, ok := item.(string); ok {
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
}
}
}
if errorText, ok := rateLimit["error_text"].(string); ok {
rateLimitConfig.ErrorText = errorText
}
conf.rateLimitConfig = rateLimitConfig
}
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
if !ok || ssePathSuffix == "" {
return nil, fmt.Errorf("sse path suffix is not set or empty")
}
conf.ssePathSuffix = ssePathSuffix
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
if !ok {
api.LogDebug("No servers are configured")
@@ -153,19 +55,33 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
if !ok {
return nil, fmt.Errorf("server config must be an object")
}
serverType, ok := serverConfigMap["type"].(string)
if !ok {
return nil, fmt.Errorf("server type is not set")
}
serverPath, ok := serverConfigMap["path"].(string)
if !ok {
return nil, fmt.Errorf("server %s path is not set", serverType)
}
serverDomainList := []string{}
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
for _, domain := range domainList {
if domainStr, ok := domain.(string); ok {
serverDomainList = append(serverDomainList, domainStr)
}
}
} else {
serverDomainList = []string{"*"}
}
serverName, ok := serverConfigMap["name"].(string)
if !ok {
return nil, fmt.Errorf("server %s name is not set", serverType)
}
server := internal.GlobalRegistry.GetServer(serverType)
server := common.GlobalRegistry.GetServer(serverType)
if server == nil {
return nil, fmt.Errorf("server %s is not registered", serverType)
@@ -186,50 +102,37 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
}
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
internal.WithRedisClient(conf.redisClient),
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
internal.WithMessageEndpoint(serverPath)))
conf.servers = append(conf.servers, &SSEServerWrapper{
BaseServer: common.NewSSEServer(serverInstance,
common.WithRedisClient(common.GlobalRedisClient),
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
common.WithMessageEndpoint(serverPath)),
DomainList: serverDomainList,
})
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
}
return conf, nil
}
func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
parentConfig := parent.(*config)
childConfig := child.(*config)
newConfig := *parentConfig
if childConfig.redisClient != nil {
newConfig.redisClient = childConfig.redisClient
}
if childConfig.ssePathSuffix != "" {
newConfig.ssePathSuffix = childConfig.ssePathSuffix
}
if childConfig.servers != nil {
newConfig.servers = childConfig.servers
}
if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer
}
if childConfig.matchList != nil {
newConfig.matchList = childConfig.matchList
}
return &newConfig
}
func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
conf, ok := c.(*config)
if !ok {
panic("unexpected config type")
}
return &filter{
callbacks: callbacks,
config: conf,
stopChan: make(chan struct{}),
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
config: conf,
callbacks: callbacks,
}
}
func main() {}

View File

@@ -1,104 +1,41 @@
package main
package mcp_server
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
)
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
// Because api.PassThroughStreamFilter provides a default implementation.
type filter struct {
api.PassThroughStreamFilter
callbacks api.FilterCallbackHandler
path string
config *config
stopChan chan struct{}
req *http.Request
serverName string
message bool
proxyURL *url.URL
skip bool
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
ratelimit bool
mcpRatelimitHandler *handler.MCPRatelimitHandler
config *config
req *http.Request
message bool
path string
}
type RequestURL struct {
method string
scheme string
host string
path string
baseURL string
parsedURL *url.URL
internalIP bool
}
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
method, _ := header.Get(":method")
scheme, _ := header.Get(":scheme")
host, _ := header.Get(":authority")
path, _ := header.Get(":path")
internalIP, _ := header.Get("x-envoy-internal")
baseURL := fmt.Sprintf("%s://%s", scheme, host)
parsedURL, _ := url.Parse(path)
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
}
// Callbacks which are called in request path
// The endStream is true if the request doesn't have body
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
url := NewRequestURL(header)
f.path = url.parsedURL.Path
// Check if request matches any rule in match_list
if !internal.IsMatch(f.config.matchList, url.host, f.path) {
f.skip = true
api.LogDebugf("Request does not match any rule in match_list: %s", url.parsedURL.String())
return api.Continue
}
url := common.NewRequestURL(header)
f.path = url.ParsedURL.Path
for _, server := range f.config.servers {
if f.path == server.GetSSEEndpoint() {
if url.method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.serverName = server.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
api.LogDebugf("%s SSE connection started", server.GetServerName())
return api.LocalReply
} else if f.path == server.GetMessageEndpoint() {
if url.method != http.MethodPost {
if common.MatchDomainList(url.ParsedURL.Host, server.DomainList) && url.ParsedURL.Path == server.BaseServer.GetMessageEndpoint() {
if url.Method != http.MethodPost {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
return api.LocalReply
}
// Create a new http.Request object
f.req = &http.Request{
Method: url.method,
URL: url.parsedURL,
Method: url.Method,
URL: url.ParsedURL,
Header: make(http.Header),
}
api.LogDebugf("Message request: %v", url.parsedURL)
api.LogDebugf("Message request: %v", url.ParsedURL)
// Copy headers from api.RequestHeaderMap to http.Header
header.Range(func(key, value string) bool {
f.req.Header.Add(key, value)
@@ -113,209 +50,33 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
}
f.req = &http.Request{
Method: url.method,
URL: url.parsedURL,
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !url.internalIP {
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
return api.LocalReply
}
f.userLevelConfig = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
f.proxyURL = url.parsedURL
if f.config.enableUserLevelServer {
parts := strings.Split(url.parsedURL.Path, "/")
if len(parts) >= 3 {
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
}
}
f.ratelimit = true
}
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if url.method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.config.defaultServer = internal.NewSSEServer(internal.NewMCPServer(DefaultServerName, Version),
internal.WithSSEEndpoint(f.config.ssePathSuffix),
internal.WithMessageEndpoint(strings.TrimSuffix(url.parsedURL.Path, f.config.ssePathSuffix)),
internal.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
return api.LocalReply
return api.Continue
}
// DecodeData might be called multiple times during handling the request body.
// The endStream is true when handling the last piece of the body.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.message {
for _, server := range f.config.servers {
if f.path == server.GetMessageEndpoint() {
if f.path == server.BaseServer.GetMessageEndpoint() {
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer with complete body
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
httpStatus := server.BaseServer.HandleMessage(recorder, f.req, buffer.Bytes())
f.message = false
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
return api.LocalReply
}
}
} else if f.userLevelConfig {
// Handle config POST request
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
return api.LocalReply
} else if f.ratelimit {
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
api.LogDebugf("Not a tools call request, skipping ratelimit")
return api.Continue
}
parts := strings.Split(f.req.URL.Path, "/")
if len(parts) < 3 {
api.LogWarnf("Access denied: no valid uid found")
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
return api.LocalReply
}
}
}
return api.Continue
}
// Callbacks which are called in response path
// The endStream is true if the response doesn't have body
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if f.serverName != "" {
if f.config.redisClient != nil {
header.Set("Content-Type", "text/event-stream")
header.Set("Cache-Control", "no-cache")
header.Set("Connection", "keep-alive")
header.Set("Access-Control-Allow-Origin", "*")
header.Del("Content-Length")
} else {
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
}
return api.Continue
}
return api.Continue
}
// EncodeData might be called multiple times during handling the response body.
// The endStream is true when handling the last piece of the body.
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.proxyURL != nil && f.config.redisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" {
channel := internal.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := f.config.redisClient.Publish(channel, eventData)
if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
}
}
}
if f.serverName != "" {
if f.config.redisClient != nil {
// handle specific server
for _, server := range f.config.servers {
if f.serverName == server.GetServerName() {
buffer.Reset()
server.HandleSSE(f.callbacks, f.stopChan)
return api.Running
}
}
// handle default server
if f.serverName == f.config.defaultServer.GetServerName() {
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
}
return api.Continue
} else {
buffer.SetString(RedisNotEnabledResponseBody)
return api.Continue
}
}
return api.Continue
}
// OnDestroy stops the goroutine
func (f *filter) OnDestroy(reason api.DestroyReason) {
api.LogDebugf("OnDestroy: reason=%v", reason)
if f.serverName != "" && f.stopChan != nil {
select {
case <-f.stopChan:
return
default:
api.LogDebug("Stopping SSE connection")
close(f.stopChan)
}
}
}
// check if the request is a tools/call request
func checkJSONRPCMethod(body []byte, method string) bool {
var request mcp.CallToolRequest
if err := json.Unmarshal(body, &request); err != nil {
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
return true
}
return request.Method == method
}

View File

@@ -5,8 +5,8 @@ import (
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
"github.com/nacos-group/nacos-sdk-go/v2/clients"
@@ -15,7 +15,7 @@ import (
)
func init() {
internal.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
common.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
}
type NacosConfig struct {
@@ -28,7 +28,7 @@ type NacosConfig struct {
}
type McpServerToolsChangeListener struct {
mcpServer *internal.MCPServer
mcpServer *common.MCPServer
}
func (l *McpServerToolsChangeListener) OnToolChanged(reg registry.McpServerRegistry) {
@@ -137,8 +137,8 @@ func (c *NacosConfig) ParseConfig(config map[string]any) error {
return nil
}
func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error) {
mcpServer := internal.NewMCPServer(
func (c *NacosConfig) NewServer(serverName string) (*common.MCPServer, error) {
mcpServer := common.NewMCPServer(
serverName,
"1.0.0",
)
@@ -170,11 +170,11 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
return mcpServer, nil
}
func resetToolsToMcpServer(mcpServer *internal.MCPServer, reg registry.McpServerRegistry) {
wrappedTools := []internal.ServerTool{}
func resetToolsToMcpServer(mcpServer *common.MCPServer, reg registry.McpServerRegistry) {
wrappedTools := []common.ServerTool{}
tools := reg.ListToolsDesciption()
for _, tool := range tools {
wrappedTools = append(wrappedTools, internal.ServerTool{
wrappedTools = append(wrappedTools, common.ServerTool{
Tool: mcp.NewToolWithRawSchema(tool.Name, tool.Description, tool.InputSchema),
Handler: registry.HandleRegistryToolsCall(reg),
})

View File

@@ -9,7 +9,7 @@ import (
"net/url"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/mark3labs/mcp-go/mcp"
)
@@ -204,7 +204,7 @@ func CommonRemoteCall(reg McpServerRegistry, toolName string, parameters map[str
return remoteHandle.HandleToolCall(ctx, parameters)
}
func HandleRegistryToolsCall(reg McpServerRegistry) internal.ToolHandlerFunc {
func HandleRegistryToolsCall(reg McpServerRegistry) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
return CommonRemoteCall(reg, request.Params.Name, arguments)

View File

@@ -4,7 +4,7 @@ import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
@@ -12,7 +12,7 @@ import (
const Version = "1.0.0"
func init() {
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
common.GlobalRegistry.RegisterServer("database", &DBConfig{})
}
type DBConfig struct {
@@ -41,11 +41,11 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
return nil
}
func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
mcpServer := internal.NewMCPServer(
func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
mcpServer := common.NewMCPServer(
serverName,
Version,
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
common.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
)
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())

View File

@@ -5,12 +5,12 @@ import (
"encoding/json"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/mark3labs/mcp-go/mcp"
)
// HandleQueryTool handles SQL query execution
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
message, ok := arguments["sql"].(string)

View File

@@ -1,4 +1,4 @@
package internal
package common
import (
"crypto/aes"

View File

@@ -1,4 +1,4 @@
package internal
package common
import (
"regexp"
@@ -23,6 +23,27 @@ type MatchRule struct {
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
}
// ParseMatchList parses the match list from the config
func ParseMatchList(matchListConfig []interface{}) []MatchRule {
matchList := make([]MatchRule, 0)
for _, item := range matchListConfig {
if ruleMap, ok := item.(map[string]interface{}); ok {
rule := MatchRule{}
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
rule.MatchRuleDomain = domain
}
if path, ok := ruleMap["match_rule_path"].(string); ok {
rule.MatchRulePath = path
}
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
rule.MatchRuleType = RuleType(ruleType)
}
matchList = append(matchList, rule)
}
}
return matchList
}
// convertWildcardToRegex converts wildcard pattern to regex pattern
func convertWildcardToRegex(pattern string) string {
pattern = regexp.QuoteMeta(pattern)
@@ -87,3 +108,13 @@ func IsMatch(rules []MatchRule, host, path string) bool {
}
return false
}
// MatchDomainList checks if the domain matches any of the domains in the list
func MatchDomainList(domain string, domainList []string) bool {
for _, d := range domainList {
if matchDomain(domain, d) {
return true
}
}
return false
}

View File

@@ -1,4 +1,4 @@
package internal
package common
import (
"context"
@@ -9,6 +9,8 @@ import (
"github.com/go-redis/redis/v8"
)
var GlobalRedisClient *RedisClient
type RedisConfig struct {
address string
username string
@@ -249,6 +251,18 @@ func (r *RedisClient) Get(key string) (string, error) {
return value, nil
}
// Expire sets the expiration time for a key
func (r *RedisClient) Expire(key string, expiration time.Duration) error {
ok, err := r.client.Expire(r.ctx, key, expiration).Result()
if err != nil {
return fmt.Errorf("failed to set expiration for key: %w", err)
}
if !ok {
return fmt.Errorf("key does not exist")
}
return nil
}
// Close closes the Redis client and stops the keepalive goroutine
func (r *RedisClient) Close() error {
r.cancel()

View File

@@ -1,4 +1,4 @@
package internal
package common
var GlobalRegistry = NewServerRegistry()

View File

@@ -1,4 +1,4 @@
package internal
package common
import (
"context"
@@ -243,6 +243,7 @@ func (s *MCPServer) HandleMessage(
message json.RawMessage,
) mcp.JSONRPCMessage {
// Add server to context
ctx = context.WithValue(ctx, serverKey{}, s)
var baseMessage struct {

View File

@@ -1,4 +1,4 @@
package internal
package common
import (
"encoding/json"
@@ -210,15 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
var status int
// Only send response if there is one (not for notifications)
if response != nil {
eventData, _ := json.Marshal(response)
if sessionID != "" && s.redisClient != nil {
channel := GetSSEChannelName(sessionID)
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))
if publishErr != nil {
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
}
w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted
} else {

View File

@@ -0,0 +1,30 @@
package common
import (
"fmt"
"net/url"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
type RequestURL struct {
Method string
Scheme string
Host string
Path string
BaseURL string
ParsedURL *url.URL
InternalIP bool
}
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
method, _ := header.Get(":method")
scheme, _ := header.Get(":scheme")
host, _ := header.Get(":authority")
path, _ := header.Get(":path")
internalIP, _ := header.Get("x-envoy-internal")
baseURL := fmt.Sprintf("%s://%s", scheme, host)
parsedURL, _ := url.Parse(path)
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
}

View File

@@ -0,0 +1,143 @@
package mcp_session
import (
"fmt"
_ "net/http/pprof"
xds "github.com/cncf/xds/go/xds/type/v3"
"google.golang.org/protobuf/types/known/anypb"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
const Name = "mcp-session"
const Version = "1.0.0"
const ConfigPathSuffix = "/config"
const DefaultServerName = "higress-mcp-server"
var GlobalSSEPathSuffix = "/sse"
type config struct {
matchList []common.MatchRule
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer
}
func (c *config) Destroy() {
if common.GlobalRedisClient != nil {
api.LogDebug("Closing Redis client")
common.GlobalRedisClient.Close()
}
}
type Parser struct {
}
// Parse the filter configuration
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
configStruct := &xds.TypedStruct{}
if err := any.UnmarshalTo(configStruct); err != nil {
return nil, err
}
v := configStruct.Value
conf := &config{
matchList: make([]common.MatchRule, 0),
}
// Parse match_list if exists
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
conf.matchList = common.ParseMatchList(matchList)
}
// Redis configuration is optional
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
redisConfig, err := common.ParseRedisConfig(redisConfigMap)
if err != nil {
return nil, fmt.Errorf("failed to parse redis config: %w", err)
}
redisClient, err := common.NewRedisClient(redisConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
}
common.GlobalRedisClient = redisClient
api.LogDebug("Redis client initialized")
} else {
api.LogDebug("Redis configuration not provided, running without Redis")
}
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
if !ok {
enableUserLevelServer = false
if common.GlobalRedisClient == nil {
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
}
}
conf.enableUserLevelServer = enableUserLevelServer
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
rateLimitConfig := &handler.MCPRatelimitConfig{}
if limit, ok := rateLimit["limit"].(float64); ok {
rateLimitConfig.Limit = int(limit)
}
if window, ok := rateLimit["window"].(float64); ok {
rateLimitConfig.Window = int(window)
}
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
for _, item := range whiteList {
if uid, ok := item.(string); ok {
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
}
}
}
if errorText, ok := rateLimit["error_text"].(string); ok {
rateLimitConfig.ErrorText = errorText
}
conf.rateLimitConfig = rateLimitConfig
}
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
if !ok || ssePathSuffix == "" {
return nil, fmt.Errorf("sse path suffix is not set or empty")
}
GlobalSSEPathSuffix = ssePathSuffix
return conf, nil
}
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
parentConfig := parent.(*config)
childConfig := child.(*config)
newConfig := *parentConfig
if childConfig.matchList != nil {
newConfig.matchList = childConfig.matchList
}
newConfig.enableUserLevelServer = childConfig.enableUserLevelServer
if childConfig.rateLimitConfig != nil {
newConfig.rateLimitConfig = childConfig.rateLimitConfig
}
if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer
}
return &newConfig
}
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
conf, ok := c.(*config)
if !ok {
panic("unexpected config type")
}
return &filter{
callbacks: callbacks,
config: conf,
stopChan: make(chan struct{}),
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
}
}

View File

@@ -0,0 +1,237 @@
package mcp_session
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
)
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
// Because api.PassThroughStreamFilter provides a default implementation.
type filter struct {
api.PassThroughStreamFilter
callbacks api.FilterCallbackHandler
path string
config *config
stopChan chan struct{}
req *http.Request
serverName string
proxyURL *url.URL
skip bool
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
ratelimit bool
mcpRatelimitHandler *handler.MCPRatelimitHandler
}
// Callbacks which are called in request path
// The endStream is true if the request doesn't have body
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
url := common.NewRequestURL(header)
f.path = url.ParsedURL.Path
// Check if request matches any rule in match_list
if !common.IsMatch(f.config.matchList, url.Host, f.path) {
f.skip = true
api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String())
return api.Continue
}
f.req = &http.Request{
Method: url.Method,
URL: url.ParsedURL,
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !url.InternalIP {
api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String())
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
return api.LocalReply
}
f.userLevelConfig = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) {
f.proxyURL = url.ParsedURL
if f.config.enableUserLevelServer {
parts := strings.Split(url.ParsedURL.Path, "/")
if len(parts) >= 3 {
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
}
}
f.ratelimit = true
}
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if url.Method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
common.WithRedisClient(common.GlobalRedisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
return api.LocalReply
}
// DecodeData might be called multiple times during handling the request body.
// The endStream is true when handling the last piece of the body.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.userLevelConfig {
// Handle config POST request
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
return api.LocalReply
} else if f.ratelimit {
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
api.LogDebugf("Not a tools call request, skipping ratelimit")
return api.Continue
}
parts := strings.Split(f.req.URL.Path, "/")
if len(parts) < 3 {
api.LogWarnf("Access denied: no valid uid found")
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
return api.LocalReply
}
}
}
return api.Continue
}
// Callbacks which are called in response path
// The endStream is true if the response doesn't have body
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if f.serverName != "" {
if common.GlobalRedisClient != nil {
header.Set("Content-Type", "text/event-stream")
header.Set("Cache-Control", "no-cache")
header.Set("Connection", "keep-alive")
header.Set("Access-Control-Allow-Origin", "*")
header.Del("Content-Length")
} else {
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
}
return api.Continue
}
return api.Continue
}
// EncodeData might be called multiple times during handling the response body.
// The endStream is true when handling the last piece of the body.
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if f.skip {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.proxyURL != nil && common.GlobalRedisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" {
channel := common.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
}
}
}
if f.serverName != "" {
if common.GlobalRedisClient != nil {
// handle default server
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
} else {
buffer.SetString(RedisNotEnabledResponseBody)
return api.Continue
}
}
return api.Continue
}
// OnDestroy stops the goroutine
func (f *filter) OnDestroy(reason api.DestroyReason) {
api.LogDebugf("OnDestroy: reason=%v", reason)
if f.serverName != "" && f.stopChan != nil {
select {
case <-f.stopChan:
return
default:
api.LogDebug("Stopping SSE connection")
close(f.stopChan)
}
}
}
// check if the request is a tools/call request
func checkJSONRPCMethod(body []byte, method string) bool {
var request mcp.CallToolRequest
if err := json.Unmarshal(body, &request); err != nil {
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
return true
}
return request.Method == method
}

View File

@@ -7,7 +7,7 @@ import (
"net/http"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
@@ -18,7 +18,7 @@ type MCPConfigHandler struct {
}
// NewMCPConfigHandler creates a new instance of MCP configuration handler
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
func NewMCPConfigHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
return &MCPConfigHandler{
configStore: NewRedisConfigStore(redisClient),
callbacks: callbacks,

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
)
const (
@@ -36,11 +36,11 @@ type ConfigStore interface {
// RedisConfigStore implements configuration storage using Redis
type RedisConfigStore struct {
redisClient *internal.RedisClient
redisClient *common.RedisClient
}
// NewRedisConfigStore creates a new instance of Redis configuration storage
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
func NewRedisConfigStore(redisClient *common.RedisClient) ConfigStore {
return &RedisConfigStore{
redisClient: redisClient,
}
@@ -101,5 +101,11 @@ func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]
return nil, err
}
// Refresh TTL
if err := s.redisClient.Expire(key, configExpiry); err != nil {
// Log error but don't fail the request
fmt.Printf("Failed to refresh TTL for key %s: %v\n", key, err)
}
return config, nil
}

View File

@@ -8,13 +8,13 @@ import (
"strings"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
type MCPRatelimitHandler struct {
redisClient *internal.RedisClient
redisClient *common.RedisClient
callbacks api.FilterCallbackHandler
limit int // Maximum requests allowed per window
window int // Time window in seconds
@@ -31,7 +31,7 @@ type MCPRatelimitConfig struct {
}
// NewMCPRatelimitHandler creates a new rate limit handler
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
func NewMCPRatelimitHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
if conf == nil {
conf = &MCPRatelimitConfig{
Limit: 100,