mirror of
https://github.com/alibaba/higress.git
synced 2026-06-25 18:25:10 +08:00
fix: Refactor MCP Server into MCP Session and MCP Server (#2120)
This commit is contained in:
@@ -61,6 +61,8 @@ type SSEServer struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
// Additional Config parameters for the real MCP server implementation
|
||||
Config map[string]interface{} `json:"config,omitempty"`
|
||||
// The domain list of the SSE server
|
||||
DomainList []string `json:"domain_list,omitempty"`
|
||||
}
|
||||
|
||||
// MatchRule defines a rule for matching requests
|
||||
@@ -179,9 +181,10 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) {
|
||||
newMcp.Servers = make([]*SSEServer, len(mcp.Servers))
|
||||
for i, server := range mcp.Servers {
|
||||
newServer := &SSEServer{
|
||||
Name: server.Name,
|
||||
Path: server.Path,
|
||||
Type: server.Type,
|
||||
Name: server.Name,
|
||||
Path: server.Path,
|
||||
Type: server.Type,
|
||||
DomainList: server.DomainList,
|
||||
}
|
||||
if server.Config != nil {
|
||||
newServer.Config = make(map[string]interface{})
|
||||
@@ -294,73 +297,88 @@ func (m *McpServerController) ConstructEnvoyFilters() ([]*config.Config, error)
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
mcpStruct := m.constructMcpServerStruct(mcpServer)
|
||||
if mcpStruct == "" {
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
config := &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: higressMcpServerEnvoyFilterName,
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
{
|
||||
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||
Name: "envoy.filters.http.cors",
|
||||
// mcp-session envoy filter
|
||||
mcpSessionStruct := m.constructMcpSessionStruct(mcpServer)
|
||||
if mcpSessionStruct != "" {
|
||||
sessionConfig := &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: higressMcpServerEnvoyFilterName,
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
{
|
||||
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||
Name: "envoy.filters.http.cors",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_INSERT_AFTER,
|
||||
Value: util.BuildPatchStruct(mcpStruct),
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_INSERT_AFTER,
|
||||
Value: util.BuildPatchStruct(mcpSessionStruct),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
configs = append(configs, sessionConfig)
|
||||
}
|
||||
|
||||
// mcp-server envoy filter
|
||||
mcpServerStruct := m.constructMcpServerStruct(mcpServer)
|
||||
if mcpServerStruct != "" {
|
||||
serverConfig := &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: higressMcpServerEnvoyFilterName + "-server",
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
{
|
||||
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||
Name: "envoy.filters.http.router",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_INSERT_BEFORE,
|
||||
Value: util.BuildPatchStruct(mcpServerStruct),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
configs = append(configs, serverConfig)
|
||||
}
|
||||
|
||||
configs = append(configs, config)
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
// Build servers configuration
|
||||
servers := "[]"
|
||||
if len(mcp.Servers) > 0 {
|
||||
serverConfigs := make([]string, len(mcp.Servers))
|
||||
for i, server := range mcp.Servers {
|
||||
serverConfig := fmt.Sprintf(`{
|
||||
"name": "%s",
|
||||
"path": "%s",
|
||||
"type": "%s"`,
|
||||
server.Name, server.Path, server.Type)
|
||||
|
||||
if len(server.Config) > 0 {
|
||||
config, _ := json.Marshal(server.Config)
|
||||
serverConfig += fmt.Sprintf(`,
|
||||
"config": %s`, string(config))
|
||||
}
|
||||
|
||||
serverConfig += "}"
|
||||
serverConfigs[i] = serverConfig
|
||||
}
|
||||
servers = fmt.Sprintf("[%s]", strings.Join(serverConfigs, ","))
|
||||
}
|
||||
|
||||
func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||
// Build match_list configuration
|
||||
matchList := "[]"
|
||||
if len(mcp.MatchList) > 0 {
|
||||
@@ -375,7 +393,7 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
||||
}
|
||||
|
||||
// 构建 Redis 配置
|
||||
// Build redis configuration
|
||||
redisConfig := "null"
|
||||
if mcp.Redis != nil {
|
||||
redisConfig = fmt.Sprintf(`{
|
||||
@@ -386,7 +404,7 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
}`, mcp.Redis.Address, mcp.Redis.Username, mcp.Redis.Password, mcp.Redis.DB)
|
||||
}
|
||||
|
||||
// 构建限流配置
|
||||
// Build rate limit configuration
|
||||
rateLimitConfig := "null"
|
||||
if mcp.Ratelimit != nil {
|
||||
whiteList := "[]"
|
||||
@@ -417,7 +435,6 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
"rate_limit": %s,
|
||||
"sse_path_suffix": "%s",
|
||||
"match_list": %s,
|
||||
"servers": %s,
|
||||
"enable_user_level_server": %t
|
||||
}
|
||||
}
|
||||
@@ -428,6 +445,53 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
rateLimitConfig,
|
||||
mcp.SsePathSuffix,
|
||||
matchList,
|
||||
servers,
|
||||
mcp.EnableUserLevelServer)
|
||||
}
|
||||
|
||||
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
// Build servers configuration
|
||||
servers := "[]"
|
||||
if len(mcp.Servers) > 0 {
|
||||
serverConfigs := make([]string, len(mcp.Servers))
|
||||
for i, server := range mcp.Servers {
|
||||
serverConfig := fmt.Sprintf(`{
|
||||
"name": "%s",
|
||||
"path": "%s",
|
||||
"type": "%s"`,
|
||||
server.Name, server.Path, server.Type)
|
||||
if len(server.DomainList) > 0 {
|
||||
domainList := fmt.Sprintf(`["%s"]`, strings.Join(server.DomainList, `","`))
|
||||
serverConfig += fmt.Sprintf(`,
|
||||
"domain_list": %s`, domainList)
|
||||
}
|
||||
if len(server.Config) > 0 {
|
||||
config, _ := json.Marshal(server.Config)
|
||||
serverConfig += fmt.Sprintf(`,
|
||||
"config": %s`, string(config))
|
||||
}
|
||||
serverConfig += "}"
|
||||
serverConfigs[i] = serverConfig
|
||||
}
|
||||
servers = fmt.Sprintf("[%s]", strings.Join(serverConfigs, ","))
|
||||
}
|
||||
|
||||
// Build complete configuration structure
|
||||
return fmt.Sprintf(`{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-server",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-server",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"servers": %s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, servers)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package configmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@@ -422,3 +423,311 @@ func TestMcpServerController_AddOrUpdateHigressConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMcpServerController_ValidHigressConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
higressConfig *HigressConfig
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
higressConfig: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "nil mcp server",
|
||||
higressConfig: &HigressConfig{
|
||||
McpServer: nil,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid config",
|
||||
higressConfig: &HigressConfig{
|
||||
McpServer: &McpServer{
|
||||
Enable: true,
|
||||
Redis: &RedisConfig{
|
||||
Address: "localhost:6379",
|
||||
},
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid config - user level server without redis",
|
||||
higressConfig: &HigressConfig{
|
||||
McpServer: &McpServer{
|
||||
Enable: true,
|
||||
EnableUserLevelServer: true,
|
||||
Redis: nil,
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
},
|
||||
wantErr: errors.New("redis config cannot be empty when user level server is enabled"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewMcpServerController("test-namespace")
|
||||
err := m.ValidHigressConfig(tt.higressConfig)
|
||||
assert.Equal(t, tt.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMcpServerController_ConstructEnvoyFilters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mcpServer *McpServer
|
||||
wantConfigs int
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "nil mcp server",
|
||||
mcpServer: nil,
|
||||
wantConfigs: 0,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "disabled mcp server",
|
||||
mcpServer: &McpServer{
|
||||
Enable: false,
|
||||
},
|
||||
wantConfigs: 0,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid mcp server with redis",
|
||||
mcpServer: &McpServer{
|
||||
Enable: true,
|
||||
Redis: &RedisConfig{
|
||||
Address: "localhost:6379",
|
||||
},
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantConfigs: 2, // Both session and server filters
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewMcpServerController("test-namespace")
|
||||
m.mcpServer.Store(tt.mcpServer)
|
||||
configs, err := m.ConstructEnvoyFilters()
|
||||
assert.Equal(t, tt.wantErr, err)
|
||||
assert.Equal(t, tt.wantConfigs, len(configs))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMcpServerController_constructMcpSessionStruct(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mcp *McpServer
|
||||
wantJSON string
|
||||
}{
|
||||
{
|
||||
name: "minimal config",
|
||||
mcp: &McpServer{
|
||||
Enable: true,
|
||||
Redis: &RedisConfig{
|
||||
Address: "localhost:6379",
|
||||
},
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantJSON: `{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-session",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-session",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"redis": {
|
||||
"address": "localhost:6379",
|
||||
"username": "",
|
||||
"password": "",
|
||||
"db": 0
|
||||
},
|
||||
"rate_limit": null,
|
||||
"sse_path_suffix": "",
|
||||
"match_list": [],
|
||||
"enable_user_level_server": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "full config",
|
||||
mcp: &McpServer{
|
||||
Enable: true,
|
||||
Redis: &RedisConfig{
|
||||
Address: "localhost:6379",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
DB: 1,
|
||||
},
|
||||
SsePathSuffix: "/sse",
|
||||
MatchList: []*MatchRule{
|
||||
{
|
||||
MatchRuleDomain: "*",
|
||||
MatchRulePath: "/test",
|
||||
MatchRuleType: "exact",
|
||||
},
|
||||
},
|
||||
EnableUserLevelServer: true,
|
||||
Ratelimit: &MCPRatelimitConfig{
|
||||
Limit: 100,
|
||||
Window: 3600,
|
||||
WhiteList: []string{"user1", "user2"},
|
||||
},
|
||||
},
|
||||
wantJSON: `{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-session",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-session",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"redis": {
|
||||
"address": "localhost:6379",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"db": 1
|
||||
},
|
||||
"rate_limit": {
|
||||
"limit": 100,
|
||||
"window": 3600,
|
||||
"white_list": ["user1","user2"]
|
||||
},
|
||||
"sse_path_suffix": "/sse",
|
||||
"match_list": [{
|
||||
"match_rule_domain": "*",
|
||||
"match_rule_path": "/test",
|
||||
"match_rule_type": "exact"
|
||||
}],
|
||||
"enable_user_level_server": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewMcpServerController("test-namespace")
|
||||
got := m.constructMcpSessionStruct(tt.mcp)
|
||||
// Normalize JSON strings for comparison
|
||||
var gotJSON, wantJSON interface{}
|
||||
json.Unmarshal([]byte(got), &gotJSON)
|
||||
json.Unmarshal([]byte(tt.wantJSON), &wantJSON)
|
||||
assert.Equal(t, wantJSON, gotJSON)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMcpServerController_constructMcpServerStruct(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mcp *McpServer
|
||||
wantJSON string
|
||||
}{
|
||||
{
|
||||
name: "no servers",
|
||||
mcp: &McpServer{
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantJSON: `{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-server",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-server",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"servers": []
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "with servers",
|
||||
mcp: &McpServer{
|
||||
Servers: []*SSEServer{
|
||||
{
|
||||
Name: "test-server",
|
||||
Path: "/test",
|
||||
Type: "test",
|
||||
Config: map[string]interface{}{
|
||||
"key": "value",
|
||||
},
|
||||
DomainList: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantJSON: `{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-server",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-server",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"servers": [{
|
||||
"name": "test-server",
|
||||
"path": "/test",
|
||||
"type": "test",
|
||||
"domain_list": ["example.com"],
|
||||
"config": {"key":"value"}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewMcpServerController("test-namespace")
|
||||
got := m.constructMcpServerStruct(tt.mcp)
|
||||
// Normalize JSON strings for comparison
|
||||
var gotJSON, wantJSON interface{}
|
||||
json.Unmarshal([]byte(got), &gotJSON)
|
||||
json.Unmarshal([]byte(tt.wantJSON), &wantJSON)
|
||||
assert.Equal(t, wantJSON, gotJSON)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ WORKDIR /workspace
|
||||
|
||||
COPY . .
|
||||
|
||||
WORKDIR /workspace/$GO_FILTER_NAME
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN go mod tidy
|
||||
RUN if [ "$GOARCH" = "arm64" ]; then \
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
GO_FILTER_NAME ?= mcp-server
|
||||
GO_FILTER_NAME ?= golang-filter
|
||||
GOPROXY := $(shell go env GOPROXY)
|
||||
GOARCH ?= amd64
|
||||
|
||||
@@ -8,5 +8,5 @@ build:
|
||||
--build-arg GO_FILTER_NAME=${GO_FILTER_NAME} \
|
||||
--build-arg GOARCH=${GOARCH} \
|
||||
-t ${GO_FILTER_NAME} \
|
||||
--output ./${GO_FILTER_NAME} \
|
||||
--output . \
|
||||
.
|
||||
@@ -28,7 +28,7 @@ http_filters:
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config
|
||||
library_id: my-go-filter
|
||||
library_path: "./my-go-filter.so"
|
||||
library_path: "./go-filter.so"
|
||||
plugin_name: my-go-filter
|
||||
plugin_config:
|
||||
"@type": type.googleapis.com/xds.type.v3.TypedStruct
|
||||
@@ -43,5 +43,5 @@ http_filters:
|
||||
使用以下命令可以快速构建 golang filter 插件:
|
||||
|
||||
```bash
|
||||
GO_FILTER_NAME=mcp-server make build
|
||||
make build
|
||||
```
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
|
||||
module github.com/alibaba/higress/plugins/golang-filter
|
||||
|
||||
go 1.23
|
||||
|
||||
25
plugins/golang-filter/main.go
Normal file
25
plugins/golang-filter/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
mcp_server "github.com/alibaba/higress/plugins/golang-filter/mcp-server"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
)
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_session.Name, mcp_session.FilterFactory, &mcp_session.Parser{})
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_server.Name, mcp_server.FilterFactory, &mcp_server.Parser{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||
}()
|
||||
}
|
||||
|
||||
func main() {}
|
||||
@@ -1,64 +1,39 @@
|
||||
package main
|
||||
package mcp_server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
const Name = "mcp-session"
|
||||
const Name = "mcp-server"
|
||||
const Version = "1.0.0"
|
||||
const DefaultServerName = "defaultServer"
|
||||
const ConfigPathSuffix = "/config"
|
||||
|
||||
func init() {
|
||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||
}()
|
||||
type SSEServerWrapper struct {
|
||||
BaseServer *common.SSEServer
|
||||
DomainList []string
|
||||
}
|
||||
|
||||
type config struct {
|
||||
ssePathSuffix string
|
||||
redisClient *internal.RedisClient
|
||||
servers []*internal.SSEServer
|
||||
defaultServer *internal.SSEServer
|
||||
matchList []internal.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
servers []*SSEServerWrapper
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
if c.redisClient != nil {
|
||||
api.LogDebug("Closing Redis client")
|
||||
c.redisClient.Close()
|
||||
}
|
||||
for _, server := range c.servers {
|
||||
server.Close()
|
||||
server.BaseServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
type Parser struct {
|
||||
}
|
||||
|
||||
// Parse the filter configuration
|
||||
func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
configStruct := &xds.TypedStruct{}
|
||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||
return nil, err
|
||||
@@ -66,82 +41,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
v := configStruct.Value
|
||||
|
||||
conf := &config{
|
||||
matchList: make([]internal.MatchRule, 0),
|
||||
servers: make([]*internal.SSEServer, 0),
|
||||
servers: make([]*SSEServerWrapper, 0),
|
||||
}
|
||||
|
||||
// Parse match_list if exists
|
||||
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
||||
for _, item := range matchList {
|
||||
if ruleMap, ok := item.(map[string]interface{}); ok {
|
||||
rule := internal.MatchRule{}
|
||||
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
||||
rule.MatchRuleDomain = domain
|
||||
}
|
||||
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
||||
rule.MatchRulePath = path
|
||||
}
|
||||
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
||||
rule.MatchRuleType = internal.RuleType(ruleType)
|
||||
}
|
||||
conf.matchList = append(conf.matchList, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redis configuration is optional
|
||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := internal.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
conf.redisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
enableUserLevelServer = false
|
||||
if conf.redisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
conf.enableUserLevelServer = enableUserLevelServer
|
||||
|
||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||
rateLimitConfig.Limit = int(limit)
|
||||
}
|
||||
if window, ok := rateLimit["window"].(float64); ok {
|
||||
rateLimitConfig.Window = int(window)
|
||||
}
|
||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||
for _, item := range whiteList {
|
||||
if uid, ok := item.(string); ok {
|
||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||
rateLimitConfig.ErrorText = errorText
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok || ssePathSuffix == "" {
|
||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||
}
|
||||
conf.ssePathSuffix = ssePathSuffix
|
||||
|
||||
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
|
||||
if !ok {
|
||||
api.LogDebug("No servers are configured")
|
||||
@@ -153,19 +55,33 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server config must be an object")
|
||||
}
|
||||
|
||||
serverType, ok := serverConfigMap["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server type is not set")
|
||||
}
|
||||
|
||||
serverPath, ok := serverConfigMap["path"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s path is not set", serverType)
|
||||
}
|
||||
|
||||
serverDomainList := []string{}
|
||||
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
|
||||
for _, domain := range domainList {
|
||||
if domainStr, ok := domain.(string); ok {
|
||||
serverDomainList = append(serverDomainList, domainStr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
serverDomainList = []string{"*"}
|
||||
}
|
||||
|
||||
serverName, ok := serverConfigMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s name is not set", serverType)
|
||||
}
|
||||
server := internal.GlobalRegistry.GetServer(serverType)
|
||||
server := common.GlobalRegistry.GetServer(serverType)
|
||||
|
||||
if server == nil {
|
||||
return nil, fmt.Errorf("server %s is not registered", serverType)
|
||||
@@ -186,50 +102,37 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
|
||||
}
|
||||
|
||||
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
|
||||
internal.WithRedisClient(conf.redisClient),
|
||||
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
|
||||
internal.WithMessageEndpoint(serverPath)))
|
||||
conf.servers = append(conf.servers, &SSEServerWrapper{
|
||||
BaseServer: common.NewSSEServer(serverInstance,
|
||||
common.WithRedisClient(common.GlobalRedisClient),
|
||||
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
||||
common.WithMessageEndpoint(serverPath)),
|
||||
DomainList: serverDomainList,
|
||||
})
|
||||
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
parentConfig := parent.(*config)
|
||||
childConfig := child.(*config)
|
||||
|
||||
newConfig := *parentConfig
|
||||
if childConfig.redisClient != nil {
|
||||
newConfig.redisClient = childConfig.redisClient
|
||||
}
|
||||
if childConfig.ssePathSuffix != "" {
|
||||
newConfig.ssePathSuffix = childConfig.ssePathSuffix
|
||||
}
|
||||
if childConfig.servers != nil {
|
||||
newConfig.servers = childConfig.servers
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
if childConfig.matchList != nil {
|
||||
newConfig.matchList = childConfig.matchList
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
conf, ok := c.(*config)
|
||||
if !ok {
|
||||
panic("unexpected config type")
|
||||
}
|
||||
return &filter{
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
|
||||
config: conf,
|
||||
callbacks: callbacks,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -1,104 +1,41 @@
|
||||
package main
|
||||
package mcp_server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||
)
|
||||
|
||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||
type filter struct {
|
||||
api.PassThroughStreamFilter
|
||||
|
||||
callbacks api.FilterCallbackHandler
|
||||
path string
|
||||
config *config
|
||||
stopChan chan struct{}
|
||||
|
||||
req *http.Request
|
||||
serverName string
|
||||
message bool
|
||||
proxyURL *url.URL
|
||||
skip bool
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
ratelimit bool
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
config *config
|
||||
req *http.Request
|
||||
message bool
|
||||
path string
|
||||
}
|
||||
|
||||
type RequestURL struct {
|
||||
method string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
baseURL string
|
||||
parsedURL *url.URL
|
||||
internalIP bool
|
||||
}
|
||||
|
||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
method, _ := header.Get(":method")
|
||||
scheme, _ := header.Get(":scheme")
|
||||
host, _ := header.Get(":authority")
|
||||
path, _ := header.Get(":path")
|
||||
internalIP, _ := header.Get("x-envoy-internal")
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
parsedURL, _ := url.Parse(path)
|
||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
// The endStream is true if the request doesn't have body
|
||||
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
url := NewRequestURL(header)
|
||||
f.path = url.parsedURL.Path
|
||||
|
||||
// Check if request matches any rule in match_list
|
||||
if !internal.IsMatch(f.config.matchList, url.host, f.path) {
|
||||
f.skip = true
|
||||
api.LogDebugf("Request does not match any rule in match_list: %s", url.parsedURL.String())
|
||||
return api.Continue
|
||||
}
|
||||
url := common.NewRequestURL(header)
|
||||
f.path = url.ParsedURL.Path
|
||||
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetSSEEndpoint() {
|
||||
if url.method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.serverName = server.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
api.LogDebugf("%s SSE connection started", server.GetServerName())
|
||||
return api.LocalReply
|
||||
} else if f.path == server.GetMessageEndpoint() {
|
||||
if url.method != http.MethodPost {
|
||||
if common.MatchDomainList(url.ParsedURL.Host, server.DomainList) && url.ParsedURL.Path == server.BaseServer.GetMessageEndpoint() {
|
||||
if url.Method != http.MethodPost {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
// Create a new http.Request object
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
Method: url.Method,
|
||||
URL: url.ParsedURL,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
api.LogDebugf("Message request: %v", url.parsedURL)
|
||||
api.LogDebugf("Message request: %v", url.ParsedURL)
|
||||
// Copy headers from api.RequestHeaderMap to http.Header
|
||||
header.Range(func(key, value string) bool {
|
||||
f.req.Header.Add(key, value)
|
||||
@@ -113,209 +50,33 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
}
|
||||
|
||||
f.req = &http.Request{
|
||||
Method: url.method,
|
||||
URL: url.parsedURL,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.internalIP {
|
||||
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
|
||||
f.proxyURL = url.parsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.parsedURL.Path, "/")
|
||||
if len(parts) >= 3 {
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
}
|
||||
}
|
||||
f.ratelimit = true
|
||||
}
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if url.method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.config.defaultServer = internal.NewSSEServer(internal.NewMCPServer(DefaultServerName, Version),
|
||||
internal.WithSSEEndpoint(f.config.ssePathSuffix),
|
||||
internal.WithMessageEndpoint(strings.TrimSuffix(url.parsedURL.Path, f.config.ssePathSuffix)),
|
||||
internal.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
return api.LocalReply
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.message {
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.GetMessageEndpoint() {
|
||||
if f.path == server.BaseServer.GetMessageEndpoint() {
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
httpStatus := server.BaseServer.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||
f.message = false
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
} else if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
} else if f.ratelimit {
|
||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||
return api.Continue
|
||||
}
|
||||
parts := strings.Split(f.req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogWarnf("Access denied: no valid uid found")
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// Callbacks which are called in response path
|
||||
// The endStream is true if the response doesn't have body
|
||||
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
} else {
|
||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// EncodeData might be called multiple times during handling the response body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil && f.config.redisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := internal.GetSSEChannelName(sessionID)
|
||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||
publishErr := f.config.redisClient.Publish(channel, eventData)
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
// handle specific server
|
||||
for _, server := range f.config.servers {
|
||||
if f.serverName == server.GetServerName() {
|
||||
buffer.Reset()
|
||||
server.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
}
|
||||
// handle default server
|
||||
if f.serverName == f.config.defaultServer.GetServerName() {
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
}
|
||||
return api.Continue
|
||||
} else {
|
||||
buffer.SetString(RedisNotEnabledResponseBody)
|
||||
return api.Continue
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// OnDestroy stops the goroutine
|
||||
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||
api.LogDebugf("OnDestroy: reason=%v", reason)
|
||||
if f.serverName != "" && f.stopChan != nil {
|
||||
select {
|
||||
case <-f.stopChan:
|
||||
return
|
||||
default:
|
||||
api.LogDebug("Stopping SSE connection")
|
||||
close(f.stopChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if the request is a tools/call request
|
||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||
var request mcp.CallToolRequest
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return request.Method == method
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
internal.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
||||
common.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
||||
}
|
||||
|
||||
type NacosConfig struct {
|
||||
@@ -28,7 +28,7 @@ type NacosConfig struct {
|
||||
}
|
||||
|
||||
type McpServerToolsChangeListener struct {
|
||||
mcpServer *internal.MCPServer
|
||||
mcpServer *common.MCPServer
|
||||
}
|
||||
|
||||
func (l *McpServerToolsChangeListener) OnToolChanged(reg registry.McpServerRegistry) {
|
||||
@@ -137,8 +137,8 @@ func (c *NacosConfig) ParseConfig(config map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
func (c *NacosConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
"1.0.0",
|
||||
)
|
||||
@@ -170,11 +170,11 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
|
||||
return mcpServer, nil
|
||||
}
|
||||
|
||||
func resetToolsToMcpServer(mcpServer *internal.MCPServer, reg registry.McpServerRegistry) {
|
||||
wrappedTools := []internal.ServerTool{}
|
||||
func resetToolsToMcpServer(mcpServer *common.MCPServer, reg registry.McpServerRegistry) {
|
||||
wrappedTools := []common.ServerTool{}
|
||||
tools := reg.ListToolsDesciption()
|
||||
for _, tool := range tools {
|
||||
wrappedTools = append(wrappedTools, internal.ServerTool{
|
||||
wrappedTools = append(wrappedTools, common.ServerTool{
|
||||
Tool: mcp.NewToolWithRawSchema(tool.Name, tool.Description, tool.InputSchema),
|
||||
Handler: registry.HandleRegistryToolsCall(reg),
|
||||
})
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
@@ -204,7 +204,7 @@ func CommonRemoteCall(reg McpServerRegistry, toolName string, parameters map[str
|
||||
return remoteHandle.HandleToolCall(ctx, parameters)
|
||||
}
|
||||
|
||||
func HandleRegistryToolsCall(reg McpServerRegistry) internal.ToolHandlerFunc {
|
||||
func HandleRegistryToolsCall(reg McpServerRegistry) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
return CommonRemoteCall(reg, request.Params.Name, arguments)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
const Version = "1.0.0"
|
||||
|
||||
func init() {
|
||||
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||
common.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
@@ -41,11 +41,11 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
||||
mcpServer := internal.NewMCPServer(
|
||||
func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
Version,
|
||||
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||
common.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||
)
|
||||
|
||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// HandleQueryTool handles SQL query execution
|
||||
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
||||
func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["sql"].(string)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
@@ -23,6 +23,27 @@ type MatchRule struct {
|
||||
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
|
||||
}
|
||||
|
||||
// ParseMatchList parses the match list from the config
|
||||
func ParseMatchList(matchListConfig []interface{}) []MatchRule {
|
||||
matchList := make([]MatchRule, 0)
|
||||
for _, item := range matchListConfig {
|
||||
if ruleMap, ok := item.(map[string]interface{}); ok {
|
||||
rule := MatchRule{}
|
||||
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
||||
rule.MatchRuleDomain = domain
|
||||
}
|
||||
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
||||
rule.MatchRulePath = path
|
||||
}
|
||||
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
||||
rule.MatchRuleType = RuleType(ruleType)
|
||||
}
|
||||
matchList = append(matchList, rule)
|
||||
}
|
||||
}
|
||||
return matchList
|
||||
}
|
||||
|
||||
// convertWildcardToRegex converts wildcard pattern to regex pattern
|
||||
func convertWildcardToRegex(pattern string) string {
|
||||
pattern = regexp.QuoteMeta(pattern)
|
||||
@@ -87,3 +108,13 @@ func IsMatch(rules []MatchRule, host, path string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchDomainList checks if the domain matches any of the domains in the list
|
||||
func MatchDomainList(domain string, domainList []string) bool {
|
||||
for _, d := range domainList {
|
||||
if matchDomain(domain, d) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var GlobalRedisClient *RedisClient
|
||||
|
||||
type RedisConfig struct {
|
||||
address string
|
||||
username string
|
||||
@@ -249,6 +251,18 @@ func (r *RedisClient) Get(key string) (string, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Expire sets the expiration time for a key
|
||||
func (r *RedisClient) Expire(key string, expiration time.Duration) error {
|
||||
ok, err := r.client.Expire(r.ctx, key, expiration).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set expiration for key: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("key does not exist")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the Redis client and stops the keepalive goroutine
|
||||
func (r *RedisClient) Close() error {
|
||||
r.cancel()
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
var GlobalRegistry = NewServerRegistry()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -243,6 +243,7 @@ func (s *MCPServer) HandleMessage(
|
||||
message json.RawMessage,
|
||||
) mcp.JSONRPCMessage {
|
||||
// Add server to context
|
||||
|
||||
ctx = context.WithValue(ctx, serverKey{}, s)
|
||||
|
||||
var baseMessage struct {
|
||||
@@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -210,15 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
var status int
|
||||
// Only send response if there is one (not for notifications)
|
||||
if response != nil {
|
||||
eventData, _ := json.Marshal(response)
|
||||
|
||||
if sessionID != "" && s.redisClient != nil {
|
||||
channel := GetSSEChannelName(sessionID)
|
||||
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))
|
||||
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
} else {
|
||||
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
type RequestURL struct {
|
||||
Method string
|
||||
Scheme string
|
||||
Host string
|
||||
Path string
|
||||
BaseURL string
|
||||
ParsedURL *url.URL
|
||||
InternalIP bool
|
||||
}
|
||||
|
||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||
method, _ := header.Get(":method")
|
||||
scheme, _ := header.Get(":scheme")
|
||||
host, _ := header.Get(":authority")
|
||||
path, _ := header.Get(":path")
|
||||
internalIP, _ := header.Get("x-envoy-internal")
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
parsedURL, _ := url.Parse(path)
|
||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
|
||||
}
|
||||
143
plugins/golang-filter/mcp-session/config.go
Normal file
143
plugins/golang-filter/mcp-session/config.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package mcp_session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
const Name = "mcp-session"
|
||||
const Version = "1.0.0"
|
||||
const ConfigPathSuffix = "/config"
|
||||
const DefaultServerName = "higress-mcp-server"
|
||||
|
||||
var GlobalSSEPathSuffix = "/sse"
|
||||
|
||||
type config struct {
|
||||
matchList []common.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
defaultServer *common.SSEServer
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
if common.GlobalRedisClient != nil {
|
||||
api.LogDebug("Closing Redis client")
|
||||
common.GlobalRedisClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
}
|
||||
|
||||
// Parse the filter configuration
|
||||
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||
configStruct := &xds.TypedStruct{}
|
||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := configStruct.Value
|
||||
|
||||
conf := &config{
|
||||
matchList: make([]common.MatchRule, 0),
|
||||
}
|
||||
|
||||
// Parse match_list if exists
|
||||
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
||||
conf.matchList = common.ParseMatchList(matchList)
|
||||
}
|
||||
|
||||
// Redis configuration is optional
|
||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||
redisConfig, err := common.ParseRedisConfig(redisConfigMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||
}
|
||||
|
||||
redisClient, err := common.NewRedisClient(redisConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||
}
|
||||
common.GlobalRedisClient = redisClient
|
||||
api.LogDebug("Redis client initialized")
|
||||
} else {
|
||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||
}
|
||||
|
||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||
if !ok {
|
||||
enableUserLevelServer = false
|
||||
if common.GlobalRedisClient == nil {
|
||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||
}
|
||||
}
|
||||
conf.enableUserLevelServer = enableUserLevelServer
|
||||
|
||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||
rateLimitConfig.Limit = int(limit)
|
||||
}
|
||||
if window, ok := rateLimit["window"].(float64); ok {
|
||||
rateLimitConfig.Window = int(window)
|
||||
}
|
||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||
for _, item := range whiteList {
|
||||
if uid, ok := item.(string); ok {
|
||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||
rateLimitConfig.ErrorText = errorText
|
||||
}
|
||||
conf.rateLimitConfig = rateLimitConfig
|
||||
}
|
||||
|
||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||
if !ok || ssePathSuffix == "" {
|
||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||
}
|
||||
GlobalSSEPathSuffix = ssePathSuffix
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
parentConfig := parent.(*config)
|
||||
childConfig := child.(*config)
|
||||
|
||||
newConfig := *parentConfig
|
||||
if childConfig.matchList != nil {
|
||||
newConfig.matchList = childConfig.matchList
|
||||
}
|
||||
newConfig.enableUserLevelServer = childConfig.enableUserLevelServer
|
||||
if childConfig.rateLimitConfig != nil {
|
||||
newConfig.rateLimitConfig = childConfig.rateLimitConfig
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||
conf, ok := c.(*config)
|
||||
if !ok {
|
||||
panic("unexpected config type")
|
||||
}
|
||||
return &filter{
|
||||
callbacks: callbacks,
|
||||
config: conf,
|
||||
stopChan: make(chan struct{}),
|
||||
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
|
||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
|
||||
}
|
||||
}
|
||||
237
plugins/golang-filter/mcp-session/filter.go
Normal file
237
plugins/golang-filter/mcp-session/filter.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package mcp_session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||
)
|
||||
|
||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||
type filter struct {
|
||||
api.PassThroughStreamFilter
|
||||
|
||||
callbacks api.FilterCallbackHandler
|
||||
path string
|
||||
config *config
|
||||
stopChan chan struct{}
|
||||
|
||||
req *http.Request
|
||||
serverName string
|
||||
proxyURL *url.URL
|
||||
skip bool
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
ratelimit bool
|
||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
// The endStream is true if the request doesn't have body
|
||||
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
url := common.NewRequestURL(header)
|
||||
f.path = url.ParsedURL.Path
|
||||
|
||||
// Check if request matches any rule in match_list
|
||||
if !common.IsMatch(f.config.matchList, url.Host, f.path) {
|
||||
f.skip = true
|
||||
api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String())
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
f.req = &http.Request{
|
||||
Method: url.Method,
|
||||
URL: url.ParsedURL,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||
if !url.InternalIP {
|
||||
api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String())
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet {
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||
return api.LocalReply
|
||||
}
|
||||
f.userLevelConfig = true
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) {
|
||||
f.proxyURL = url.ParsedURL
|
||||
if f.config.enableUserLevelServer {
|
||||
parts := strings.Split(url.ParsedURL.Path, "/")
|
||||
if len(parts) >= 3 {
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
// Get encoded config
|
||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if encodedConfig != "" {
|
||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||
}
|
||||
}
|
||||
f.ratelimit = true
|
||||
}
|
||||
if endStream {
|
||||
return api.Continue
|
||||
} else {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
}
|
||||
|
||||
if url.Method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
|
||||
common.WithRedisClient(common.GlobalRedisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
return api.LocalReply
|
||||
}
|
||||
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.userLevelConfig {
|
||||
// Handle config POST request
|
||||
api.LogDebugf("Handling config request: %s", f.path)
|
||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||
return api.LocalReply
|
||||
} else if f.ratelimit {
|
||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||
return api.Continue
|
||||
}
|
||||
parts := strings.Split(f.req.URL.Path, "/")
|
||||
if len(parts) < 3 {
|
||||
api.LogWarnf("Access denied: no valid uid found")
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
}
|
||||
serverName := parts[1]
|
||||
uid := parts[2]
|
||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||
if err != nil {
|
||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||
return api.LocalReply
|
||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||
return api.LocalReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// Callbacks which are called in response path
|
||||
// The endStream is true if the response doesn't have body
|
||||
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
header.Set("Content-Type", "text/event-stream")
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Del("Content-Length")
|
||||
} else {
|
||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// EncodeData might be called multiple times during handling the response body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if f.skip {
|
||||
return api.Continue
|
||||
}
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.proxyURL != nil && common.GlobalRedisClient != nil {
|
||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||
if sessionID != "" {
|
||||
channel := common.GetSSEChannelName(sessionID)
|
||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
|
||||
if publishErr != nil {
|
||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f.serverName != "" {
|
||||
if common.GlobalRedisClient != nil {
|
||||
// handle default server
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
} else {
|
||||
buffer.SetString(RedisNotEnabledResponseBody)
|
||||
return api.Continue
|
||||
}
|
||||
}
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
// OnDestroy stops the goroutine
|
||||
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||
api.LogDebugf("OnDestroy: reason=%v", reason)
|
||||
if f.serverName != "" && f.stopChan != nil {
|
||||
select {
|
||||
case <-f.stopChan:
|
||||
return
|
||||
default:
|
||||
api.LogDebug("Stopping SSE connection")
|
||||
close(f.stopChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if the request is a tools/call request
|
||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||
var request mcp.CallToolRequest
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return request.Method == method
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ type MCPConfigHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPConfigHandler creates a new instance of MCP configuration handler
|
||||
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||
func NewMCPConfigHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||
return &MCPConfigHandler{
|
||||
configStore: NewRedisConfigStore(redisClient),
|
||||
callbacks: callbacks,
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,11 +36,11 @@ type ConfigStore interface {
|
||||
|
||||
// RedisConfigStore implements configuration storage using Redis
|
||||
type RedisConfigStore struct {
|
||||
redisClient *internal.RedisClient
|
||||
redisClient *common.RedisClient
|
||||
}
|
||||
|
||||
// NewRedisConfigStore creates a new instance of Redis configuration storage
|
||||
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
|
||||
func NewRedisConfigStore(redisClient *common.RedisClient) ConfigStore {
|
||||
return &RedisConfigStore{
|
||||
redisClient: redisClient,
|
||||
}
|
||||
@@ -101,5 +101,11 @@ func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh TTL
|
||||
if err := s.redisClient.Expire(key, configExpiry); err != nil {
|
||||
// Log error but don't fail the request
|
||||
fmt.Printf("Failed to refresh TTL for key %s: %v\n", key, err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type MCPRatelimitHandler struct {
|
||||
redisClient *internal.RedisClient
|
||||
redisClient *common.RedisClient
|
||||
callbacks api.FilterCallbackHandler
|
||||
limit int // Maximum requests allowed per window
|
||||
window int // Time window in seconds
|
||||
@@ -31,7 +31,7 @@ type MCPRatelimitConfig struct {
|
||||
}
|
||||
|
||||
// NewMCPRatelimitHandler creates a new rate limit handler
|
||||
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||
func NewMCPRatelimitHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||
if conf == nil {
|
||||
conf = &MCPRatelimitConfig{
|
||||
Limit: 100,
|
||||
@@ -16,24 +16,16 @@
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
INNER_GO_FILTER_NAME=${GO_FILTER_NAME-""}
|
||||
OUTPUT_PACKAGE_DIR=${OUTPUT_PACKAGE_DIR:-"../../external/package/"}
|
||||
|
||||
cd ./plugins/golang-filter
|
||||
if [ ! -n "$INNER_GO_FILTER_NAME" ]; then
|
||||
GO_FILTERS_DIR=$(pwd)
|
||||
echo "🚀 Build all Go Filters under folder of $GO_FILTERS_DIR"
|
||||
for file in `ls $GO_FILTERS_DIR`
|
||||
do
|
||||
if [ -d $GO_FILTERS_DIR/$file ]; then
|
||||
name=${file##*/}
|
||||
echo "🚀 Build Go Filter: $name"
|
||||
GO_FILTER_NAME=${name} GOARCH=${TARGET_ARCH} make build
|
||||
cp ${GO_FILTERS_DIR}/${file}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "🚀 Build Go Filter: $INNER_GO_FILTER_NAME"
|
||||
GO_FILTER_NAME=${INNER_GO_FILTER_NAME} make build
|
||||
fi
|
||||
cd plugins/golang-filter
|
||||
|
||||
GO_FILTERS_DIR=$(pwd)
|
||||
|
||||
echo "🚀 Build Go Filter"
|
||||
|
||||
GOARCH=${TARGET_ARCH} make build
|
||||
|
||||
cp ${GO_FILTERS_DIR}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user