mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
fix: Refactor MCP Server into MCP Session and MCP Server (#2120)
This commit is contained in:
@@ -61,6 +61,8 @@ type SSEServer struct {
|
|||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
// Additional Config parameters for the real MCP server implementation
|
// Additional Config parameters for the real MCP server implementation
|
||||||
Config map[string]interface{} `json:"config,omitempty"`
|
Config map[string]interface{} `json:"config,omitempty"`
|
||||||
|
// The domain list of the SSE server
|
||||||
|
DomainList []string `json:"domain_list,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MatchRule defines a rule for matching requests
|
// MatchRule defines a rule for matching requests
|
||||||
@@ -179,9 +181,10 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) {
|
|||||||
newMcp.Servers = make([]*SSEServer, len(mcp.Servers))
|
newMcp.Servers = make([]*SSEServer, len(mcp.Servers))
|
||||||
for i, server := range mcp.Servers {
|
for i, server := range mcp.Servers {
|
||||||
newServer := &SSEServer{
|
newServer := &SSEServer{
|
||||||
Name: server.Name,
|
Name: server.Name,
|
||||||
Path: server.Path,
|
Path: server.Path,
|
||||||
Type: server.Type,
|
Type: server.Type,
|
||||||
|
DomainList: server.DomainList,
|
||||||
}
|
}
|
||||||
if server.Config != nil {
|
if server.Config != nil {
|
||||||
newServer.Config = make(map[string]interface{})
|
newServer.Config = make(map[string]interface{})
|
||||||
@@ -294,73 +297,88 @@ func (m *McpServerController) ConstructEnvoyFilters() ([]*config.Config, error)
|
|||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpStruct := m.constructMcpServerStruct(mcpServer)
|
// mcp-session envoy filter
|
||||||
if mcpStruct == "" {
|
mcpSessionStruct := m.constructMcpSessionStruct(mcpServer)
|
||||||
return configs, nil
|
if mcpSessionStruct != "" {
|
||||||
}
|
sessionConfig := &config.Config{
|
||||||
|
Meta: config.Meta{
|
||||||
config := &config.Config{
|
GroupVersionKind: gvk.EnvoyFilter,
|
||||||
Meta: config.Meta{
|
Name: higressMcpServerEnvoyFilterName,
|
||||||
GroupVersionKind: gvk.EnvoyFilter,
|
Namespace: namespace,
|
||||||
Name: higressMcpServerEnvoyFilterName,
|
},
|
||||||
Namespace: namespace,
|
Spec: &networking.EnvoyFilter{
|
||||||
},
|
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||||
Spec: &networking.EnvoyFilter{
|
{
|
||||||
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||||
{
|
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||||
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
Context: networking.EnvoyFilter_GATEWAY,
|
||||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||||
Context: networking.EnvoyFilter_GATEWAY,
|
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||||
Listener: &networking.EnvoyFilter_ListenerMatch{
|
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||||
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
Name: "envoy.filters.network.http_connection_manager",
|
||||||
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||||
Name: "envoy.filters.network.http_connection_manager",
|
Name: "envoy.filters.http.cors",
|
||||||
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
},
|
||||||
Name: "envoy.filters.http.cors",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
Patch: &networking.EnvoyFilter_Patch{
|
||||||
Patch: &networking.EnvoyFilter_Patch{
|
Operation: networking.EnvoyFilter_Patch_INSERT_AFTER,
|
||||||
Operation: networking.EnvoyFilter_Patch_INSERT_AFTER,
|
Value: util.BuildPatchStruct(mcpSessionStruct),
|
||||||
Value: util.BuildPatchStruct(mcpStruct),
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
|
configs = append(configs, sessionConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mcp-server envoy filter
|
||||||
|
mcpServerStruct := m.constructMcpServerStruct(mcpServer)
|
||||||
|
if mcpServerStruct != "" {
|
||||||
|
serverConfig := &config.Config{
|
||||||
|
Meta: config.Meta{
|
||||||
|
GroupVersionKind: gvk.EnvoyFilter,
|
||||||
|
Name: higressMcpServerEnvoyFilterName + "-server",
|
||||||
|
Namespace: namespace,
|
||||||
|
},
|
||||||
|
Spec: &networking.EnvoyFilter{
|
||||||
|
ConfigPatches: []*networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||||
|
{
|
||||||
|
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||||
|
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||||
|
Context: networking.EnvoyFilter_GATEWAY,
|
||||||
|
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||||
|
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||||
|
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||||
|
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||||
|
Name: "envoy.filters.network.http_connection_manager",
|
||||||
|
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||||
|
Name: "envoy.filters.http.router",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Patch: &networking.EnvoyFilter_Patch{
|
||||||
|
Operation: networking.EnvoyFilter_Patch_INSERT_BEFORE,
|
||||||
|
Value: util.BuildPatchStruct(mcpServerStruct),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
configs = append(configs, serverConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
configs = append(configs, config)
|
|
||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||||
// Build servers configuration
|
|
||||||
servers := "[]"
|
|
||||||
if len(mcp.Servers) > 0 {
|
|
||||||
serverConfigs := make([]string, len(mcp.Servers))
|
|
||||||
for i, server := range mcp.Servers {
|
|
||||||
serverConfig := fmt.Sprintf(`{
|
|
||||||
"name": "%s",
|
|
||||||
"path": "%s",
|
|
||||||
"type": "%s"`,
|
|
||||||
server.Name, server.Path, server.Type)
|
|
||||||
|
|
||||||
if len(server.Config) > 0 {
|
|
||||||
config, _ := json.Marshal(server.Config)
|
|
||||||
serverConfig += fmt.Sprintf(`,
|
|
||||||
"config": %s`, string(config))
|
|
||||||
}
|
|
||||||
|
|
||||||
serverConfig += "}"
|
|
||||||
serverConfigs[i] = serverConfig
|
|
||||||
}
|
|
||||||
servers = fmt.Sprintf("[%s]", strings.Join(serverConfigs, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build match_list configuration
|
// Build match_list configuration
|
||||||
matchList := "[]"
|
matchList := "[]"
|
||||||
if len(mcp.MatchList) > 0 {
|
if len(mcp.MatchList) > 0 {
|
||||||
@@ -375,7 +393,7 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
|||||||
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建 Redis 配置
|
// Build redis configuration
|
||||||
redisConfig := "null"
|
redisConfig := "null"
|
||||||
if mcp.Redis != nil {
|
if mcp.Redis != nil {
|
||||||
redisConfig = fmt.Sprintf(`{
|
redisConfig = fmt.Sprintf(`{
|
||||||
@@ -386,7 +404,7 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
|||||||
}`, mcp.Redis.Address, mcp.Redis.Username, mcp.Redis.Password, mcp.Redis.DB)
|
}`, mcp.Redis.Address, mcp.Redis.Username, mcp.Redis.Password, mcp.Redis.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建限流配置
|
// Build rate limit configuration
|
||||||
rateLimitConfig := "null"
|
rateLimitConfig := "null"
|
||||||
if mcp.Ratelimit != nil {
|
if mcp.Ratelimit != nil {
|
||||||
whiteList := "[]"
|
whiteList := "[]"
|
||||||
@@ -417,7 +435,6 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
|||||||
"rate_limit": %s,
|
"rate_limit": %s,
|
||||||
"sse_path_suffix": "%s",
|
"sse_path_suffix": "%s",
|
||||||
"match_list": %s,
|
"match_list": %s,
|
||||||
"servers": %s,
|
|
||||||
"enable_user_level_server": %t
|
"enable_user_level_server": %t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -428,6 +445,53 @@ func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
|||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
mcp.SsePathSuffix,
|
mcp.SsePathSuffix,
|
||||||
matchList,
|
matchList,
|
||||||
servers,
|
|
||||||
mcp.EnableUserLevelServer)
|
mcp.EnableUserLevelServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||||
|
// Build servers configuration
|
||||||
|
servers := "[]"
|
||||||
|
if len(mcp.Servers) > 0 {
|
||||||
|
serverConfigs := make([]string, len(mcp.Servers))
|
||||||
|
for i, server := range mcp.Servers {
|
||||||
|
serverConfig := fmt.Sprintf(`{
|
||||||
|
"name": "%s",
|
||||||
|
"path": "%s",
|
||||||
|
"type": "%s"`,
|
||||||
|
server.Name, server.Path, server.Type)
|
||||||
|
if len(server.DomainList) > 0 {
|
||||||
|
domainList := fmt.Sprintf(`["%s"]`, strings.Join(server.DomainList, `","`))
|
||||||
|
serverConfig += fmt.Sprintf(`,
|
||||||
|
"domain_list": %s`, domainList)
|
||||||
|
}
|
||||||
|
if len(server.Config) > 0 {
|
||||||
|
config, _ := json.Marshal(server.Config)
|
||||||
|
serverConfig += fmt.Sprintf(`,
|
||||||
|
"config": %s`, string(config))
|
||||||
|
}
|
||||||
|
serverConfig += "}"
|
||||||
|
serverConfigs[i] = serverConfig
|
||||||
|
}
|
||||||
|
servers = fmt.Sprintf("[%s]", strings.Join(serverConfigs, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build complete configuration structure
|
||||||
|
return fmt.Sprintf(`{
|
||||||
|
"name": "envoy.filters.http.golang",
|
||||||
|
"typed_config": {
|
||||||
|
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||||
|
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||||
|
"value": {
|
||||||
|
"library_id": "mcp-server",
|
||||||
|
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||||
|
"plugin_name": "mcp-server",
|
||||||
|
"plugin_config": {
|
||||||
|
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||||
|
"value": {
|
||||||
|
"servers": %s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`, servers)
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
package configmap
|
package configmap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -422,3 +423,311 @@ func TestMcpServerController_AddOrUpdateHigressConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMcpServerController_ValidHigressConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
higressConfig *HigressConfig
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config",
|
||||||
|
higressConfig: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil mcp server",
|
||||||
|
higressConfig: &HigressConfig{
|
||||||
|
McpServer: nil,
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
higressConfig: &HigressConfig{
|
||||||
|
McpServer: &McpServer{
|
||||||
|
Enable: true,
|
||||||
|
Redis: &RedisConfig{
|
||||||
|
Address: "localhost:6379",
|
||||||
|
},
|
||||||
|
MatchList: []*MatchRule{},
|
||||||
|
Servers: []*SSEServer{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid config - user level server without redis",
|
||||||
|
higressConfig: &HigressConfig{
|
||||||
|
McpServer: &McpServer{
|
||||||
|
Enable: true,
|
||||||
|
EnableUserLevelServer: true,
|
||||||
|
Redis: nil,
|
||||||
|
MatchList: []*MatchRule{},
|
||||||
|
Servers: []*SSEServer{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: errors.New("redis config cannot be empty when user level server is enabled"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := NewMcpServerController("test-namespace")
|
||||||
|
err := m.ValidHigressConfig(tt.higressConfig)
|
||||||
|
assert.Equal(t, tt.wantErr, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpServerController_ConstructEnvoyFilters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mcpServer *McpServer
|
||||||
|
wantConfigs int
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil mcp server",
|
||||||
|
mcpServer: nil,
|
||||||
|
wantConfigs: 0,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled mcp server",
|
||||||
|
mcpServer: &McpServer{
|
||||||
|
Enable: false,
|
||||||
|
},
|
||||||
|
wantConfigs: 0,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid mcp server with redis",
|
||||||
|
mcpServer: &McpServer{
|
||||||
|
Enable: true,
|
||||||
|
Redis: &RedisConfig{
|
||||||
|
Address: "localhost:6379",
|
||||||
|
},
|
||||||
|
MatchList: []*MatchRule{},
|
||||||
|
Servers: []*SSEServer{},
|
||||||
|
},
|
||||||
|
wantConfigs: 2, // Both session and server filters
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := NewMcpServerController("test-namespace")
|
||||||
|
m.mcpServer.Store(tt.mcpServer)
|
||||||
|
configs, err := m.ConstructEnvoyFilters()
|
||||||
|
assert.Equal(t, tt.wantErr, err)
|
||||||
|
assert.Equal(t, tt.wantConfigs, len(configs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpServerController_constructMcpSessionStruct(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mcp *McpServer
|
||||||
|
wantJSON string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "minimal config",
|
||||||
|
mcp: &McpServer{
|
||||||
|
Enable: true,
|
||||||
|
Redis: &RedisConfig{
|
||||||
|
Address: "localhost:6379",
|
||||||
|
},
|
||||||
|
MatchList: []*MatchRule{},
|
||||||
|
Servers: []*SSEServer{},
|
||||||
|
},
|
||||||
|
wantJSON: `{
|
||||||
|
"name": "envoy.filters.http.golang",
|
||||||
|
"typed_config": {
|
||||||
|
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||||
|
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||||
|
"value": {
|
||||||
|
"library_id": "mcp-session",
|
||||||
|
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||||
|
"plugin_name": "mcp-session",
|
||||||
|
"plugin_config": {
|
||||||
|
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||||
|
"value": {
|
||||||
|
"redis": {
|
||||||
|
"address": "localhost:6379",
|
||||||
|
"username": "",
|
||||||
|
"password": "",
|
||||||
|
"db": 0
|
||||||
|
},
|
||||||
|
"rate_limit": null,
|
||||||
|
"sse_path_suffix": "",
|
||||||
|
"match_list": [],
|
||||||
|
"enable_user_level_server": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full config",
|
||||||
|
mcp: &McpServer{
|
||||||
|
Enable: true,
|
||||||
|
Redis: &RedisConfig{
|
||||||
|
Address: "localhost:6379",
|
||||||
|
Username: "user",
|
||||||
|
Password: "pass",
|
||||||
|
DB: 1,
|
||||||
|
},
|
||||||
|
SsePathSuffix: "/sse",
|
||||||
|
MatchList: []*MatchRule{
|
||||||
|
{
|
||||||
|
MatchRuleDomain: "*",
|
||||||
|
MatchRulePath: "/test",
|
||||||
|
MatchRuleType: "exact",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
EnableUserLevelServer: true,
|
||||||
|
Ratelimit: &MCPRatelimitConfig{
|
||||||
|
Limit: 100,
|
||||||
|
Window: 3600,
|
||||||
|
WhiteList: []string{"user1", "user2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantJSON: `{
|
||||||
|
"name": "envoy.filters.http.golang",
|
||||||
|
"typed_config": {
|
||||||
|
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||||
|
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||||
|
"value": {
|
||||||
|
"library_id": "mcp-session",
|
||||||
|
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||||
|
"plugin_name": "mcp-session",
|
||||||
|
"plugin_config": {
|
||||||
|
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||||
|
"value": {
|
||||||
|
"redis": {
|
||||||
|
"address": "localhost:6379",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"db": 1
|
||||||
|
},
|
||||||
|
"rate_limit": {
|
||||||
|
"limit": 100,
|
||||||
|
"window": 3600,
|
||||||
|
"white_list": ["user1","user2"]
|
||||||
|
},
|
||||||
|
"sse_path_suffix": "/sse",
|
||||||
|
"match_list": [{
|
||||||
|
"match_rule_domain": "*",
|
||||||
|
"match_rule_path": "/test",
|
||||||
|
"match_rule_type": "exact"
|
||||||
|
}],
|
||||||
|
"enable_user_level_server": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := NewMcpServerController("test-namespace")
|
||||||
|
got := m.constructMcpSessionStruct(tt.mcp)
|
||||||
|
// Normalize JSON strings for comparison
|
||||||
|
var gotJSON, wantJSON interface{}
|
||||||
|
json.Unmarshal([]byte(got), &gotJSON)
|
||||||
|
json.Unmarshal([]byte(tt.wantJSON), &wantJSON)
|
||||||
|
assert.Equal(t, wantJSON, gotJSON)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpServerController_constructMcpServerStruct(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mcp *McpServer
|
||||||
|
wantJSON string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no servers",
|
||||||
|
mcp: &McpServer{
|
||||||
|
Servers: []*SSEServer{},
|
||||||
|
},
|
||||||
|
wantJSON: `{
|
||||||
|
"name": "envoy.filters.http.golang",
|
||||||
|
"typed_config": {
|
||||||
|
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||||
|
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||||
|
"value": {
|
||||||
|
"library_id": "mcp-server",
|
||||||
|
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||||
|
"plugin_name": "mcp-server",
|
||||||
|
"plugin_config": {
|
||||||
|
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||||
|
"value": {
|
||||||
|
"servers": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with servers",
|
||||||
|
mcp: &McpServer{
|
||||||
|
Servers: []*SSEServer{
|
||||||
|
{
|
||||||
|
Name: "test-server",
|
||||||
|
Path: "/test",
|
||||||
|
Type: "test",
|
||||||
|
Config: map[string]interface{}{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
DomainList: []string{"example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantJSON: `{
|
||||||
|
"name": "envoy.filters.http.golang",
|
||||||
|
"typed_config": {
|
||||||
|
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||||
|
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||||
|
"value": {
|
||||||
|
"library_id": "mcp-server",
|
||||||
|
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||||
|
"plugin_name": "mcp-server",
|
||||||
|
"plugin_config": {
|
||||||
|
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||||
|
"value": {
|
||||||
|
"servers": [{
|
||||||
|
"name": "test-server",
|
||||||
|
"path": "/test",
|
||||||
|
"type": "test",
|
||||||
|
"domain_list": ["example.com"],
|
||||||
|
"config": {"key":"value"}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := NewMcpServerController("test-namespace")
|
||||||
|
got := m.constructMcpServerStruct(tt.mcp)
|
||||||
|
// Normalize JSON strings for comparison
|
||||||
|
var gotJSON, wantJSON interface{}
|
||||||
|
json.Unmarshal([]byte(got), &gotJSON)
|
||||||
|
json.Unmarshal([]byte(tt.wantJSON), &wantJSON)
|
||||||
|
assert.Equal(t, wantJSON, gotJSON)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
WORKDIR /workspace/$GO_FILTER_NAME
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN go mod tidy
|
RUN go mod tidy
|
||||||
RUN if [ "$GOARCH" = "arm64" ]; then \
|
RUN if [ "$GOARCH" = "arm64" ]; then \
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
GO_FILTER_NAME ?= mcp-server
|
GO_FILTER_NAME ?= golang-filter
|
||||||
GOPROXY := $(shell go env GOPROXY)
|
GOPROXY := $(shell go env GOPROXY)
|
||||||
GOARCH ?= amd64
|
GOARCH ?= amd64
|
||||||
|
|
||||||
@@ -8,5 +8,5 @@ build:
|
|||||||
--build-arg GO_FILTER_NAME=${GO_FILTER_NAME} \
|
--build-arg GO_FILTER_NAME=${GO_FILTER_NAME} \
|
||||||
--build-arg GOARCH=${GOARCH} \
|
--build-arg GOARCH=${GOARCH} \
|
||||||
-t ${GO_FILTER_NAME} \
|
-t ${GO_FILTER_NAME} \
|
||||||
--output ./${GO_FILTER_NAME} \
|
--output . \
|
||||||
.
|
.
|
||||||
@@ -28,7 +28,7 @@ http_filters:
|
|||||||
typed_config:
|
typed_config:
|
||||||
"@type": type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config
|
"@type": type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config
|
||||||
library_id: my-go-filter
|
library_id: my-go-filter
|
||||||
library_path: "./my-go-filter.so"
|
library_path: "./go-filter.so"
|
||||||
plugin_name: my-go-filter
|
plugin_name: my-go-filter
|
||||||
plugin_config:
|
plugin_config:
|
||||||
"@type": type.googleapis.com/xds.type.v3.TypedStruct
|
"@type": type.googleapis.com/xds.type.v3.TypedStruct
|
||||||
@@ -43,5 +43,5 @@ http_filters:
|
|||||||
使用以下命令可以快速构建 golang filter 插件:
|
使用以下命令可以快速构建 golang filter 插件:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
GO_FILTER_NAME=mcp-server make build
|
make build
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
module github.com/alibaba/higress/plugins/golang-filter/mcp-server
|
module github.com/alibaba/higress/plugins/golang-filter
|
||||||
|
|
||||||
go 1.23
|
go 1.23
|
||||||
|
|
||||||
25
plugins/golang-filter/main.go
Normal file
25
plugins/golang-filter/main.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
mcp_server "github.com/alibaba/higress/plugins/golang-filter/mcp-server"
|
||||||
|
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||||
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
|
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_session.Name, mcp_session.FilterFactory, &mcp_session.Parser{})
|
||||||
|
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(mcp_server.Name, mcp_server.FilterFactory, &mcp_server.Parser{})
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
api.LogErrorf("PProf server recovered from panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {}
|
||||||
@@ -1,64 +1,39 @@
|
|||||||
package main
|
package mcp_server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
|
|
||||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
|
||||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
||||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
||||||
|
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
|
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
envoyHttp "github.com/envoyproxy/envoy/contrib/golang/filters/http/source/go/pkg/http"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
const Name = "mcp-session"
|
const Name = "mcp-server"
|
||||||
const Version = "1.0.0"
|
const Version = "1.0.0"
|
||||||
const DefaultServerName = "defaultServer"
|
|
||||||
const ConfigPathSuffix = "/config"
|
|
||||||
|
|
||||||
func init() {
|
type SSEServerWrapper struct {
|
||||||
envoyHttp.RegisterHttpFilterFactoryAndConfigParser(Name, filterFactory, &parser{})
|
BaseServer *common.SSEServer
|
||||||
go func() {
|
DomainList []string
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
api.LogErrorf("PProf server recovered from panic: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
api.LogError(http.ListenAndServe("localhost:6060", nil).Error())
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
ssePathSuffix string
|
servers []*SSEServerWrapper
|
||||||
redisClient *internal.RedisClient
|
|
||||||
servers []*internal.SSEServer
|
|
||||||
defaultServer *internal.SSEServer
|
|
||||||
matchList []internal.MatchRule
|
|
||||||
enableUserLevelServer bool
|
|
||||||
rateLimitConfig *handler.MCPRatelimitConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *config) Destroy() {
|
func (c *config) Destroy() {
|
||||||
if c.redisClient != nil {
|
|
||||||
api.LogDebug("Closing Redis client")
|
|
||||||
c.redisClient.Close()
|
|
||||||
}
|
|
||||||
for _, server := range c.servers {
|
for _, server := range c.servers {
|
||||||
server.Close()
|
server.BaseServer.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type parser struct {
|
type Parser struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the filter configuration
|
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||||
func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
|
||||||
configStruct := &xds.TypedStruct{}
|
configStruct := &xds.TypedStruct{}
|
||||||
if err := any.UnmarshalTo(configStruct); err != nil {
|
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -66,82 +41,9 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
|||||||
v := configStruct.Value
|
v := configStruct.Value
|
||||||
|
|
||||||
conf := &config{
|
conf := &config{
|
||||||
matchList: make([]internal.MatchRule, 0),
|
servers: make([]*SSEServerWrapper, 0),
|
||||||
servers: make([]*internal.SSEServer, 0),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse match_list if exists
|
|
||||||
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
|
||||||
for _, item := range matchList {
|
|
||||||
if ruleMap, ok := item.(map[string]interface{}); ok {
|
|
||||||
rule := internal.MatchRule{}
|
|
||||||
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
|
||||||
rule.MatchRuleDomain = domain
|
|
||||||
}
|
|
||||||
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
|
||||||
rule.MatchRulePath = path
|
|
||||||
}
|
|
||||||
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
|
||||||
rule.MatchRuleType = internal.RuleType(ruleType)
|
|
||||||
}
|
|
||||||
conf.matchList = append(conf.matchList, rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Redis configuration is optional
|
|
||||||
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
|
||||||
redisConfig, err := internal.ParseRedisConfig(redisConfigMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
redisClient, err := internal.NewRedisClient(redisConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
|
||||||
}
|
|
||||||
conf.redisClient = redisClient
|
|
||||||
api.LogDebug("Redis client initialized")
|
|
||||||
} else {
|
|
||||||
api.LogDebug("Redis configuration not provided, running without Redis")
|
|
||||||
}
|
|
||||||
|
|
||||||
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
|
||||||
if !ok {
|
|
||||||
enableUserLevelServer = false
|
|
||||||
if conf.redisClient == nil {
|
|
||||||
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conf.enableUserLevelServer = enableUserLevelServer
|
|
||||||
|
|
||||||
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
|
||||||
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
|
||||||
if limit, ok := rateLimit["limit"].(float64); ok {
|
|
||||||
rateLimitConfig.Limit = int(limit)
|
|
||||||
}
|
|
||||||
if window, ok := rateLimit["window"].(float64); ok {
|
|
||||||
rateLimitConfig.Window = int(window)
|
|
||||||
}
|
|
||||||
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
|
||||||
for _, item := range whiteList {
|
|
||||||
if uid, ok := item.(string); ok {
|
|
||||||
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if errorText, ok := rateLimit["error_text"].(string); ok {
|
|
||||||
rateLimitConfig.ErrorText = errorText
|
|
||||||
}
|
|
||||||
conf.rateLimitConfig = rateLimitConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
|
||||||
if !ok || ssePathSuffix == "" {
|
|
||||||
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
|
||||||
}
|
|
||||||
conf.ssePathSuffix = ssePathSuffix
|
|
||||||
|
|
||||||
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
|
serverConfigs, ok := v.AsMap()["servers"].([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
api.LogDebug("No servers are configured")
|
api.LogDebug("No servers are configured")
|
||||||
@@ -153,19 +55,33 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("server config must be an object")
|
return nil, fmt.Errorf("server config must be an object")
|
||||||
}
|
}
|
||||||
|
|
||||||
serverType, ok := serverConfigMap["type"].(string)
|
serverType, ok := serverConfigMap["type"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("server type is not set")
|
return nil, fmt.Errorf("server type is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
serverPath, ok := serverConfigMap["path"].(string)
|
serverPath, ok := serverConfigMap["path"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("server %s path is not set", serverType)
|
return nil, fmt.Errorf("server %s path is not set", serverType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverDomainList := []string{}
|
||||||
|
if domainList, ok := serverConfigMap["domain_list"].([]interface{}); ok {
|
||||||
|
for _, domain := range domainList {
|
||||||
|
if domainStr, ok := domain.(string); ok {
|
||||||
|
serverDomainList = append(serverDomainList, domainStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
serverDomainList = []string{"*"}
|
||||||
|
}
|
||||||
|
|
||||||
serverName, ok := serverConfigMap["name"].(string)
|
serverName, ok := serverConfigMap["name"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("server %s name is not set", serverType)
|
return nil, fmt.Errorf("server %s name is not set", serverType)
|
||||||
}
|
}
|
||||||
server := internal.GlobalRegistry.GetServer(serverType)
|
server := common.GlobalRegistry.GetServer(serverType)
|
||||||
|
|
||||||
if server == nil {
|
if server == nil {
|
||||||
return nil, fmt.Errorf("server %s is not registered", serverType)
|
return nil, fmt.Errorf("server %s is not registered", serverType)
|
||||||
@@ -186,50 +102,37 @@ func (p *parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
|||||||
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
|
return nil, fmt.Errorf("failed to initialize DBServer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.servers = append(conf.servers, internal.NewSSEServer(serverInstance,
|
conf.servers = append(conf.servers, &SSEServerWrapper{
|
||||||
internal.WithRedisClient(conf.redisClient),
|
BaseServer: common.NewSSEServer(serverInstance,
|
||||||
internal.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, ssePathSuffix)),
|
common.WithRedisClient(common.GlobalRedisClient),
|
||||||
internal.WithMessageEndpoint(serverPath)))
|
common.WithSSEEndpoint(fmt.Sprintf("%s%s", serverPath, mcp_session.GlobalSSEPathSuffix)),
|
||||||
|
common.WithMessageEndpoint(serverPath)),
|
||||||
|
DomainList: serverDomainList,
|
||||||
|
})
|
||||||
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
api.LogDebug(fmt.Sprintf("Registered MCP Server: %s", serverType))
|
||||||
}
|
}
|
||||||
|
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *parser) Merge(parent interface{}, child interface{}) interface{} {
|
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||||
parentConfig := parent.(*config)
|
parentConfig := parent.(*config)
|
||||||
childConfig := child.(*config)
|
childConfig := child.(*config)
|
||||||
|
|
||||||
newConfig := *parentConfig
|
newConfig := *parentConfig
|
||||||
if childConfig.redisClient != nil {
|
|
||||||
newConfig.redisClient = childConfig.redisClient
|
|
||||||
}
|
|
||||||
if childConfig.ssePathSuffix != "" {
|
|
||||||
newConfig.ssePathSuffix = childConfig.ssePathSuffix
|
|
||||||
}
|
|
||||||
if childConfig.servers != nil {
|
if childConfig.servers != nil {
|
||||||
newConfig.servers = childConfig.servers
|
newConfig.servers = childConfig.servers
|
||||||
}
|
}
|
||||||
if childConfig.defaultServer != nil {
|
|
||||||
newConfig.defaultServer = childConfig.defaultServer
|
|
||||||
}
|
|
||||||
if childConfig.matchList != nil {
|
|
||||||
newConfig.matchList = childConfig.matchList
|
|
||||||
}
|
|
||||||
return &newConfig
|
return &newConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||||
conf, ok := c.(*config)
|
conf, ok := c.(*config)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("unexpected config type")
|
panic("unexpected config type")
|
||||||
}
|
}
|
||||||
return &filter{
|
return &filter{
|
||||||
callbacks: callbacks,
|
config: conf,
|
||||||
config: conf,
|
callbacks: callbacks,
|
||||||
stopChan: make(chan struct{}),
|
|
||||||
mcpConfigHandler: handler.NewMCPConfigHandler(conf.redisClient, callbacks),
|
|
||||||
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(conf.redisClient, callbacks, conf.rateLimitConfig),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {}
|
|
||||||
|
|||||||
@@ -1,104 +1,41 @@
|
|||||||
package main
|
package mcp_server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/handler"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
|
||||||
// Because api.PassThroughStreamFilter provides a default implementation.
|
|
||||||
type filter struct {
|
type filter struct {
|
||||||
api.PassThroughStreamFilter
|
api.PassThroughStreamFilter
|
||||||
|
|
||||||
callbacks api.FilterCallbackHandler
|
callbacks api.FilterCallbackHandler
|
||||||
path string
|
|
||||||
config *config
|
|
||||||
stopChan chan struct{}
|
|
||||||
|
|
||||||
req *http.Request
|
config *config
|
||||||
serverName string
|
req *http.Request
|
||||||
message bool
|
message bool
|
||||||
proxyURL *url.URL
|
path string
|
||||||
skip bool
|
|
||||||
|
|
||||||
userLevelConfig bool
|
|
||||||
mcpConfigHandler *handler.MCPConfigHandler
|
|
||||||
ratelimit bool
|
|
||||||
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestURL struct {
|
|
||||||
method string
|
|
||||||
scheme string
|
|
||||||
host string
|
|
||||||
path string
|
|
||||||
baseURL string
|
|
||||||
parsedURL *url.URL
|
|
||||||
internalIP bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
|
||||||
method, _ := header.Get(":method")
|
|
||||||
scheme, _ := header.Get(":scheme")
|
|
||||||
host, _ := header.Get(":authority")
|
|
||||||
path, _ := header.Get(":path")
|
|
||||||
internalIP, _ := header.Get("x-envoy-internal")
|
|
||||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
|
||||||
parsedURL, _ := url.Parse(path)
|
|
||||||
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
|
||||||
return &RequestURL{method: method, scheme: scheme, host: host, path: path, baseURL: baseURL, parsedURL: parsedURL, internalIP: internalIP == "true"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Callbacks which are called in request path
|
|
||||||
// The endStream is true if the request doesn't have body
|
|
||||||
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||||
url := NewRequestURL(header)
|
url := common.NewRequestURL(header)
|
||||||
f.path = url.parsedURL.Path
|
f.path = url.ParsedURL.Path
|
||||||
|
|
||||||
// Check if request matches any rule in match_list
|
|
||||||
if !internal.IsMatch(f.config.matchList, url.host, f.path) {
|
|
||||||
f.skip = true
|
|
||||||
api.LogDebugf("Request does not match any rule in match_list: %s", url.parsedURL.String())
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, server := range f.config.servers {
|
for _, server := range f.config.servers {
|
||||||
if f.path == server.GetSSEEndpoint() {
|
if common.MatchDomainList(url.ParsedURL.Host, server.DomainList) && url.ParsedURL.Path == server.BaseServer.GetMessageEndpoint() {
|
||||||
if url.method != http.MethodGet {
|
if url.Method != http.MethodPost {
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
|
||||||
} else {
|
|
||||||
f.serverName = server.GetServerName()
|
|
||||||
body := "SSE connection create"
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
|
||||||
}
|
|
||||||
api.LogDebugf("%s SSE connection started", server.GetServerName())
|
|
||||||
return api.LocalReply
|
|
||||||
} else if f.path == server.GetMessageEndpoint() {
|
|
||||||
if url.method != http.MethodPost {
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||||
return api.LocalReply
|
return api.LocalReply
|
||||||
}
|
}
|
||||||
// Create a new http.Request object
|
// Create a new http.Request object
|
||||||
f.req = &http.Request{
|
f.req = &http.Request{
|
||||||
Method: url.method,
|
Method: url.Method,
|
||||||
URL: url.parsedURL,
|
URL: url.ParsedURL,
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}
|
}
|
||||||
api.LogDebugf("Message request: %v", url.parsedURL)
|
api.LogDebugf("Message request: %v", url.ParsedURL)
|
||||||
// Copy headers from api.RequestHeaderMap to http.Header
|
// Copy headers from api.RequestHeaderMap to http.Header
|
||||||
header.Range(func(key, value string) bool {
|
header.Range(func(key, value string) bool {
|
||||||
f.req.Header.Add(key, value)
|
f.req.Header.Add(key, value)
|
||||||
@@ -113,209 +50,33 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.req = &http.Request{
|
return api.Continue
|
||||||
Method: url.method,
|
|
||||||
URL: url.parsedURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
|
||||||
if !url.internalIP {
|
|
||||||
api.LogWarnf("Access denied: non-internal IP address %s", url.parsedURL.String())
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
|
||||||
return api.LocalReply
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.method == http.MethodGet {
|
|
||||||
api.LogDebugf("Handling config request: %s", f.path)
|
|
||||||
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
|
||||||
return api.LocalReply
|
|
||||||
}
|
|
||||||
f.userLevelConfig = true
|
|
||||||
if endStream {
|
|
||||||
return api.Continue
|
|
||||||
} else {
|
|
||||||
return api.StopAndBuffer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(url.parsedURL.Path, f.config.ssePathSuffix) {
|
|
||||||
f.proxyURL = url.parsedURL
|
|
||||||
if f.config.enableUserLevelServer {
|
|
||||||
parts := strings.Split(url.parsedURL.Path, "/")
|
|
||||||
if len(parts) >= 3 {
|
|
||||||
serverName := parts[1]
|
|
||||||
uid := parts[2]
|
|
||||||
// Get encoded config
|
|
||||||
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
|
||||||
if encodedConfig != "" {
|
|
||||||
header.Set("x-higress-mcpserver-config", encodedConfig)
|
|
||||||
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
f.ratelimit = true
|
|
||||||
}
|
|
||||||
if endStream {
|
|
||||||
return api.Continue
|
|
||||||
} else {
|
|
||||||
return api.StopAndBuffer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if url.method != http.MethodGet {
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
|
||||||
} else {
|
|
||||||
f.config.defaultServer = internal.NewSSEServer(internal.NewMCPServer(DefaultServerName, Version),
|
|
||||||
internal.WithSSEEndpoint(f.config.ssePathSuffix),
|
|
||||||
internal.WithMessageEndpoint(strings.TrimSuffix(url.parsedURL.Path, f.config.ssePathSuffix)),
|
|
||||||
internal.WithRedisClient(f.config.redisClient))
|
|
||||||
f.serverName = f.config.defaultServer.GetServerName()
|
|
||||||
body := "SSE connection create"
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
|
||||||
}
|
|
||||||
return api.LocalReply
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeData might be called multiple times during handling the request body.
|
|
||||||
// The endStream is true when handling the last piece of the body.
|
|
||||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||||
if f.skip {
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
if !endStream {
|
if !endStream {
|
||||||
return api.StopAndBuffer
|
return api.StopAndBuffer
|
||||||
}
|
}
|
||||||
if f.message {
|
if f.message {
|
||||||
for _, server := range f.config.servers {
|
for _, server := range f.config.servers {
|
||||||
if f.path == server.GetMessageEndpoint() {
|
if f.path == server.BaseServer.GetMessageEndpoint() {
|
||||||
// Create a response recorder to capture the response
|
// Create a response recorder to capture the response
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
// Call the handleMessage method of SSEServer with complete body
|
// Call the handleMessage method of SSEServer with complete body
|
||||||
httpStatus := server.HandleMessage(recorder, f.req, buffer.Bytes())
|
httpStatus := server.BaseServer.HandleMessage(recorder, f.req, buffer.Bytes())
|
||||||
f.message = false
|
f.message = false
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(httpStatus, recorder.Body.String(), recorder.Header(), 0, "")
|
||||||
return api.LocalReply
|
return api.LocalReply
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if f.userLevelConfig {
|
|
||||||
// Handle config POST request
|
|
||||||
api.LogDebugf("Handling config request: %s", f.path)
|
|
||||||
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
|
||||||
return api.LocalReply
|
|
||||||
} else if f.ratelimit {
|
|
||||||
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
|
||||||
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
parts := strings.Split(f.req.URL.Path, "/")
|
|
||||||
if len(parts) < 3 {
|
|
||||||
api.LogWarnf("Access denied: no valid uid found")
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
|
||||||
return api.LocalReply
|
|
||||||
}
|
|
||||||
serverName := parts[1]
|
|
||||||
uid := parts[2]
|
|
||||||
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
|
||||||
if err != nil {
|
|
||||||
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
|
||||||
return api.LocalReply
|
|
||||||
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
|
||||||
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
|
||||||
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
|
||||||
return api.LocalReply
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return api.Continue
|
return api.Continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callbacks which are called in response path
|
|
||||||
// The endStream is true if the response doesn't have body
|
|
||||||
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||||
if f.skip {
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
if f.serverName != "" {
|
|
||||||
if f.config.redisClient != nil {
|
|
||||||
header.Set("Content-Type", "text/event-stream")
|
|
||||||
header.Set("Cache-Control", "no-cache")
|
|
||||||
header.Set("Connection", "keep-alive")
|
|
||||||
header.Set("Access-Control-Allow-Origin", "*")
|
|
||||||
header.Del("Content-Length")
|
|
||||||
} else {
|
|
||||||
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
|
||||||
}
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
return api.Continue
|
return api.Continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeData might be called multiple times during handling the response body.
|
|
||||||
// The endStream is true when handling the last piece of the body.
|
|
||||||
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||||
if f.skip {
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
if !endStream {
|
|
||||||
return api.StopAndBuffer
|
|
||||||
}
|
|
||||||
if f.proxyURL != nil && f.config.redisClient != nil {
|
|
||||||
sessionID := f.proxyURL.Query().Get("sessionId")
|
|
||||||
if sessionID != "" {
|
|
||||||
channel := internal.GetSSEChannelName(sessionID)
|
|
||||||
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
|
||||||
publishErr := f.config.redisClient.Publish(channel, eventData)
|
|
||||||
if publishErr != nil {
|
|
||||||
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.serverName != "" {
|
|
||||||
if f.config.redisClient != nil {
|
|
||||||
// handle specific server
|
|
||||||
for _, server := range f.config.servers {
|
|
||||||
if f.serverName == server.GetServerName() {
|
|
||||||
buffer.Reset()
|
|
||||||
server.HandleSSE(f.callbacks, f.stopChan)
|
|
||||||
return api.Running
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// handle default server
|
|
||||||
if f.serverName == f.config.defaultServer.GetServerName() {
|
|
||||||
buffer.Reset()
|
|
||||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
|
||||||
return api.Running
|
|
||||||
}
|
|
||||||
return api.Continue
|
|
||||||
} else {
|
|
||||||
buffer.SetString(RedisNotEnabledResponseBody)
|
|
||||||
return api.Continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return api.Continue
|
return api.Continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnDestroy stops the goroutine
|
|
||||||
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
|
||||||
api.LogDebugf("OnDestroy: reason=%v", reason)
|
|
||||||
if f.serverName != "" && f.stopChan != nil {
|
|
||||||
select {
|
|
||||||
case <-f.stopChan:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
api.LogDebug("Stopping SSE connection")
|
|
||||||
close(f.stopChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if the request is a tools/call request
|
|
||||||
func checkJSONRPCMethod(body []byte, method string) bool {
|
|
||||||
var request mcp.CallToolRequest
|
|
||||||
if err := json.Unmarshal(body, &request); err != nil {
|
|
||||||
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return request.Method == method
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry"
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
internal.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
common.GlobalRegistry.RegisterServer("nacos-mcp-registry", &NacosConfig{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type NacosConfig struct {
|
type NacosConfig struct {
|
||||||
@@ -28,7 +28,7 @@ type NacosConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type McpServerToolsChangeListener struct {
|
type McpServerToolsChangeListener struct {
|
||||||
mcpServer *internal.MCPServer
|
mcpServer *common.MCPServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *McpServerToolsChangeListener) OnToolChanged(reg registry.McpServerRegistry) {
|
func (l *McpServerToolsChangeListener) OnToolChanged(reg registry.McpServerRegistry) {
|
||||||
@@ -137,8 +137,8 @@ func (c *NacosConfig) ParseConfig(config map[string]any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
func (c *NacosConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||||
mcpServer := internal.NewMCPServer(
|
mcpServer := common.NewMCPServer(
|
||||||
serverName,
|
serverName,
|
||||||
"1.0.0",
|
"1.0.0",
|
||||||
)
|
)
|
||||||
@@ -170,11 +170,11 @@ func (c *NacosConfig) NewServer(serverName string) (*internal.MCPServer, error)
|
|||||||
return mcpServer, nil
|
return mcpServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func resetToolsToMcpServer(mcpServer *internal.MCPServer, reg registry.McpServerRegistry) {
|
func resetToolsToMcpServer(mcpServer *common.MCPServer, reg registry.McpServerRegistry) {
|
||||||
wrappedTools := []internal.ServerTool{}
|
wrappedTools := []common.ServerTool{}
|
||||||
tools := reg.ListToolsDesciption()
|
tools := reg.ListToolsDesciption()
|
||||||
for _, tool := range tools {
|
for _, tool := range tools {
|
||||||
wrappedTools = append(wrappedTools, internal.ServerTool{
|
wrappedTools = append(wrappedTools, common.ServerTool{
|
||||||
Tool: mcp.NewToolWithRawSchema(tool.Name, tool.Description, tool.InputSchema),
|
Tool: mcp.NewToolWithRawSchema(tool.Name, tool.Description, tool.InputSchema),
|
||||||
Handler: registry.HandleRegistryToolsCall(reg),
|
Handler: registry.HandleRegistryToolsCall(reg),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -204,7 +204,7 @@ func CommonRemoteCall(reg McpServerRegistry, toolName string, parameters map[str
|
|||||||
return remoteHandle.HandleToolCall(ctx, parameters)
|
return remoteHandle.HandleToolCall(ctx, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleRegistryToolsCall(reg McpServerRegistry) internal.ToolHandlerFunc {
|
func HandleRegistryToolsCall(reg McpServerRegistry) common.ToolHandlerFunc {
|
||||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
arguments := request.Params.Arguments
|
arguments := request.Params.Arguments
|
||||||
return CommonRemoteCall(reg, request.Params.Name, arguments)
|
return CommonRemoteCall(reg, request.Params.Name, arguments)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
)
|
)
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
const Version = "1.0.0"
|
const Version = "1.0.0"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
internal.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
common.GlobalRegistry.RegisterServer("database", &DBConfig{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBConfig struct {
|
type DBConfig struct {
|
||||||
@@ -41,11 +41,11 @@ func (c *DBConfig) ParseConfig(config map[string]any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DBConfig) NewServer(serverName string) (*internal.MCPServer, error) {
|
func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||||
mcpServer := internal.NewMCPServer(
|
mcpServer := common.NewMCPServer(
|
||||||
serverName,
|
serverName,
|
||||||
Version,
|
Version,
|
||||||
internal.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
common.WithInstructions(fmt.Sprintf("This is a %s database server", c.dbType)),
|
||||||
)
|
)
|
||||||
|
|
||||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleQueryTool handles SQL query execution
|
// HandleQueryTool handles SQL query execution
|
||||||
func HandleQueryTool(dbClient *DBClient) internal.ToolHandlerFunc {
|
func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
arguments := request.Params.Arguments
|
arguments := request.Params.Arguments
|
||||||
message, ok := arguments["sql"].(string)
|
message, ok := arguments["sql"].(string)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -23,6 +23,27 @@ type MatchRule struct {
|
|||||||
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
|
MatchRuleType RuleType `json:"match_rule_type"` // Type of match rule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseMatchList parses the match list from the config
|
||||||
|
func ParseMatchList(matchListConfig []interface{}) []MatchRule {
|
||||||
|
matchList := make([]MatchRule, 0)
|
||||||
|
for _, item := range matchListConfig {
|
||||||
|
if ruleMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
rule := MatchRule{}
|
||||||
|
if domain, ok := ruleMap["match_rule_domain"].(string); ok {
|
||||||
|
rule.MatchRuleDomain = domain
|
||||||
|
}
|
||||||
|
if path, ok := ruleMap["match_rule_path"].(string); ok {
|
||||||
|
rule.MatchRulePath = path
|
||||||
|
}
|
||||||
|
if ruleType, ok := ruleMap["match_rule_type"].(string); ok {
|
||||||
|
rule.MatchRuleType = RuleType(ruleType)
|
||||||
|
}
|
||||||
|
matchList = append(matchList, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matchList
|
||||||
|
}
|
||||||
|
|
||||||
// convertWildcardToRegex converts wildcard pattern to regex pattern
|
// convertWildcardToRegex converts wildcard pattern to regex pattern
|
||||||
func convertWildcardToRegex(pattern string) string {
|
func convertWildcardToRegex(pattern string) string {
|
||||||
pattern = regexp.QuoteMeta(pattern)
|
pattern = regexp.QuoteMeta(pattern)
|
||||||
@@ -87,3 +108,13 @@ func IsMatch(rules []MatchRule, host, path string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MatchDomainList checks if the domain matches any of the domains in the list
|
||||||
|
func MatchDomainList(domain string, domainList []string) bool {
|
||||||
|
for _, d := range domainList {
|
||||||
|
if matchDomain(domain, d) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var GlobalRedisClient *RedisClient
|
||||||
|
|
||||||
type RedisConfig struct {
|
type RedisConfig struct {
|
||||||
address string
|
address string
|
||||||
username string
|
username string
|
||||||
@@ -249,6 +251,18 @@ func (r *RedisClient) Get(key string) (string, error) {
|
|||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expire sets the expiration time for a key
|
||||||
|
func (r *RedisClient) Expire(key string, expiration time.Duration) error {
|
||||||
|
ok, err := r.client.Expire(r.ctx, key, expiration).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set expiration for key: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("key does not exist")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the Redis client and stops the keepalive goroutine
|
// Close closes the Redis client and stops the keepalive goroutine
|
||||||
func (r *RedisClient) Close() error {
|
func (r *RedisClient) Close() error {
|
||||||
r.cancel()
|
r.cancel()
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
var GlobalRegistry = NewServerRegistry()
|
var GlobalRegistry = NewServerRegistry()
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -243,6 +243,7 @@ func (s *MCPServer) HandleMessage(
|
|||||||
message json.RawMessage,
|
message json.RawMessage,
|
||||||
) mcp.JSONRPCMessage {
|
) mcp.JSONRPCMessage {
|
||||||
// Add server to context
|
// Add server to context
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, serverKey{}, s)
|
ctx = context.WithValue(ctx, serverKey{}, s)
|
||||||
|
|
||||||
var baseMessage struct {
|
var baseMessage struct {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package internal
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -210,15 +210,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
|||||||
var status int
|
var status int
|
||||||
// Only send response if there is one (not for notifications)
|
// Only send response if there is one (not for notifications)
|
||||||
if response != nil {
|
if response != nil {
|
||||||
eventData, _ := json.Marshal(response)
|
|
||||||
|
|
||||||
if sessionID != "" && s.redisClient != nil {
|
if sessionID != "" && s.redisClient != nil {
|
||||||
channel := GetSSEChannelName(sessionID)
|
|
||||||
publishErr := s.redisClient.Publish(channel, fmt.Sprintf("event: message\ndata: %s\n\n", eventData))
|
|
||||||
|
|
||||||
if publishErr != nil {
|
|
||||||
api.LogErrorf("Failed to publish message to Redis: %v", publishErr)
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
status = http.StatusAccepted
|
status = http.StatusAccepted
|
||||||
} else {
|
} else {
|
||||||
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
30
plugins/golang-filter/mcp-session/common/utils.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestURL struct {
|
||||||
|
Method string
|
||||||
|
Scheme string
|
||||||
|
Host string
|
||||||
|
Path string
|
||||||
|
BaseURL string
|
||||||
|
ParsedURL *url.URL
|
||||||
|
InternalIP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestURL(header api.RequestHeaderMap) *RequestURL {
|
||||||
|
method, _ := header.Get(":method")
|
||||||
|
scheme, _ := header.Get(":scheme")
|
||||||
|
host, _ := header.Get(":authority")
|
||||||
|
path, _ := header.Get(":path")
|
||||||
|
internalIP, _ := header.Get("x-envoy-internal")
|
||||||
|
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||||
|
parsedURL, _ := url.Parse(path)
|
||||||
|
api.LogDebugf("RequestURL: method=%s, scheme=%s, host=%s, path=%s", method, scheme, host, path)
|
||||||
|
return &RequestURL{Method: method, Scheme: scheme, Host: host, Path: path, BaseURL: baseURL, ParsedURL: parsedURL, InternalIP: internalIP == "true"}
|
||||||
|
}
|
||||||
143
plugins/golang-filter/mcp-session/config.go
Normal file
143
plugins/golang-filter/mcp-session/config.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package mcp_session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
_ "net/http/pprof"
|
||||||
|
|
||||||
|
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||||
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const Name = "mcp-session"
|
||||||
|
const Version = "1.0.0"
|
||||||
|
const ConfigPathSuffix = "/config"
|
||||||
|
const DefaultServerName = "higress-mcp-server"
|
||||||
|
|
||||||
|
var GlobalSSEPathSuffix = "/sse"
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
matchList []common.MatchRule
|
||||||
|
enableUserLevelServer bool
|
||||||
|
rateLimitConfig *handler.MCPRatelimitConfig
|
||||||
|
defaultServer *common.SSEServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) Destroy() {
|
||||||
|
if common.GlobalRedisClient != nil {
|
||||||
|
api.LogDebug("Closing Redis client")
|
||||||
|
common.GlobalRedisClient.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the filter configuration
|
||||||
|
func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (interface{}, error) {
|
||||||
|
configStruct := &xds.TypedStruct{}
|
||||||
|
if err := any.UnmarshalTo(configStruct); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v := configStruct.Value
|
||||||
|
|
||||||
|
conf := &config{
|
||||||
|
matchList: make([]common.MatchRule, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse match_list if exists
|
||||||
|
if matchList, ok := v.AsMap()["match_list"].([]interface{}); ok {
|
||||||
|
conf.matchList = common.ParseMatchList(matchList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redis configuration is optional
|
||||||
|
if redisConfigMap, ok := v.AsMap()["redis"].(map[string]interface{}); ok {
|
||||||
|
redisConfig, err := common.ParseRedisConfig(redisConfigMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse redis config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
redisClient, err := common.NewRedisClient(redisConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize RedisClient: %w", err)
|
||||||
|
}
|
||||||
|
common.GlobalRedisClient = redisClient
|
||||||
|
api.LogDebug("Redis client initialized")
|
||||||
|
} else {
|
||||||
|
api.LogDebug("Redis configuration not provided, running without Redis")
|
||||||
|
}
|
||||||
|
|
||||||
|
enableUserLevelServer, ok := v.AsMap()["enable_user_level_server"].(bool)
|
||||||
|
if !ok {
|
||||||
|
enableUserLevelServer = false
|
||||||
|
if common.GlobalRedisClient == nil {
|
||||||
|
return nil, fmt.Errorf("redis configuration is not provided, enable_user_level_server is true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conf.enableUserLevelServer = enableUserLevelServer
|
||||||
|
|
||||||
|
if rateLimit, ok := v.AsMap()["rate_limit"].(map[string]interface{}); ok {
|
||||||
|
rateLimitConfig := &handler.MCPRatelimitConfig{}
|
||||||
|
if limit, ok := rateLimit["limit"].(float64); ok {
|
||||||
|
rateLimitConfig.Limit = int(limit)
|
||||||
|
}
|
||||||
|
if window, ok := rateLimit["window"].(float64); ok {
|
||||||
|
rateLimitConfig.Window = int(window)
|
||||||
|
}
|
||||||
|
if whiteList, ok := rateLimit["white_list"].([]interface{}); ok {
|
||||||
|
for _, item := range whiteList {
|
||||||
|
if uid, ok := item.(string); ok {
|
||||||
|
rateLimitConfig.Whitelist = append(rateLimitConfig.Whitelist, uid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errorText, ok := rateLimit["error_text"].(string); ok {
|
||||||
|
rateLimitConfig.ErrorText = errorText
|
||||||
|
}
|
||||||
|
conf.rateLimitConfig = rateLimitConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
ssePathSuffix, ok := v.AsMap()["sse_path_suffix"].(string)
|
||||||
|
if !ok || ssePathSuffix == "" {
|
||||||
|
return nil, fmt.Errorf("sse path suffix is not set or empty")
|
||||||
|
}
|
||||||
|
GlobalSSEPathSuffix = ssePathSuffix
|
||||||
|
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||||
|
parentConfig := parent.(*config)
|
||||||
|
childConfig := child.(*config)
|
||||||
|
|
||||||
|
newConfig := *parentConfig
|
||||||
|
if childConfig.matchList != nil {
|
||||||
|
newConfig.matchList = childConfig.matchList
|
||||||
|
}
|
||||||
|
newConfig.enableUserLevelServer = childConfig.enableUserLevelServer
|
||||||
|
if childConfig.rateLimitConfig != nil {
|
||||||
|
newConfig.rateLimitConfig = childConfig.rateLimitConfig
|
||||||
|
}
|
||||||
|
if childConfig.defaultServer != nil {
|
||||||
|
newConfig.defaultServer = childConfig.defaultServer
|
||||||
|
}
|
||||||
|
return &newConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func FilterFactory(c interface{}, callbacks api.FilterCallbackHandler) api.StreamFilter {
|
||||||
|
conf, ok := c.(*config)
|
||||||
|
if !ok {
|
||||||
|
panic("unexpected config type")
|
||||||
|
}
|
||||||
|
return &filter{
|
||||||
|
callbacks: callbacks,
|
||||||
|
config: conf,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
mcpConfigHandler: handler.NewMCPConfigHandler(common.GlobalRedisClient, callbacks),
|
||||||
|
mcpRatelimitHandler: handler.NewMCPRatelimitHandler(common.GlobalRedisClient, callbacks, conf.rateLimitConfig),
|
||||||
|
}
|
||||||
|
}
|
||||||
237
plugins/golang-filter/mcp-session/filter.go
Normal file
237
plugins/golang-filter/mcp-session/filter.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
package mcp_session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
|
||||||
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
|
||||||
|
// Because api.PassThroughStreamFilter provides a default implementation.
|
||||||
|
type filter struct {
|
||||||
|
api.PassThroughStreamFilter
|
||||||
|
|
||||||
|
callbacks api.FilterCallbackHandler
|
||||||
|
path string
|
||||||
|
config *config
|
||||||
|
stopChan chan struct{}
|
||||||
|
|
||||||
|
req *http.Request
|
||||||
|
serverName string
|
||||||
|
proxyURL *url.URL
|
||||||
|
skip bool
|
||||||
|
|
||||||
|
userLevelConfig bool
|
||||||
|
mcpConfigHandler *handler.MCPConfigHandler
|
||||||
|
ratelimit bool
|
||||||
|
mcpRatelimitHandler *handler.MCPRatelimitHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callbacks which are called in request path
|
||||||
|
// The endStream is true if the request doesn't have body
|
||||||
|
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||||
|
url := common.NewRequestURL(header)
|
||||||
|
f.path = url.ParsedURL.Path
|
||||||
|
|
||||||
|
// Check if request matches any rule in match_list
|
||||||
|
if !common.IsMatch(f.config.matchList, url.Host, f.path) {
|
||||||
|
f.skip = true
|
||||||
|
api.LogDebugf("Request does not match any rule in match_list: %s", url.ParsedURL.String())
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
|
||||||
|
f.req = &http.Request{
|
||||||
|
Method: url.Method,
|
||||||
|
URL: url.ParsedURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
|
||||||
|
if !url.InternalIP {
|
||||||
|
api.LogWarnf("Access denied: non-Internal IP address %s", url.ParsedURL.String())
|
||||||
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||||
|
return api.LocalReply
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(f.path, ConfigPathSuffix) && url.Method == http.MethodGet {
|
||||||
|
api.LogDebugf("Handling config request: %s", f.path)
|
||||||
|
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
|
||||||
|
return api.LocalReply
|
||||||
|
}
|
||||||
|
f.userLevelConfig = true
|
||||||
|
if endStream {
|
||||||
|
return api.Continue
|
||||||
|
} else {
|
||||||
|
return api.StopAndBuffer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix) {
|
||||||
|
f.proxyURL = url.ParsedURL
|
||||||
|
if f.config.enableUserLevelServer {
|
||||||
|
parts := strings.Split(url.ParsedURL.Path, "/")
|
||||||
|
if len(parts) >= 3 {
|
||||||
|
serverName := parts[1]
|
||||||
|
uid := parts[2]
|
||||||
|
// Get encoded config
|
||||||
|
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||||
|
if encodedConfig != "" {
|
||||||
|
header.Set("x-higress-mcpserver-config", encodedConfig)
|
||||||
|
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.ratelimit = true
|
||||||
|
}
|
||||||
|
if endStream {
|
||||||
|
return api.Continue
|
||||||
|
} else {
|
||||||
|
return api.StopAndBuffer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.Method != http.MethodGet {
|
||||||
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||||
|
} else {
|
||||||
|
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||||
|
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||||
|
common.WithMessageEndpoint(strings.TrimSuffix(url.ParsedURL.Path, GlobalSSEPathSuffix)),
|
||||||
|
common.WithRedisClient(common.GlobalRedisClient))
|
||||||
|
f.serverName = f.config.defaultServer.GetServerName()
|
||||||
|
body := "SSE connection create"
|
||||||
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||||
|
}
|
||||||
|
return api.LocalReply
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeData might be called multiple times during handling the request body.
|
||||||
|
// The endStream is true when handling the last piece of the body.
|
||||||
|
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||||
|
if f.skip {
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
if !endStream {
|
||||||
|
return api.StopAndBuffer
|
||||||
|
}
|
||||||
|
if f.userLevelConfig {
|
||||||
|
// Handle config POST request
|
||||||
|
api.LogDebugf("Handling config request: %s", f.path)
|
||||||
|
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
|
||||||
|
return api.LocalReply
|
||||||
|
} else if f.ratelimit {
|
||||||
|
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
|
||||||
|
api.LogDebugf("Not a tools call request, skipping ratelimit")
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
parts := strings.Split(f.req.URL.Path, "/")
|
||||||
|
if len(parts) < 3 {
|
||||||
|
api.LogWarnf("Access denied: no valid uid found")
|
||||||
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||||
|
return api.LocalReply
|
||||||
|
}
|
||||||
|
serverName := parts[1]
|
||||||
|
uid := parts[2]
|
||||||
|
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
|
||||||
|
if err != nil {
|
||||||
|
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
|
||||||
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
|
||||||
|
return api.LocalReply
|
||||||
|
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
|
||||||
|
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
|
||||||
|
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
|
||||||
|
return api.LocalReply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callbacks which are called in response path
|
||||||
|
// The endStream is true if the response doesn't have body
|
||||||
|
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
|
||||||
|
if f.skip {
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
if f.serverName != "" {
|
||||||
|
if common.GlobalRedisClient != nil {
|
||||||
|
header.Set("Content-Type", "text/event-stream")
|
||||||
|
header.Set("Cache-Control", "no-cache")
|
||||||
|
header.Set("Connection", "keep-alive")
|
||||||
|
header.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
header.Del("Content-Length")
|
||||||
|
} else {
|
||||||
|
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
|
||||||
|
}
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeData might be called multiple times during handling the response body.
|
||||||
|
// The endStream is true when handling the last piece of the body.
|
||||||
|
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||||
|
if f.skip {
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
if !endStream {
|
||||||
|
return api.StopAndBuffer
|
||||||
|
}
|
||||||
|
if f.proxyURL != nil && common.GlobalRedisClient != nil {
|
||||||
|
sessionID := f.proxyURL.Query().Get("sessionId")
|
||||||
|
if sessionID != "" {
|
||||||
|
channel := common.GetSSEChannelName(sessionID)
|
||||||
|
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
|
||||||
|
publishErr := common.GlobalRedisClient.Publish(channel, eventData)
|
||||||
|
if publishErr != nil {
|
||||||
|
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.serverName != "" {
|
||||||
|
if common.GlobalRedisClient != nil {
|
||||||
|
// handle default server
|
||||||
|
buffer.Reset()
|
||||||
|
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||||
|
return api.Running
|
||||||
|
} else {
|
||||||
|
buffer.SetString(RedisNotEnabledResponseBody)
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return api.Continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDestroy stops the goroutine
|
||||||
|
func (f *filter) OnDestroy(reason api.DestroyReason) {
|
||||||
|
api.LogDebugf("OnDestroy: reason=%v", reason)
|
||||||
|
if f.serverName != "" && f.stopChan != nil {
|
||||||
|
select {
|
||||||
|
case <-f.stopChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
api.LogDebug("Stopping SSE connection")
|
||||||
|
close(f.stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the request is a tools/call request
|
||||||
|
func checkJSONRPCMethod(body []byte, method string) bool {
|
||||||
|
var request mcp.CallToolRequest
|
||||||
|
if err := json.Unmarshal(body, &request); err != nil {
|
||||||
|
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return request.Method == method
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ type MCPConfigHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewMCPConfigHandler creates a new instance of MCP configuration handler
|
// NewMCPConfigHandler creates a new instance of MCP configuration handler
|
||||||
func NewMCPConfigHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
func NewMCPConfigHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler) *MCPConfigHandler {
|
||||||
return &MCPConfigHandler{
|
return &MCPConfigHandler{
|
||||||
configStore: NewRedisConfigStore(redisClient),
|
configStore: NewRedisConfigStore(redisClient),
|
||||||
callbacks: callbacks,
|
callbacks: callbacks,
|
||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -36,11 +36,11 @@ type ConfigStore interface {
|
|||||||
|
|
||||||
// RedisConfigStore implements configuration storage using Redis
|
// RedisConfigStore implements configuration storage using Redis
|
||||||
type RedisConfigStore struct {
|
type RedisConfigStore struct {
|
||||||
redisClient *internal.RedisClient
|
redisClient *common.RedisClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedisConfigStore creates a new instance of Redis configuration storage
|
// NewRedisConfigStore creates a new instance of Redis configuration storage
|
||||||
func NewRedisConfigStore(redisClient *internal.RedisClient) ConfigStore {
|
func NewRedisConfigStore(redisClient *common.RedisClient) ConfigStore {
|
||||||
return &RedisConfigStore{
|
return &RedisConfigStore{
|
||||||
redisClient: redisClient,
|
redisClient: redisClient,
|
||||||
}
|
}
|
||||||
@@ -101,5 +101,11 @@ func (s *RedisConfigStore) GetConfig(serverName string, uid string) (map[string]
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh TTL
|
||||||
|
if err := s.redisClient.Expire(key, configExpiry); err != nil {
|
||||||
|
// Log error but don't fail the request
|
||||||
|
fmt.Printf("Failed to refresh TTL for key %s: %v\n", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
@@ -8,13 +8,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/internal"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MCPRatelimitHandler struct {
|
type MCPRatelimitHandler struct {
|
||||||
redisClient *internal.RedisClient
|
redisClient *common.RedisClient
|
||||||
callbacks api.FilterCallbackHandler
|
callbacks api.FilterCallbackHandler
|
||||||
limit int // Maximum requests allowed per window
|
limit int // Maximum requests allowed per window
|
||||||
window int // Time window in seconds
|
window int // Time window in seconds
|
||||||
@@ -31,7 +31,7 @@ type MCPRatelimitConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewMCPRatelimitHandler creates a new rate limit handler
|
// NewMCPRatelimitHandler creates a new rate limit handler
|
||||||
func NewMCPRatelimitHandler(redisClient *internal.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
func NewMCPRatelimitHandler(redisClient *common.RedisClient, callbacks api.FilterCallbackHandler, conf *MCPRatelimitConfig) *MCPRatelimitHandler {
|
||||||
if conf == nil {
|
if conf == nil {
|
||||||
conf = &MCPRatelimitConfig{
|
conf = &MCPRatelimitConfig{
|
||||||
Limit: 100,
|
Limit: 100,
|
||||||
@@ -16,24 +16,16 @@
|
|||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
INNER_GO_FILTER_NAME=${GO_FILTER_NAME-""}
|
|
||||||
OUTPUT_PACKAGE_DIR=${OUTPUT_PACKAGE_DIR:-"../../external/package/"}
|
OUTPUT_PACKAGE_DIR=${OUTPUT_PACKAGE_DIR:-"../../external/package/"}
|
||||||
|
|
||||||
cd ./plugins/golang-filter
|
cd plugins/golang-filter
|
||||||
if [ ! -n "$INNER_GO_FILTER_NAME" ]; then
|
|
||||||
GO_FILTERS_DIR=$(pwd)
|
GO_FILTERS_DIR=$(pwd)
|
||||||
echo "🚀 Build all Go Filters under folder of $GO_FILTERS_DIR"
|
|
||||||
for file in `ls $GO_FILTERS_DIR`
|
echo "🚀 Build Go Filter"
|
||||||
do
|
|
||||||
if [ -d $GO_FILTERS_DIR/$file ]; then
|
GOARCH=${TARGET_ARCH} make build
|
||||||
name=${file##*/}
|
|
||||||
echo "🚀 Build Go Filter: $name"
|
cp ${GO_FILTERS_DIR}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
||||||
GO_FILTER_NAME=${name} GOARCH=${TARGET_ARCH} make build
|
|
||||||
cp ${GO_FILTERS_DIR}/${file}/golang-filter_${TARGET_ARCH}.so ${OUTPUT_PACKAGE_DIR}
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo "🚀 Build Go Filter: $INNER_GO_FILTER_NAME"
|
|
||||||
GO_FILTER_NAME=${INNER_GO_FILTER_NAME} make build
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user