mirror of
https://github.com/alibaba/higress.git
synced 2026-03-05 17:10:55 +08:00
feat: support config store and redis configuration optional in mcp server (#2035)
This commit is contained in:
@@ -235,8 +235,7 @@ clean-gateway: clean-istio
|
||||
rm -rf external/proxy
|
||||
rm -rf external/go-control-plane
|
||||
rm -rf external/package/envoy.tar.gz
|
||||
rm -rf external/package/mcp-server_amd64.so
|
||||
rm -rf external/package/mcp-server_arm64.so
|
||||
rm -rf external/package/*.so
|
||||
|
||||
clean-env:
|
||||
rm -rf out/
|
||||
|
||||
@@ -41,6 +41,16 @@ type RedisConfig struct {
|
||||
DB int `json:"db,omitempty"`
|
||||
}
|
||||
|
||||
// MCPRatelimitConfig defines the configuration for rate limit
|
||||
type MCPRatelimitConfig struct {
|
||||
// The limit of the rate limit
|
||||
Limit int64 `json:"limit,omitempty"`
|
||||
// The window of the rate limit
|
||||
Window int64 `json:"window,omitempty"`
|
||||
// The white list of the rate limit
|
||||
WhiteList []string `json:"white_list,omitempty"`
|
||||
}
|
||||
|
||||
// SSEServer defines the configuration for Server-Sent Events (SSE) server
|
||||
type SSEServer struct {
|
||||
// The name of the SSE server
|
||||
@@ -75,13 +85,18 @@ type McpServer struct {
|
||||
Servers []*SSEServer `json:"servers,omitempty"`
|
||||
// List of match rules for filtering requests
|
||||
MatchList []*MatchRule `json:"match_list,omitempty"`
|
||||
// Flag to control whether user level server is enabled
|
||||
EnableUserLevelServer bool `json:"enable_user_level_server,omitempty"`
|
||||
// Rate limit config for MCP server
|
||||
Ratelimit *MCPRatelimitConfig `json:"rate_limit,omitempty"`
|
||||
}
|
||||
|
||||
func NewDefaultMcpServer() *McpServer {
|
||||
return &McpServer{
|
||||
Enable: false,
|
||||
Servers: make([]*SSEServer, 0),
|
||||
MatchList: make([]*MatchRule, 0),
|
||||
Enable: false,
|
||||
Servers: make([]*SSEServer, 0),
|
||||
MatchList: make([]*MatchRule, 0),
|
||||
EnableUserLevelServer: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,8 +109,8 @@ func validMcpServer(m *McpServer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.Enable && m.Redis == nil {
|
||||
return errors.New("redis config cannot be empty when mcp server is enabled")
|
||||
if m.EnableUserLevelServer && m.Redis == nil {
|
||||
return errors.New("redis config cannot be empty when user level server is enabled")
|
||||
}
|
||||
|
||||
// Validate match rule types
|
||||
@@ -149,9 +164,17 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) {
|
||||
DB: mcp.Redis.DB,
|
||||
}
|
||||
}
|
||||
|
||||
if mcp.Ratelimit != nil {
|
||||
newMcp.Ratelimit = &MCPRatelimitConfig{
|
||||
Limit: mcp.Ratelimit.Limit,
|
||||
Window: mcp.Ratelimit.Window,
|
||||
WhiteList: mcp.Ratelimit.WhiteList,
|
||||
}
|
||||
}
|
||||
newMcp.SsePathSuffix = mcp.SsePathSuffix
|
||||
|
||||
newMcp.EnableUserLevelServer = mcp.EnableUserLevelServer
|
||||
|
||||
if len(mcp.Servers) > 0 {
|
||||
newMcp.Servers = make([]*SSEServer, len(mcp.Servers))
|
||||
for i, server := range mcp.Servers {
|
||||
@@ -352,40 +375,59 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
||||
}
|
||||
|
||||
// 构建 Redis 配置
|
||||
redisConfig := "null"
|
||||
if mcp.Redis != nil {
|
||||
redisConfig = fmt.Sprintf(`{
|
||||
"address": "%s",
|
||||
"username": "%s",
|
||||
"password": "%s",
|
||||
"db": %d
|
||||
}`, mcp.Redis.Address, mcp.Redis.Username, mcp.Redis.Password, mcp.Redis.DB)
|
||||
}
|
||||
|
||||
// 构建限流配置
|
||||
rateLimitConfig := "null"
|
||||
if mcp.Ratelimit != nil {
|
||||
whiteList := "[]"
|
||||
if len(mcp.Ratelimit.WhiteList) > 0 {
|
||||
whiteList = fmt.Sprintf(`["%s"]`, strings.Join(mcp.Ratelimit.WhiteList, `","`))
|
||||
}
|
||||
rateLimitConfig = fmt.Sprintf(`{
|
||||
"limit": %d,
|
||||
"window": %d,
|
||||
"white_list": %s
|
||||
}`, mcp.Ratelimit.Limit, mcp.Ratelimit.Window, whiteList)
|
||||
}
|
||||
|
||||
// Build complete configuration structure
|
||||
structFmt := `{
|
||||
return fmt.Sprintf(`{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-server",
|
||||
"library_path": "/var/lib/istio/envoy/mcp-server.so",
|
||||
"plugin_name": "mcp-server",
|
||||
"library_id": "mcp-session",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-session",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"redis": {
|
||||
"address": "%s",
|
||||
"username": "%s",
|
||||
"password": "%s",
|
||||
"db": %d
|
||||
},
|
||||
"redis": %s,
|
||||
"rate_limit": %s,
|
||||
"sse_path_suffix": "%s",
|
||||
"match_list": %s,
|
||||
"servers": %s
|
||||
"servers": %s,
|
||||
"enable_user_level_server": %t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
return fmt.Sprintf(structFmt,
|
||||
mcp.Redis.Address,
|
||||
mcp.Redis.Username,
|
||||
mcp.Redis.Password,
|
||||
mcp.Redis.DB,
|
||||
}`,
|
||||
redisConfig,
|
||||
rateLimitConfig,
|
||||
mcp.SsePathSuffix,
|
||||
matchList,
|
||||
servers)
|
||||
servers,
|
||||
mcp.EnableUserLevelServer)
|
||||
}
|
||||
|
||||
@@ -45,17 +45,30 @@ func Test_validMcpServer(t *testing.T) {
|
||||
{
|
||||
name: "enabled but no redis config",
|
||||
mcp: &McpServer{
|
||||
Enable: true,
|
||||
Redis: nil,
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
Enable: true,
|
||||
EnableUserLevelServer: false,
|
||||
Redis: nil,
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantErr: errors.New("redis config cannot be empty when mcp server is enabled"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "enabled with user level server but no redis config",
|
||||
mcp: &McpServer{
|
||||
Enable: true,
|
||||
EnableUserLevelServer: true,
|
||||
Redis: nil,
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantErr: errors.New("redis config cannot be empty when user level server is enabled"),
|
||||
},
|
||||
{
|
||||
name: "valid config with redis",
|
||||
mcp: &McpServer{
|
||||
Enable: true,
|
||||
Enable: true,
|
||||
EnableUserLevelServer: true,
|
||||
Redis: &RedisConfig{
|
||||
Address: "localhost:6379",
|
||||
Username: "default",
|
||||
|
||||
@@ -36,4 +36,4 @@ RUN if [ "$GOARCH" = "arm64" ]; then \
|
||||
FROM scratch AS output
|
||||
ARG GO_FILTER_NAME
|
||||
ARG GOARCH
|
||||
COPY --from=golang-base /${GO_FILTER_NAME}.so ${GO_FILTER_NAME}_${GOARCH}.so
|
||||
COPY --from=golang-base /${GO_FILTER_NAME}.so golang-filter_${GOARCH}.so
|
||||
@@ -3,9 +3,13 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
||||
@@ -13,20 +17,31 @@ import (
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
)
|
||||
|
||||
const Name = "mcp-server"
|
||||
const Name = "mcp-session"
|
||||
const Version = "1.0.0"
|
||||
const DefaultServerName = "defaultServer"
|
||||
const ConfigPathSuffix = "/config"
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||
}()
|
||||
}
|
||||
|
||||
type config struct {
|
||||
ssePathSuffix string
|
||||
redisClient *internal.RedisClient
|
||||
servers []*internal.SSEServer
|
||||
defaultServer *internal.SSEServer
|
||||
matchList []internal.MatchRule
|
||||
ssePathSuffix string
|
||||
redisClient *internal.RedisClient
|
||||
servers []*internal.SSEServer
|
||||
defaultServer *internal.SSEServer
|
||||
matchList []internal.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
@@ -71,22 +86,50 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
}
|
||||
|
||||
redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{})
|
||||
// Redis configuration is optional
|
||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := internal.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
conf.redisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redis config is not set")
|
||||
enableUserLevelServer = false
|
||||
if conf.redisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
conf.enableUserLevelServer = enableUserLevelServer
|
||||
|
||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||
rateLimitConfig.Limit = int(limit)
|
||||
}
|
||||
if window, ok := rateLimit["window"].(float64); ok {
|
||||
rateLimitConfig.Window = int(window)
|
||||
}
|
||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||
for _, item := range whiteList {
|
||||
if uid, ok := item.(string); ok {
|
||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
redisClient, err := internal.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
conf.redisClient = redisClient
|
||||
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok || ssePathSuffix == "" {
|
||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||
@@ -127,7 +170,7 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
api.LogDebug(fmt.Sprintf("Server config: %+v", serverConfig))
|
||||
|
||||
err = server.ParseConfig(serverConfig)
|
||||
err := server.ParseConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server config: %w", err)
|
||||
}
|
||||
@@ -138,7 +181,7 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
|
||||
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
|
||||
internal.WithRedisClient(redisClient),
|
||||
internal.WithRedisClient(conf.redisClient),
|
||||
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
|
||||
internal.WithMessageEndpoint(serverPath)))
|
||||
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||
@@ -158,11 +201,14 @@ func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
newConfig.ssePathSuffix = childConfig.ssePathSuffix
|
||||
}
|
||||
if childConfig.servers != nil {
|
||||
newConfig.servers = append(newConfig.servers, childConfig.servers...)
|
||||
newConfig.servers = childConfig.servers
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
if childConfig.matchList != nil {
|
||||
newConfig.matchList = childConfig.matchList
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
@@ -172,9 +218,11 @@ func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Strea
|
||||
panic("unexpected config type")
|
||||
}
|
||||
return &filter{
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,18 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||
)
|
||||
|
||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||
type filter struct {
|
||||
@@ -26,15 +32,20 @@ type filter struct {
|
||||
message bool
|
||||
proxyURL *url.URL
|
||||
skip bool
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
}
|
||||
|
||||
type RequestURL struct {
|
||||
method string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
baseURL string
|
||||
parsedURL *url.URL
|
||||
method string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
baseURL string
|
||||
parsedURL *url.URL
|
||||
internalIP bool
|
||||
}
|
||||
|
||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
@@ -42,10 +53,11 @@ func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
scheme, _ := header.Get(":scheme")
|
||||
host, _ := header.Get(":authority")
|
||||
path, _ := header.Get(":path")
|
||||
internalIP, _ := header.Get("x-envoy-internal")
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
parsedURL, _ := url.Parse(path)
|
||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL}
|
||||
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
@@ -71,11 +83,11 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
api.LogDebugf("%s SSE connection started", server.GetServerName())
|
||||
server.SetBaseURL(url.baseURL)
|
||||
return api.LocalReply
|
||||
} else if f.path == server.GetMessageEndpoint() {
|
||||
if url.method != http.MethodPost {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
// Create a new http.Request object
|
||||
f.req = &http.Request{
|
||||
@@ -97,8 +109,57 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.internalIP {
|
||||
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.path, url.method, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
|
||||
f.proxyURL = url.parsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.parsedURL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogDebugf("Access denied: missing uid in path %s", url.parsedURL.Path)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "Access denied: missing uid", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
} else {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(url.parsedURL.Path, url.method, []byte{}) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
@@ -112,7 +173,6 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
f.config.defaultServer.SetBaseURL(url.baseURL)
|
||||
}
|
||||
return api.LocalReply
|
||||
}
|
||||
@@ -138,6 +198,11 @@ func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
}
|
||||
}
|
||||
return api.StopAndBuffer
|
||||
} else if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.path, f.req.Method, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
@@ -149,11 +214,15 @@ func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
if f.config.redisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
} else {
|
||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
@@ -168,7 +237,7 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil {
|
||||
if f.proxyURL != nil && f.config.redisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := internal.GetSSEChannelName(sessionID)
|
||||
@@ -181,21 +250,26 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
// handle specific server
|
||||
for _, server := range f.config.servers {
|
||||
if f.serverName == server.GetServerName() {
|
||||
if f.config.redisClient != nil {
|
||||
// handle specific server
|
||||
for _, server := range f.config.servers {
|
||||
if f.serverName == server.GetServerName() {
|
||||
buffer.Reset()
|
||||
server.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
}
|
||||
// handle default server
|
||||
if f.serverName == f.config.defaultServer.GetServerName() {
|
||||
buffer.Reset()
|
||||
server.HandleSSE(f.callbacks, f.stopChan)
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
return api.Continue
|
||||
} else {
|
||||
buffer.SetString(RedisNotEnabledResponseBody)
|
||||
return api.Continue
|
||||
}
|
||||
// handle default server
|
||||
if f.serverName == f.config.defaultServer.GetServerName() {
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
@@ -136,12 +136,8 @@ github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9r
|
||||
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/envoyproxy/envoy v1.32.3 h1:eftH199KwYfyBTtm4reeEzsWTqraACEaTQ6efl31v0I=
|
||||
github.com/envoyproxy/envoy v1.32.3/go.mod h1:KGS+IUehDX1mSIdqodPTWskKOo7bZMLLy3GHxvOKcJk=
|
||||
github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99 h1:jih/Ieb7BFgVCStgvY5fXQ3mI9ByOt4wfwUF0d7qmqI=
|
||||
github.com/envoyproxy/envoy v1.33.1-0.20250325161043-11ab50a29d99/go.mod h1:x7d0dNbE0xGuDBUkBg19VGCgnPQ+lJ2k8lDzDzKExow=
|
||||
github.com/envoyproxy/envoy v1.33.2 h1:k3ChySbVo4HejvbDRxkgRroUnj6TZZpXPJJ0UGaZkXs=
|
||||
github.com/envoyproxy/envoy v1.33.2/go.mod h1:faFqv1XeNGX/ph6Zto5Culdcpk4Klxp730Q6XhWarV4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
@@ -285,6 +281,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-40 h1:nzRTBplC0riQqQwEHZThw5H4/TH5LgWTQTm6A7t1lpY=
|
||||
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-40/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
|
||||
github.com/mark3labs/mcp-go v0.12.0 h1:Pue1Tdwqcz77GHq18uzgmLT3wmeDUxXUSAqSwhGLhVo=
|
||||
github.com/mark3labs/mcp-go v0.12.0/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
@@ -302,8 +300,6 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.2.9 h1:etzCMnB9EBeSKfaDIOe8zH4HO/8fycpc6s0AmXCrmAw=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.2.9/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
|
||||
153
plugins/golang-filter/mcp-server/handler/config_handler.go
Normal file
153
plugins/golang-filter/mcp-server/handler/config_handler.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
// MCPConfigHandler handles configuration requests for MCP server
|
||||
type MCPConfigHandler struct {
|
||||
configStore ConfigStore
|
||||
callbacks api.FilterCallbackHandler
|
||||
}
|
||||
|
||||
// NewMCPConfigHandler creates a new instance of MCP configuration handler
|
||||
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||
return &MCPConfigHandler{
|
||||
configStore: NewRedisConfigStore(redisClient),
|
||||
callbacks: callbacks,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleConfigRequest processes configuration requests
|
||||
func (h *MCPConfigHandler) HandleConfigRequest(path string, method string, body []byte) bool {
|
||||
// Check if it's a configuration request
|
||||
if !strings.HasSuffix(path, "/config") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract serverName and uid from path
|
||||
pathParts := strings.Split(strings.TrimSuffix(path, "/config"), "/")
|
||||
if len(pathParts) < 2 {
|
||||
h.sendErrorResponse(http.StatusBadRequest, "INVALID_PATH", "Invalid path format")
|
||||
return true
|
||||
}
|
||||
uid := pathParts[len(pathParts)-1]
|
||||
serverName := pathParts[len(pathParts)-2]
|
||||
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
return h.handleGetConfig(serverName, uid)
|
||||
case http.MethodPost:
|
||||
return h.handleStoreConfig(serverName, uid, body)
|
||||
default:
|
||||
h.sendErrorResponse(http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "Method not allowed")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetConfig handles configuration retrieval requests
|
||||
func (h *MCPConfigHandler) handleGetConfig(serverName string, uid string) bool {
|
||||
config, err := h.configStore.GetConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to get config for server %s, uid %s: %v", serverName, uid, err)
|
||||
h.sendErrorResponse(http.StatusInternalServerError, "CONFIG_ERROR", fmt.Sprintf("Failed to get configuration: %s", err.Error()))
|
||||
return true
|
||||
}
|
||||
|
||||
response := struct {
|
||||
Success bool `json:"success"`
|
||||
Config map[string]string `json:"config"`
|
||||
}{
|
||||
Success: true,
|
||||
Config: config,
|
||||
}
|
||||
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
http.StatusOK,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleStoreConfig handles configuration storage requests
|
||||
func (h *MCPConfigHandler) handleStoreConfig(serverName string, uid string, body []byte) bool {
|
||||
// Parse request body
|
||||
var requestBody struct {
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &requestBody); err != nil {
|
||||
api.LogErrorf("Invalid request format for server %s, uid %s: %v", serverName, uid, err)
|
||||
h.sendErrorResponse(http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("Invalid request format: %s", err.Error()))
|
||||
return true
|
||||
}
|
||||
|
||||
if requestBody.Config == nil {
|
||||
h.sendErrorResponse(http.StatusBadRequest, "INVALID_REQUEST", "Config cannot be null")
|
||||
return true
|
||||
}
|
||||
|
||||
response, err := h.configStore.StoreConfig(serverName, uid, requestBody.Config)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to store config for server %s, uid %s: %v", serverName, uid, err)
|
||||
h.sendErrorResponse(http.StatusInternalServerError, "CONFIG_ERROR", fmt.Sprintf("Failed to store configuration: %s", err.Error()))
|
||||
return true
|
||||
}
|
||||
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
http.StatusOK,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response with the specified status, code and message
|
||||
func (h *MCPConfigHandler) sendErrorResponse(status int, code string, message string) {
|
||||
response := &ConfigResponse{
|
||||
Success: false,
|
||||
Error: &struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
responseBytes, _ := json.Marshal(response)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(
|
||||
status,
|
||||
string(responseBytes),
|
||||
nil, 0, "",
|
||||
)
|
||||
}
|
||||
|
||||
// GetEncodedConfig retrieves and encodes the configuration for a given server and uid
|
||||
func (h *MCPConfigHandler) GetEncodedConfig(serverName string, uid string) (string, error) {
|
||||
conf, err := h.configStore.GetConfig(serverName, uid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config: %w", err)
|
||||
}
|
||||
|
||||
// Check if config exists and is not empty
|
||||
if len(conf) > 0 {
|
||||
// Convert config map to JSON string
|
||||
configBytes, err := json.Marshal(conf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
// Encode JSON string to base64
|
||||
return base64.StdEncoding.EncodeToString(configBytes), nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
105
plugins/golang-filter/mcp-server/handler/config_store.go
Normal file
105
plugins/golang-filter/mcp-server/handler/config_store.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
configExpiry = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// GetConfigStoreKey returns the Redis channel name for the given session ID
|
||||
func GetConfigStoreKey(serverName string, uid string) string {
|
||||
return fmt.Sprintf("mcp-server-config:%s:%s", serverName, uid)
|
||||
}
|
||||
|
||||
// ConfigResponse represents the response structure for configuration operations
|
||||
type ConfigResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ConfigStore defines the interface for configuration storage operations
|
||||
type ConfigStore interface {
|
||||
// StoreConfig stores user configuration
|
||||
StoreConfig(serverName string, uid string, config map[string]string) (*ConfigResponse, error)
|
||||
// GetConfig retrieves user configuration
|
||||
GetConfig(serverName string, uid string) (map[string]string, error)
|
||||
}
|
||||
|
||||
// RedisConfigStore implements configuration storage using Redis
|
||||
type RedisConfigStore struct {
|
||||
redisClient *internal.RedisClient
|
||||
}
|
||||
|
||||
// NewRedisConfigStore creates a new instance of Redis configuration storage
|
||||
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
|
||||
return &RedisConfigStore{
|
||||
redisClient: redisClient,
|
||||
}
|
||||
}
|
||||
|
||||
// StoreConfig stores configuration in Redis
|
||||
func (s *RedisConfigStore) StoreConfig(serverName string, uid string, config map[string]string) (*ConfigResponse, error) {
|
||||
key := GetConfigStoreKey(serverName, uid)
|
||||
|
||||
// Convert config to JSON
|
||||
configBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return &ConfigResponse{
|
||||
Success: false,
|
||||
Error: &struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Code: "MARSHAL_ERROR",
|
||||
Message: "Failed to marshal configuration",
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
// Store in Redis with expiry
|
||||
err = s.redisClient.Set(key, string(configBytes), configExpiry)
|
||||
if err != nil {
|
||||
return &ConfigResponse{
|
||||
Success: false,
|
||||
Error: &struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Code: "REDIS_ERROR",
|
||||
Message: "Failed to store configuration in Redis",
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
return &ConfigResponse{
|
||||
Success: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConfig retrieves configuration from Redis
|
||||
func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]string, error) {
|
||||
key := GetConfigStoreKey(serverName, uid)
|
||||
|
||||
// Get from Redis
|
||||
value, err := s.redisClient.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var config map[string]string
|
||||
if err := json.Unmarshal([]byte(value), &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
129
plugins/golang-filter/mcp-server/handler/rate_limit_handler.go
Normal file
129
plugins/golang-filter/mcp-server/handler/rate_limit_handler.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
type MCPRatelimitHandler struct {
|
||||
redisClient *internal.RedisClient
|
||||
callbacks api.FilterCallbackHandler
|
||||
limit int // Maximum requests allowed per window
|
||||
window int // Time window in seconds
|
||||
whitelist []string // Whitelist of UIDs that bypass rate limiting
|
||||
}
|
||||
|
||||
// MCPRatelimitConfig is the configuration for the rate limit handler
|
||||
type MCPRatelimitConfig struct {
|
||||
Limit int `json:"limit"`
|
||||
Window int `json:"window"`
|
||||
Whitelist []string `json:"white_list"` // List of UIDs that bypass rate limiting
|
||||
}
|
||||
|
||||
// NewMCPRatelimitHandler creates a new rate limit handler
|
||||
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||
if conf == nil {
|
||||
conf = &MCPRatelimitConfig{
|
||||
Limit: 100,
|
||||
Window: int(24 * time.Hour / time.Second), // 24 hours in seconds
|
||||
Whitelist: []string{},
|
||||
}
|
||||
}
|
||||
return &MCPRatelimitHandler{
|
||||
redisClient: redisClient,
|
||||
callbacks: callbacks,
|
||||
limit: conf.Limit,
|
||||
window: conf.Window,
|
||||
whitelist: conf.Whitelist,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// Lua script for rate limiting
|
||||
LimitScript = `
|
||||
local ttl = redis.call('ttl', KEYS[1])
|
||||
if ttl < 0 then
|
||||
redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2])
|
||||
return {ARGV[1], ARGV[1] - 1, ARGV[2]}
|
||||
end
|
||||
return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl}
|
||||
`
|
||||
)
|
||||
|
||||
type LimitContext struct {
|
||||
Count int // Current request count
|
||||
Remaining int // Remaining requests allowed
|
||||
Reset int // Time until reset in seconds
|
||||
}
|
||||
|
||||
func (h *MCPRatelimitHandler) HandleRatelimit(path string, method string, body []byte) bool {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 3 {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return false
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
|
||||
// Check if the UID is in whitelist
|
||||
for _, whitelistedUID := range h.whitelist {
|
||||
if whitelistedUID == uid {
|
||||
return true // Bypass rate limiting for whitelisted UIDs
|
||||
}
|
||||
}
|
||||
|
||||
// Build rate limit key using serverName, uid, window and limit
|
||||
limitKey := fmt.Sprintf("mcp-server-limit:%s:%s:%d:%d", serverName, uid, h.window, h.limit)
|
||||
keys := []string{limitKey}
|
||||
|
||||
args := []interface{}{h.limit, h.window}
|
||||
|
||||
result, err := h.redisClient.Eval(LimitScript, 1, keys, args)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to check rate limit: %v", err)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
|
||||
return false
|
||||
}
|
||||
|
||||
// Process response
|
||||
resultArray, ok := result.([]interface{})
|
||||
if !ok || len(resultArray) != 3 {
|
||||
api.LogErrorf("Invalid response format: %v", result)
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusInternalServerError, "", nil, 0, "")
|
||||
return false
|
||||
}
|
||||
|
||||
context := LimitContext{
|
||||
Count: parseRedisValue(resultArray[0]),
|
||||
Remaining: parseRedisValue(resultArray[1]),
|
||||
Reset: parseRedisValue(resultArray[2]),
|
||||
}
|
||||
|
||||
if context.Remaining < 0 {
|
||||
h.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusTooManyRequests, "", nil, 0, "")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// parseRedisValue converts the value from Redis to an int
|
||||
func parseRedisValue(value interface{}) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case string:
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
76
plugins/golang-filter/mcp-server/internal/crypto.go
Normal file
76
plugins/golang-filter/mcp-server/internal/crypto.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Crypto handles encryption and decryption operations using AES-GCM
|
||||
type Crypto struct {
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
func NewCrypto(secret string) (*Crypto, error) {
|
||||
if secret == "" {
|
||||
return nil, fmt.Errorf("secret cannot be empty")
|
||||
}
|
||||
|
||||
// Generate a 32-byte key using SHA-256
|
||||
hash := sha256.Sum256([]byte(secret))
|
||||
block, err := aes.NewCipher(hash[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %v", err)
|
||||
}
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %v", err)
|
||||
}
|
||||
|
||||
return &Crypto{gcm: gcm}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext data using AES-GCM
|
||||
func (c *Crypto) Encrypt(plaintext []byte) (string, error) {
|
||||
// Generate random nonce
|
||||
nonce := make([]byte, c.gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt and authenticate data
|
||||
ciphertext := c.gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the encrypted string using AES-GCM
|
||||
func (c *Crypto) Decrypt(encryptedStr string) ([]byte, error) {
|
||||
// Decode base64
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid encrypted data format")
|
||||
}
|
||||
|
||||
// Check if the ciphertext is too short
|
||||
if len(ciphertext) < c.gcm.NonceSize() {
|
||||
return nil, fmt.Errorf("invalid encrypted data length")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce := ciphertext[:c.gcm.NonceSize()]
|
||||
ciphertext = ciphertext[c.gcm.NonceSize():]
|
||||
|
||||
// Decrypt and verify data
|
||||
plaintext, err := c.gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
@@ -10,35 +10,42 @@ import (
|
||||
)
|
||||
|
||||
type RedisConfig struct {
|
||||
Address string
|
||||
Username string
|
||||
Password string
|
||||
DB int
|
||||
address string
|
||||
username string
|
||||
password string
|
||||
db int
|
||||
secret string // Encryption key
|
||||
}
|
||||
|
||||
func ParseRedisConfig(config map[string]any) (*RedisConfig, error) {
|
||||
// ParseRedisConfig parses Redis configuration from a map
|
||||
func ParseRedisConfig(config map[string]interface{}) (*RedisConfig, error) {
|
||||
c := &RedisConfig{}
|
||||
|
||||
// address is required
|
||||
addr, ok := config["address"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("address is required and must be a string")
|
||||
if addr, ok := config["address"].(string); ok && addr != "" {
|
||||
c.address = addr
|
||||
} else {
|
||||
return nil, fmt.Errorf("address is required and must be a non-empty string")
|
||||
}
|
||||
c.Address = addr
|
||||
|
||||
// username is optional
|
||||
if username, ok := config["username"].(string); ok {
|
||||
c.Username = username
|
||||
c.username = username
|
||||
}
|
||||
|
||||
// password is optional
|
||||
if password, ok := config["password"].(string); ok {
|
||||
c.Password = password
|
||||
c.password = password
|
||||
}
|
||||
|
||||
// db is optional, default to 0
|
||||
if db, ok := config["db"].(int); ok {
|
||||
c.DB = db
|
||||
c.db = db
|
||||
}
|
||||
|
||||
// secret is optional
|
||||
if secret, ok := config["secret"].(string); ok {
|
||||
c.secret = secret
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@@ -50,15 +57,16 @@ type RedisClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *RedisConfig
|
||||
crypto *Crypto
|
||||
}
|
||||
|
||||
// NewRedisClient creates a new RedisClient instance and establishes a connection to the Redis server
|
||||
func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: config.Address,
|
||||
Username: config.Username,
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
Addr: config.address,
|
||||
Username: config.username,
|
||||
Password: config.password,
|
||||
DB: config.db,
|
||||
})
|
||||
|
||||
// Ping the Redis server to check the connection
|
||||
@@ -69,11 +77,22 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
api.LogDebugf("Connected to Redis: %s", pong)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var crypto *Crypto
|
||||
if config.secret != "" {
|
||||
crypto, err = NewCrypto(config.secret)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
redisClient := &RedisClient{
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
crypto: crypto,
|
||||
}
|
||||
|
||||
// Start keep-alive check
|
||||
@@ -117,10 +136,10 @@ func (r *RedisClient) reconnect() error {
|
||||
|
||||
// Create new client
|
||||
r.client = redis.NewClient(&redis.Options{
|
||||
Addr: r.config.Address,
|
||||
Username: r.config.Username,
|
||||
Password: r.config.Password,
|
||||
DB: r.config.DB,
|
||||
Addr: r.config.address,
|
||||
Username: r.config.username,
|
||||
Password: r.config.password,
|
||||
DB: r.config.db,
|
||||
})
|
||||
|
||||
// Test the new connection
|
||||
@@ -150,6 +169,12 @@ func (r *RedisClient) Subscribe(channel string, stopChan chan struct{}, callback
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Redis Subscribe recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
pubsub.Close()
|
||||
api.LogDebugf("Closed subscription to channel %s", channel)
|
||||
@@ -184,7 +209,19 @@ func (r *RedisClient) Subscribe(channel string, stopChan chan struct{}, callback
|
||||
|
||||
// Set sets the value of a key in Redis
|
||||
func (r *RedisClient) Set(key string, value string, expiration time.Duration) error {
|
||||
err := r.client.Set(r.ctx, key, value, expiration).Err()
|
||||
var finalValue string
|
||||
if r.crypto != nil {
|
||||
// Encrypt the data
|
||||
encryptedValue, err := r.crypto.Encrypt([]byte(value))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt value: %w", err)
|
||||
}
|
||||
finalValue = encryptedValue
|
||||
} else {
|
||||
finalValue = value
|
||||
}
|
||||
|
||||
err := r.client.Set(r.ctx, key, finalValue, expiration).Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set key: %w", err)
|
||||
}
|
||||
@@ -193,13 +230,23 @@ func (r *RedisClient) Set(key string, value string, expiration time.Duration) er
|
||||
|
||||
// Get retrieves the value of a key from Redis
|
||||
func (r *RedisClient) Get(key string) (string, error) {
|
||||
val, err := r.client.Get(r.ctx, key).Result()
|
||||
value, err := r.client.Get(r.ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return "", fmt.Errorf("key does not exist")
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("failed to get key: %w", err)
|
||||
}
|
||||
return val, nil
|
||||
|
||||
if r.crypto != nil {
|
||||
// Decrypt the data
|
||||
decryptedValue, err := r.crypto.Decrypt(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt value: %w", err)
|
||||
}
|
||||
return string(decryptedValue), nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Close closes the Redis client and stops the keepalive goroutine
|
||||
@@ -207,3 +254,13 @@ func (r *RedisClient) Close() error {
|
||||
r.cancel()
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Eval executes a Lua script
|
||||
func (r *RedisClient) Eval(script string, numKeys int, keys []string, args []interface{}) (interface{}, error) {
|
||||
result, err := r.client.Eval(r.ctx, script, keys, args...).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute Lua script: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -419,6 +419,16 @@ func (s *MCPServer) HandleMessage(
|
||||
)
|
||||
}
|
||||
return s.handleToolCall(ctx, baseMessage.ID, request)
|
||||
case "":
|
||||
var response mcp.JSONRPCResponse
|
||||
if err := json.Unmarshal(message, &response); err != nil {
|
||||
return createErrorResponse(
|
||||
baseMessage.ID,
|
||||
mcp.INVALID_REQUEST,
|
||||
"Invalid message format",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return createErrorResponse(
|
||||
baseMessage.ID,
|
||||
|
||||
@@ -28,10 +28,6 @@ type SSEServer struct {
|
||||
redisClient *RedisClient // Redis client for pub/sub
|
||||
}
|
||||
|
||||
func (s *SSEServer) SetBaseURL(baseURL string) {
|
||||
s.baseURL = baseURL
|
||||
}
|
||||
|
||||
func (s *SSEServer) GetMessageEndpoint() string {
|
||||
return s.messageEndpoint
|
||||
}
|
||||
@@ -148,6 +144,12 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
|
||||
// Start health check handler
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Health check handler recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -158,7 +160,15 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
case <-ticker.C:
|
||||
// Send health check message
|
||||
currentTime := time.Now().Format(time.RFC3339)
|
||||
healthCheckEvent := fmt.Sprintf(": ping - %s\n\n", currentTime)
|
||||
pingRequest := mcp.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: currentTime,
|
||||
Request: mcp.Request{
|
||||
Method: "ping",
|
||||
},
|
||||
}
|
||||
pingData, _ := json.Marshal(pingRequest)
|
||||
healthCheckEvent := fmt.Sprintf("event: message\ndata: %s\n\n", pingData)
|
||||
if err := s.redisClient.Publish(channel, healthCheckEvent); err != nil {
|
||||
api.LogErrorf("Failed to send health check: %v", err)
|
||||
}
|
||||
@@ -202,7 +212,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
if response != nil {
|
||||
eventData, _ := json.Marshal(response)
|
||||
|
||||
if sessionID != "" {
|
||||
if sessionID != "" && s.redisClient != nil {
|
||||
channel := GetSSEChannelName(sessionID)
|
||||
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))
|
||||
|
||||
|
||||
@@ -154,6 +154,12 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
|
||||
nacosRegistry.RegisterToolChangeEventListener(&listener)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("NacosToolsListRefresh recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if nacosRegistry.refreshToolsList() {
|
||||
resetToolsToMcpServer(mcpServer, nacosRegistry)
|
||||
|
||||
@@ -29,7 +29,7 @@ if [ ! -n "$INNER_GO_FILTER_NAME" ]; then
|
||||
name=${file##*/}
|
||||
echo "🚀 Build Go Filter: $name"
|
||||
GO_FILTER_NAME=${name} GOARCH=${TARGET_ARCH} make build
|
||||
cp ${GO_FILTERS_DIR}/${file}/${name}_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
||||
cp ${GO_FILTERS_DIR}/${file}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
||||
fi
|
||||
done
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user