mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
fix: Refactor MCP Server into MCP Session and MCP Server (#2120)
This commit is contained in:
@@ -24,7 +24,7 @@ WORKDIR /workspace
|
||||
|
||||
COPY . .
|
||||
|
||||
WORKDIR /workspace/$GO_FILTER_NAME
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN go mod tidy
|
||||
RUN if [ "$GOARCH" = "arm64" ]; then \
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
GO_FILTER_NAME ?= mcp-server
|
||||
GO_FILTER_NAME ?= golang-filter
|
||||
GOPROXY := $(shell go env GOPROXY)
|
||||
GOARCH ?= amd64
|
||||
|
||||
@@ -8,5 +8,5 @@ build:
|
||||
--build-arg GO_FILTER_NAME=${GO_FILTER_NAME} \
|
||||
--build-arg GOARCH=${GOARCH} \
|
||||
-t ${GO_FILTER_NAME} \
|
||||
--output ./${GO_FILTER_NAME} \
|
||||
--output . \
|
||||
.
|
||||
@@ -28,7 +28,7 @@ http_filters:
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config
|
||||
library_id: my-go-filter
|
||||
library_path: "./my-go-filter.so"
|
||||
library_path: "./go-filter.so"
|
||||
plugin_name: my-go-filter
|
||||
plugin_config:
|
||||
"@type": type.googleapis.com/xds.type.v3.TypedStruct
|
||||
@@ -43,5 +43,5 @@ http_filters:
|
||||
使用以下命令可以快速构建 golang filter 插件:
|
||||
|
||||
```bash
|
||||
GO_FILTER_NAME=mcp-server make build
|
||||
make build
|
||||
```
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
|
||||
module github.com/alibaba/higress/plugins/golang-filter
|
||||
|
||||
go 1.23
|
||||
|
||||
25
plugins/golang-filter/main.go
Normal file
25
plugins/golang-filter/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
mcp_server "github.com/alibaba/higress/plugins/golang-filter/mcp-server"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
)
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_session.Name, mcp_session.FilterFactory, &mcp_session.Parser{})
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_server.Name, mcp_server.FilterFactory, &mcp_server.Parser{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||
}()
|
||||
}
|
||||
|
||||
func main() {}
|
||||
@@ -1,64 +1,39 @@
|
||||
package main
|
||||
package mcp_server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
const Name = "mcp-session"
|
||||
const Name = "mcp-server"
|
||||
const Version = "1.0.0"
|
||||
const DefaultServerName = "defaultServer"
|
||||
const ConfigPathSuffix = "/config"
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||
}()
|
||||
type SSEServerWrapper struct {
|
||||
BaseServer *common.SSEServer
|
||||
DomainList []string
|
||||
}
|
||||
|
||||
type config struct {
|
||||
ssePathSuffix string
|
||||
redisClient *internal.RedisClient
|
||||
servers []*internal.SSEServer
|
||||
defaultServer *internal.SSEServer
|
||||
matchList []internal.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
servers []*SSEServerWrapper
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
if c.redisClient != nil {
|
||||
api.LogDebug("Closing Redis client")
|
||||
c.redisClient.Close()
|
||||
}
|
||||
for _, server := range c.servers {
|
||||
server.Close()
|
||||
server.BaseServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
type Parser struct {
|
||||
}
|
||||
|
||||
// Parse the filter configuration
|
||||
func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
configStruct := &xds.TypedStruct{}
|
||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||
return nil, err
|
||||
@@ -66,82 +41,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
v := configStruct.Value
|
||||
|
||||
conf := &config{
|
||||
matchList: make([]internal.MatchRule, 0),
|
||||
servers: make([]*internal.SSEServer, 0),
|
||||
servers: make([]*SSEServerWrapper, 0),
|
||||
}
|
||||
|
||||
// Parse match_list if exists
|
||||
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
||||
for _, item := range matchList {
|
||||
if ruleMap, ok := item.(map[string]interface{}); ok {
|
||||
rule := internal.MatchRule{}
|
||||
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
||||
rule.MatchRuleDomain = domain
|
||||
}
|
||||
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
||||
rule.MatchRulePath = path
|
||||
}
|
||||
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
||||
rule.MatchRuleType = internal.RuleType(ruleType)
|
||||
}
|
||||
conf.matchList = append(conf.matchList, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redis configuration is optional
|
||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := internal.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
conf.redisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
enableUserLevelServer = false
|
||||
if conf.redisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
conf.enableUserLevelServer = enableUserLevelServer
|
||||
|
||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||
rateLimitConfig.Limit = int(limit)
|
||||
}
|
||||
if window, ok := rateLimit["window"].(float64); ok {
|
||||
rateLimitConfig.Window = int(window)
|
||||
}
|
||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||
for _, item := range whiteList {
|
||||
if uid, ok := item.(string); ok {
|
||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||
rateLimitConfig.ErrorText = errorText
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok || ssePathSuffix == "" {
|
||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||
}
|
||||
conf.ssePathSuffix = ssePathSuffix
|
||||
|
||||
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
|
||||
if !ok {
|
||||
api.LogDebug("No servers are configured")
|
||||
@@ -153,19 +55,33 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server config must be an object")
|
||||
}
|
||||
|
||||
serverType, ok := serverConfigMap["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server type is not set")
|
||||
}
|
||||
|
||||
serverPath, ok := serverConfigMap["path"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s path is not set", serverType)
|
||||
}
|
||||
|
||||
serverDomainList := []string{}
|
||||
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
|
||||
for _, domain := range domainList {
|
||||
if domainStr, ok := domain.(string); ok {
|
||||
serverDomainList = append(serverDomainList, domainStr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
serverDomainList = []string{"*"}
|
||||
}
|
||||
|
||||
serverName, ok := serverConfigMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s name is not set", serverType)
|
||||
}
|
||||
server := internal.GlobalRegistry.GetServer(serverType)
|
||||
server := common.GlobalRegistry.GetServer(serverType)
|
||||
|
||||
if server == nil {
|
||||
return nil, fmt.Errorf("server %s is not registered", serverType)
|
||||
@@ -186,50 +102,37 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
|
||||
}
|
||||
|
||||
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
|
||||
internal.WithRedisClient(conf.redisClient),
|
||||
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
|
||||
internal.WithMessageEndpoint(serverPath)))
|
||||
conf.servers = append(conf.servers, &SSEServerWrapper{
|
||||
BaseServer: common.NewSSEServer(serverInstance,
|
||||
common.WithRedisClient(common.GlobalRedisClient),
|
||||
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
||||
common.WithMessageEndpoint(serverPath)),
|
||||
DomainList: serverDomainList,
|
||||
})
|
||||
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
parentConfig := parent.(*config)
|
||||
childConfig := child.(*config)
|
||||
|
||||
newConfig := *parentConfig
|
||||
if childConfig.redisClient != nil {
|
||||
newConfig.redisClient = childConfig.redisClient
|
||||
}
|
||||
if childConfig.ssePathSuffix != "" {
|
||||
newConfig.ssePathSuffix = childConfig.ssePathSuffix
|
||||
}
|
||||
if childConfig.servers != nil {
|
||||
newConfig.servers = childConfig.servers
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
if childConfig.matchList != nil {
|
||||
newConfig.matchList = childConfig.matchList
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
conf, ok := c.(*config)
|
||||
if !ok {
|
||||
panic("unexpected config type")
|
||||
}
|
||||
return &filter{
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
|
||||
config: conf,
|
||||
callbacks: callbacks,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -1,104 +1,41 @@
|
||||
package main
|
||||
package mcp_server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||
)
|
||||
|
||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||
type filter struct {
|
||||
api.PassThroughStreamFilter
|
||||
|
||||
callbacks api.FilterCallbackHandler
|
||||
path string
|
||||
config *config
|
||||
stopChan chan struct{}
|
||||
|
||||
req *http.Request
|
||||
serverName string
|
||||
message bool
|
||||
proxyURL *url.URL
|
||||
skip bool
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
ratelimit bool
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
config *config
|
||||
req *http.Request
|
||||
message bool
|
||||
path string
|
||||
}
|
||||
|
||||
type RequestURL struct {
|
||||
method string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
baseURL string
|
||||
parsedURL *url.URL
|
||||
internalIP bool
|
||||
}
|
||||
|
||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
method, _ := header.Get(":method")
|
||||
scheme, _ := header.Get(":scheme")
|
||||
host, _ := header.Get(":authority")
|
||||
path, _ := header.Get(":path")
|
||||
internalIP, _ := header.Get("x-envoy-internal")
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
parsedURL, _ := url.Parse(path)
|
||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
// The endStream is true if the request doesn't have body
|
||||
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
url := NewRequestURL(header)
|
||||
f.path = url.parsedURL.Path
|
||||
|
||||
// Check if request matches any rule in match_list
|
||||
if !internal.IsMatch(f.config.matchList, url.host, f.path) {
|
||||
f.skip = true
|
||||
api.LogDebugf("Request does not match any rule in match_list: %s", url.parsedURL.String())
|
||||
return api.Continue
|
||||
}
|
||||
url := common.NewRequestURL(header)
|
||||
f.path = url.ParsedURL.Path
|
||||
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetSSEEndpoint() {
|
||||
if url.method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.serverName = server.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
api.LogDebugf("%s SSE connection started", server.GetServerName())
|
||||
return api.LocalReply
|
||||
} else if f.path == server.GetMessageEndpoint() {
|
||||
if url.method != http.MethodPost {
|
||||
if common.MatchDomainList(url.ParsedURL.Host, server.DomainList) && url.ParsedURL.Path == server.BaseServer.GetMessageEndpoint() {
|
||||
if url.Method != http.MethodPost {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
// Create a new http.Request object
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
Method: url.Method,
|
||||
URL: url.ParsedURL,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
api.LogDebugf("Message request: %v", url.parsedURL)
|
||||
api.LogDebugf("Message request: %v", url.ParsedURL)
|
||||
// Copy headers from api.RequestHeaderMap to http.Header
|
||||
header.Range(func(key, value string) bool {
|
||||
f.req.Header.Add(key, value)
|
||||
@@ -113,209 +50,33 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
}
|
||||
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.internalIP {
|
||||
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
|
||||
f.proxyURL = url.parsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.parsedURL.Path, "/")
|
||||
if len(parts) >= 3 {
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
}
|
||||
}
|
||||
f.ratelimit = true
|
||||
}
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if url.method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.config.defaultServer = internal.NewSSEServer(internal.NewMCPServer(DefaultServerName, Version),
|
||||
internal.WithSSEEndpoint(f.config.ssePathSuffix),
|
||||
internal.WithMessageEndpoint(strings.TrimSuffix(url.parsedURL.Path, f.config.ssePathSuffix)),
|
||||
internal.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
return api.LocalReply
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.message {
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetMessageEndpoint() {
|
||||
if f.path == server.BaseServer.GetMessageEndpoint() {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
httpStatus := server.BaseServer.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
f.message = false
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
} else if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
} else if f.ratelimit {
|
||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||
return api.Continue
|
||||
}
|
||||
parts := strings.Split(f.req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogWarnf("Access denied: no valid uid found")
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// Callbacks which are called in response path
|
||||
// The endStream is true if the response doesn't have body
|
||||
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
} else {
|
||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// EncodeData might be called multiple times during handling the response body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil && f.config.redisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := internal.GetSSEChannelName(sessionID)
|
||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||
publishErr := f.config.redisClient.Publish(channel, eventData)
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
// handle specific server
|
||||
for _, server := range f.config.servers {
|
||||
if f.serverName == server.GetServerName() {
|
||||
buffer.Reset()
|
||||
server.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
}
|
||||
// handle default server
|
||||
if f.serverName == f.config.defaultServer.GetServerName() {
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
return api.Continue
|
||||
} else {
|
||||
buffer.SetString(RedisNotEnabledResponseBody)
|
||||
return api.Continue
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// OnDestroy stops the goroutine
|
||||
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||
api.LogDebugf("OnDestroy: reason=%v", reason)
|
||||
if f.serverName != "" && f.stopChan != nil {
|
||||
select {
|
||||
case <-f.stopChan:
|
||||
return
|
||||
default:
|
||||
api.LogDebug("Stopping SSE connection")
|
||||
close(f.stopChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if the request is a tools/call request
|
||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||
var request mcp.CallToolRequest
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return request.Method == method
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
internal.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
||||
common.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
||||
}
|
||||
|
||||
type NacosConfig struct {
|
||||
@@ -28,7 +28,7 @@ type NacosConfig struct {
|
||||
}
|
||||
|
||||
type McpServerToolsChangeListener struct {
|
||||
mcpServer *internal.MCPServer
|
||||
mcpServer *common.MCPServer
|
||||
}
|
||||
|
||||
func (l *McpServerToolsChangeListener) OnToolChanged(reg registry.McpServerRegistry) {
|
||||
@@ -137,8 +137,8 @@ func (c *NacosConfig) ParseConfig(config map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
func (c *NacosConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
"1.0.0",
|
||||
)
|
||||
@@ -170,11 +170,11 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
|
||||
return mcpServer, nil
|
||||
}
|
||||
|
||||
func resetToolsToMcpServer(mcpServer *internal.MCPServer, reg registry.McpServerRegistry) {
|
||||
wrappedTools := []internal.ServerTool{}
|
||||
func resetToolsToMcpServer(mcpServer *common.MCPServer, reg registry.McpServerRegistry) {
|
||||
wrappedTools := []common.ServerTool{}
|
||||
tools := reg.ListToolsDesciption()
|
||||
for _, tool := range tools {
|
||||
wrappedTools = append(wrappedTools, internal.ServerTool{
|
||||
wrappedTools = append(wrappedTools, common.ServerTool{
|
||||
Tool: mcp.NewToolWithRawSchema(tool.Name, tool.Description, tool.InputSchema),
|
||||
Handler: registry.HandleRegistryToolsCall(reg),
|
||||
})
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
@@ -204,7 +204,7 @@ func CommonRemoteCall(reg McpServerRegistry, toolName string, parameters map[str
|
||||
return remoteHandle.HandleToolCall(ctx, parameters)
|
||||
}
|
||||
|
||||
func HandleRegistryToolsCall(reg McpServerRegistry) internal.ToolHandlerFunc {
|
||||
func HandleRegistryToolsCall(reg McpServerRegistry) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
return CommonRemoteCall(reg, request.Params.Name, arguments)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
const Version = "1.0.0"
|
||||
|
||||
func init() {
|
||||
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||
common.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
@@ -41,11 +41,11 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
Version,
|
||||
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||
common.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||
)
|
||||
|
||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// HandleQueryTool handles SQL query execution
|
||||
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
||||
func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["sql"].(string)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
@@ -23,6 +23,27 @@ type MatchRule struct {
|
||||
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
|
||||
}
|
||||
|
||||
// ParseMatchList parses the match list from the config
|
||||
func ParseMatchList(matchListConfig []interface{}) []MatchRule {
|
||||
matchList := make([]MatchRule, 0)
|
||||
for _, item := range matchListConfig {
|
||||
if ruleMap, ok := item.(map[string]interface{}); ok {
|
||||
rule := MatchRule{}
|
||||
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
||||
rule.MatchRuleDomain = domain
|
||||
}
|
||||
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
||||
rule.MatchRulePath = path
|
||||
}
|
||||
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
||||
rule.MatchRuleType = RuleType(ruleType)
|
||||
}
|
||||
matchList = append(matchList, rule)
|
||||
}
|
||||
}
|
||||
return matchList
|
||||
}
|
||||
|
||||
// convertWildcardToRegex converts wildcard pattern to regex pattern
|
||||
func convertWildcardToRegex(pattern string) string {
|
||||
pattern = regexp.QuoteMeta(pattern)
|
||||
@@ -87,3 +108,13 @@ func IsMatch(rules []MatchRule, host, path string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchDomainList checks if the domain matches any of the domains in the list
|
||||
func MatchDomainList(domain string, domainList []string) bool {
|
||||
for _, d := range domainList {
|
||||
if matchDomain(domain, d) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var GlobalRedisClient *RedisClient
|
||||
|
||||
type RedisConfig struct {
|
||||
address string
|
||||
username string
|
||||
@@ -249,6 +251,18 @@ func (r *RedisClient) Get(key string) (string, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Expire sets the expiration time for a key
|
||||
func (r *RedisClient) Expire(key string, expiration time.Duration) error {
|
||||
ok, err := r.client.Expire(r.ctx, key, expiration).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set expiration for key: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("key does not exist")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the Redis client and stops the keepalive goroutine
|
||||
func (r *RedisClient) Close() error {
|
||||
r.cancel()
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
var GlobalRegistry = NewServerRegistry()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -243,6 +243,7 @@ func (s *MCPServer) HandleMessage(
|
||||
message json.RawMessage,
|
||||
) mcp.JSONRPCMessage {
|
||||
// Add server to context
|
||||
|
||||
ctx = context.WithValue(ctx, serverKey{}, s)
|
||||
|
||||
var baseMessage struct {
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -210,15 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
var status int
|
||||
// Only send response if there is one (not for notifications)
|
||||
if response != nil {
|
||||
eventData, _ := json.Marshal(response)
|
||||
|
||||
if sessionID != "" && s.redisClient != nil {
|
||||
channel := GetSSEChannelName(sessionID)
|
||||
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))
|
||||
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
} else {
|
||||
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
type RequestURL struct {
|
||||
Method string
|
||||
Scheme string
|
||||
Host string
|
||||
Path string
|
||||
BaseURL string
|
||||
ParsedURL *url.URL
|
||||
InternalIP bool
|
||||
}
|
||||
|
||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
method, _ := header.Get(":method")
|
||||
scheme, _ := header.Get(":scheme")
|
||||
host, _ := header.Get(":authority")
|
||||
path, _ := header.Get(":path")
|
||||
internalIP, _ := header.Get("x-envoy-internal")
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
parsedURL, _ := url.Parse(path)
|
||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
|
||||
}
|
||||
143
plugins/golang-filter/mcp-session/config.go
Normal file
143
plugins/golang-filter/mcp-session/config.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package mcp_session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
const Name = "mcp-session"
|
||||
const Version = "1.0.0"
|
||||
const ConfigPathSuffix = "/config"
|
||||
const DefaultServerName = "higress-mcp-server"
|
||||
|
||||
var GlobalSSEPathSuffix = "/sse"
|
||||
|
||||
type config struct {
|
||||
matchList []common.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
defaultServer *common.SSEServer
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
if common.GlobalRedisClient != nil {
|
||||
api.LogDebug("Closing Redis client")
|
||||
common.GlobalRedisClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
}
|
||||
|
||||
// Parse the filter configuration
|
||||
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
configStruct := &xds.TypedStruct{}
|
||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := configStruct.Value
|
||||
|
||||
conf := &config{
|
||||
matchList: make([]common.MatchRule, 0),
|
||||
}
|
||||
|
||||
// Parse match_list if exists
|
||||
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
||||
conf.matchList = common.ParseMatchList(matchList)
|
||||
}
|
||||
|
||||
// Redis configuration is optional
|
||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||
redisConfig, err := common.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := common.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
common.GlobalRedisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
enableUserLevelServer = false
|
||||
if common.GlobalRedisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
conf.enableUserLevelServer = enableUserLevelServer
|
||||
|
||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||
rateLimitConfig.Limit = int(limit)
|
||||
}
|
||||
if window, ok := rateLimit["window"].(float64); ok {
|
||||
rateLimitConfig.Window = int(window)
|
||||
}
|
||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||
for _, item := range whiteList {
|
||||
if uid, ok := item.(string); ok {
|
||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||
rateLimitConfig.ErrorText = errorText
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok || ssePathSuffix == "" {
|
||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||
}
|
||||
GlobalSSEPathSuffix = ssePathSuffix
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
parentConfig := parent.(*config)
|
||||
childConfig := child.(*config)
|
||||
|
||||
newConfig := *parentConfig
|
||||
if childConfig.matchList != nil {
|
||||
newConfig.matchList = childConfig.matchList
|
||||
}
|
||||
newConfig.enableUserLevelServer = childConfig.enableUserLevelServer
|
||||
if childConfig.rateLimitConfig != nil {
|
||||
newConfig.rateLimitConfig = childConfig.rateLimitConfig
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
conf, ok := c.(*config)
|
||||
if !ok {
|
||||
panic("unexpected config type")
|
||||
}
|
||||
return &filter{
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
|
||||
}
|
||||
}
|
||||
237
plugins/golang-filter/mcp-session/filter.go
Normal file
237
plugins/golang-filter/mcp-session/filter.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package mcp_session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||
)
|
||||
|
||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||
type filter struct {
|
||||
api.PassThroughStreamFilter
|
||||
|
||||
callbacks api.FilterCallbackHandler
|
||||
path string
|
||||
config *config
|
||||
stopChan chan struct{}
|
||||
|
||||
req *http.Request
|
||||
serverName string
|
||||
proxyURL *url.URL
|
||||
skip bool
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
ratelimit bool
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
// The endStream is true if the request doesn't have body
|
||||
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
url := common.NewRequestURL(header)
|
||||
f.path = url.ParsedURL.Path
|
||||
|
||||
// Check if request matches any rule in match_list
|
||||
if !common.IsMatch(f.config.matchList, url.Host, f.path) {
|
||||
f.skip = true
|
||||
api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String())
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
f.req = &http.Request{
|
||||
Method: url.Method,
|
||||
URL: url.ParsedURL,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.InternalIP {
|
||||
api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String())
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) {
|
||||
f.proxyURL = url.ParsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.ParsedURL.Path, "/")
|
||||
if len(parts) >= 3 {
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
}
|
||||
}
|
||||
f.ratelimit = true
|
||||
}
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if url.Method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
|
||||
common.WithRedisClient(common.GlobalRedisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
return api.LocalReply
|
||||
}
|
||||
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
} else if f.ratelimit {
|
||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||
return api.Continue
|
||||
}
|
||||
parts := strings.Split(f.req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogWarnf("Access denied: no valid uid found")
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// Callbacks which are called in response path
|
||||
// The endStream is true if the response doesn't have body
|
||||
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
} else {
|
||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// EncodeData might be called multiple times during handling the response body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil && common.GlobalRedisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := common.GetSSEChannelName(sessionID)
|
||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
// handle default server
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
} else {
|
||||
buffer.SetString(RedisNotEnabledResponseBody)
|
||||
return api.Continue
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// OnDestroy stops the goroutine
|
||||
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||
api.LogDebugf("OnDestroy: reason=%v", reason)
|
||||
if f.serverName != "" && f.stopChan != nil {
|
||||
select {
|
||||
case <-f.stopChan:
|
||||
return
|
||||
default:
|
||||
api.LogDebug("Stopping SSE connection")
|
||||
close(f.stopChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if the request is a tools/call request
|
||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||
var request mcp.CallToolRequest
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return request.Method == method
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ type MCPConfigHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPConfigHandler creates a new instance of MCP configuration handler
|
||||
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||
func NewMCPConfigHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||
return &MCPConfigHandler{
|
||||
configStore: NewRedisConfigStore(redisClient),
|
||||
callbacks: callbacks,
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,11 +36,11 @@ type ConfigStore interface {
|
||||
|
||||
// RedisConfigStore implements configuration storage using Redis
|
||||
type RedisConfigStore struct {
|
||||
redisClient *internal.RedisClient
|
||||
redisClient *common.RedisClient
|
||||
}
|
||||
|
||||
// NewRedisConfigStore creates a new instance of Redis configuration storage
|
||||
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
|
||||
func NewRedisConfigStore(redisClient *common.RedisClient) ConfigStore {
|
||||
return &RedisConfigStore{
|
||||
redisClient: redisClient,
|
||||
}
|
||||
@@ -101,5 +101,11 @@ func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh TTL
|
||||
if err := s.redisClient.Expire(key, configExpiry); err != nil {
|
||||
// Log error but don't fail the request
|
||||
fmt.Printf("Failed to refresh TTL for key %s: %v\n", key, err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type MCPRatelimitHandler struct {
|
||||
redisClient *internal.RedisClient
|
||||
redisClient *common.RedisClient
|
||||
callbacks api.FilterCallbackHandler
|
||||
limit int // Maximum requests allowed per window
|
||||
window int // Time window in seconds
|
||||
@@ -31,7 +31,7 @@ type MCPRatelimitConfig struct {
|
||||
}
|
||||
|
||||
// NewMCPRatelimitHandler creates a new rate limit handler
|
||||
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||
func NewMCPRatelimitHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||
if conf == nil {
|
||||
conf = &MCPRatelimitConfig{
|
||||
Limit: 100,
|
||||
Reference in New Issue
Block a user