diff --git a/plugins/wasm-go/extensions/jsonrpc-converter/go.mod b/plugins/wasm-go/extensions/jsonrpc-converter/go.mod index 5dda4fe1f..d8a9d6ff5 100644 --- a/plugins/wasm-go/extensions/jsonrpc-converter/go.mod +++ b/plugins/wasm-go/extensions/jsonrpc-converter/go.mod @@ -1,10 +1,14 @@ module jsonrpc-converter -go 1.24.3 +go 1.24.1 + +replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp require ( - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 - github.com/higress-group/wasm-go v1.0.4 + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 ) @@ -15,6 +19,7 @@ require ( github.com/Masterminds/sprig/v3 v3.3.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect github.com/huandu/xstrings v1.5.0 // indirect @@ -22,8 +27,10 @@ require ( github.com/mailru/easyjson v0.7.7 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect diff --git a/plugins/wasm-go/extensions/jsonrpc-converter/go.sum b/plugins/wasm-go/extensions/jsonrpc-converter/go.sum index 21b93ff13..2e09ad1d9 100644 --- a/plugins/wasm-go/extensions/jsonrpc-converter/go.sum +++ b/plugins/wasm-go/extensions/jsonrpc-converter/go.sum @@ -20,10 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.4 h1:/GqbzCw4oWqJc8UbKEfF94E3/+4CPZGbzxpKo2L3Ldk= -github.com/higress-group/wasm-go v1.0.4/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= @@ -49,6 +49,8 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc= +github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/plugins/wasm-go/extensions/jsonrpc-converter/main.go b/plugins/wasm-go/extensions/jsonrpc-converter/main.go index dccd9e2dd..616390b55 100644 --- a/plugins/wasm-go/extensions/jsonrpc-converter/main.go +++ b/plugins/wasm-go/extensions/jsonrpc-converter/main.go @@ -9,8 +9,8 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/mcp" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) diff --git a/plugins/wasm-go/extensions/jsonrpc-converter/main_test.go b/plugins/wasm-go/extensions/jsonrpc-converter/main_test.go index dc5ff7764..506acdae6 100644 --- a/plugins/wasm-go/extensions/jsonrpc-converter/main_test.go +++ b/plugins/wasm-go/extensions/jsonrpc-converter/main_test.go @@ -1,9 +1,15 @@ package main import ( + "encoding/json" "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" ) +// TestTruncateString tests the truncateString function func TestTruncateString(t *testing.T) { tests := []struct { name string @@ -14,6 +20,8 @@ func TestTruncateString(t *testing.T) { {"Short String", "Higress Is an AI-Native API Gateway", 1000, "Higress Is an AI-Native API Gateway"}, {"Exact Length", "Higress Is an AI-Native API Gateway", 35, "Higress Is an AI-Native API Gateway"}, {"Truncated String", "Higress Is an AI-Native API Gateway", 20, "Higress Is...(truncated)...PI Gateway"}, + {"Empty String", "", 10, ""}, + {"Single Char", "A", 10, "A"}, } for _, tt := range tests { @@ -26,3 +34,248 @@ func TestTruncateString(t *testing.T) { }) } } + +// TestIsPreRequestStage tests the isPreRequestStage function +func TestIsPreRequestStage(t *testing.T) { + config := McpConverterConfig{Stage: ProcessRequest} + require.True(t, isPreRequestStage(config)) + + config = McpConverterConfig{Stage: ProcessResponse} + require.False(t, isPreRequestStage(config)) +} + +// TestIsPreResponseStage tests the isPreResponseStage function +func TestIsPreResponseStage(t *testing.T) { + config := McpConverterConfig{Stage: ProcessResponse} + require.True(t, isPreResponseStage(config)) + + config = McpConverterConfig{Stage: ProcessRequest} + require.False(t, isPreResponseStage(config)) +} + +// TestIsMethodAllowed tests the isMethodAllowed function +func TestIsMethodAllowed(t *testing.T) { + config := McpConverterConfig{AllowedMethods: []string{MethodToolList, MethodToolCall}} + + require.True(t, isMethodAllowed(config, MethodToolList)) + require.True(t, isMethodAllowed(config, MethodToolCall)) + require.False(t, isMethodAllowed(config, "invalid/method")) +} + +// TestConstants tests the constant values +func TestConstants(t *testing.T) { + require.Equal(t, "x-envoy-jsonrpc-id", JsonRpcId) + require.Equal(t, "x-envoy-jsonrpc-method", JsonRpcMethod) + require.Equal(t, "x-envoy-jsonrpc-params", JsonRpcParams) + require.Equal(t, "x-envoy-jsonrpc-result", JsonRpcResult) + require.Equal(t, "x-envoy-jsonrpc-error", JsonRpcError) + require.Equal(t, "x-envoy-mcp-tool-name", McpToolName) + require.Equal(t, "x-envoy-mcp-tool-arguments", McpToolArguments) + require.Equal(t, "x-envoy-mcp-tool-response", McpToolResponse) + require.Equal(t, "x-envoy-mcp-tool-error", McpToolError) + require.Equal(t, 4000, DefaultMaxHeaderLength) + require.Equal(t, "tools/list", MethodToolList) + require.Equal(t, "tools/call", MethodToolCall) + require.Equal(t, ProcessStage("request"), ProcessRequest) + require.Equal(t, ProcessStage("response"), ProcessResponse) +} + +// TestMcpConverterConfigDefaults tests config default values +func TestMcpConverterConfigDefaults(t *testing.T) { + config := McpConverterConfig{} + require.Equal(t, 0, config.MaxHeaderLength) + require.Equal(t, ProcessStage(""), config.Stage) + require.Nil(t, config.AllowedMethods) +} + +// TestProcessStage tests ProcessStage type +func TestProcessStage(t *testing.T) { + require.Equal(t, ProcessStage("request"), ProcessRequest) + require.Equal(t, ProcessStage("response"), ProcessResponse) +} + +// TestRemoveJsonRpcHeadersFunction tests removeJsonRpcHeaders function logic +func TestRemoveJsonRpcHeadersFunction(t *testing.T) { + headersToRemove := []string{ + JsonRpcId, + JsonRpcMethod, + JsonRpcParams, + JsonRpcResult, + McpToolName, + McpToolArguments, + McpToolResponse, + McpToolError, + } + require.Len(t, headersToRemove, 8) +} + +// TestTruncateStringLong tests truncation of very long strings +func TestTruncateStringLong(t *testing.T) { + longString := "" + for i := 0; i < 5000; i++ { + longString += "a" + } + config := McpConverterConfig{MaxHeaderLength: 1000} + result := truncateString(longString, config) + require.Contains(t, result, "...(truncated)...") + require.LessOrEqual(t, len(result), 1020) +} + +// TestTruncateStringWithSmallMaxLength tests truncation with small max length +func TestTruncateStringWithSmallMaxLength(t *testing.T) { + config := McpConverterConfig{MaxHeaderLength: 10} + result := truncateString("This is a very long string", config) + require.Contains(t, result, "...(truncated)...") +} + +// TestPluginInit tests plugin initialization +func TestPluginInit(t *testing.T) { + configBytes, _ := json.Marshal(McpConverterConfig{ + Stage: ProcessRequest, + MaxHeaderLength: DefaultMaxHeaderLength, + AllowedMethods: []string{MethodToolList, MethodToolCall}, + }) + + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) +} + +// TestProcessJsonRpcRequest tests processJsonRpcRequest function +func TestProcessJsonRpcRequest(t *testing.T) { + configBytes, _ := json.Marshal(McpConverterConfig{ + Stage: ProcessRequest, + MaxHeaderLength: DefaultMaxHeaderLength, + AllowedMethods: []string{MethodToolList, MethodToolCall}, + }) + + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp-server.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + + toolsListRequest := `{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + }` + action := host.CallOnHttpRequestBody([]byte(toolsListRequest)) + require.Equal(t, types.ActionContinue, action) + + host.CompleteHttp() +} + +// TestProcessToolCallRequest tests processToolCallRequest function +func TestProcessToolCallRequest(t *testing.T) { + configBytes, _ := json.Marshal(McpConverterConfig{ + Stage: ProcessRequest, + MaxHeaderLength: DefaultMaxHeaderLength, + AllowedMethods: []string{MethodToolCall}, + }) + + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp-server.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + + toolCallRequest := `{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "test_tool", + "arguments": {"arg1": "value1"} + } + }` + action := host.CallOnHttpRequestBody([]byte(toolCallRequest)) + require.Equal(t, types.ActionContinue, action) + + host.CompleteHttp() +} + +// TestProcessJsonRpcResponse tests processJsonRpcResponse function +func TestProcessJsonRpcResponse(t *testing.T) { + configBytes, _ := json.Marshal(McpConverterConfig{ + Stage: ProcessResponse, + MaxHeaderLength: DefaultMaxHeaderLength, + AllowedMethods: []string{MethodToolList, MethodToolCall}, + }) + + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp-server.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + + responseBody := `{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [{"name": "test_tool"}] + } + }` + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpResponseBody([]byte(responseBody)) + + host.CompleteHttp() +} + +// TestProcessToolListResponse tests processToolListResponse function +func TestProcessToolListResponse(t *testing.T) { + configBytes, _ := json.Marshal(McpConverterConfig{ + Stage: ProcessResponse, + MaxHeaderLength: DefaultMaxHeaderLength, + AllowedMethods: []string{MethodToolList}, + }) + + host, status := test.NewTestHost(configBytes) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp-server.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + + responseBody := `{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [{"name": "test_tool"}] + } + }` + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpResponseBody([]byte(responseBody)) + + host.CompleteHttp() +} diff --git a/plugins/wasm-go/mcp-filters/mcp-router/README.md b/plugins/wasm-go/extensions/mcp-router/README.md similarity index 100% rename from plugins/wasm-go/mcp-filters/mcp-router/README.md rename to plugins/wasm-go/extensions/mcp-router/README.md diff --git a/plugins/wasm-go/mcp-filters/mcp-router/README_ZH.md b/plugins/wasm-go/extensions/mcp-router/README_ZH.md similarity index 100% rename from plugins/wasm-go/mcp-filters/mcp-router/README_ZH.md rename to plugins/wasm-go/extensions/mcp-router/README_ZH.md diff --git a/plugins/wasm-go/mcp-filters/mcp-router/go.mod b/plugins/wasm-go/extensions/mcp-router/go.mod similarity index 80% rename from plugins/wasm-go/mcp-filters/mcp-router/go.mod rename to plugins/wasm-go/extensions/mcp-router/go.mod index 02a41e492..5eebdf1d1 100644 --- a/plugins/wasm-go/mcp-filters/mcp-router/go.mod +++ b/plugins/wasm-go/extensions/mcp-router/go.mod @@ -2,9 +2,12 @@ module mcp-router go 1.24.1 +replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp + require ( - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/plugins/wasm-go/mcp-filters/mcp-router/go.sum b/plugins/wasm-go/extensions/mcp-router/go.sum similarity index 89% rename from plugins/wasm-go/mcp-filters/mcp-router/go.sum rename to plugins/wasm-go/extensions/mcp-router/go.sum index 76296cc9d..06b699b85 100644 --- a/plugins/wasm-go/mcp-filters/mcp-router/go.sum +++ b/plugins/wasm-go/extensions/mcp-router/go.sum @@ -20,12 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57 h1:WhNdnKSDtAQrh4Yil8HAtbl7VW+WC85m7WS8kirnHAA= -github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= -github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 h1:2wlbNpFJCQNbPBFYgswz7Zvxo9O3L0PH0AJxwiCc5lk= -github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= diff --git a/plugins/wasm-go/mcp-filters/mcp-router/main.go b/plugins/wasm-go/extensions/mcp-router/main.go similarity index 97% rename from plugins/wasm-go/mcp-filters/mcp-router/main.go rename to plugins/wasm-go/extensions/mcp-router/main.go index 49003be9c..95e4c3e82 100644 --- a/plugins/wasm-go/mcp-filters/mcp-router/main.go +++ b/plugins/wasm-go/extensions/mcp-router/main.go @@ -22,8 +22,8 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/mcp" - "github.com/higress-group/wasm-go/pkg/mcp/consts" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/sjson" diff --git a/plugins/wasm-go/mcp-servers/all-in-one/go.mod b/plugins/wasm-go/extensions/mcp-server/go.mod similarity index 86% rename from plugins/wasm-go/mcp-servers/all-in-one/go.mod rename to plugins/wasm-go/extensions/mcp-server/go.mod index e4aeeff16..d94903e97 100644 --- a/plugins/wasm-go/mcp-servers/all-in-one/go.mod +++ b/plugins/wasm-go/extensions/mcp-server/go.mod @@ -1,13 +1,16 @@ -module all-in-one +module mcp-server go 1.24.1 -replace quark-search => ../quark-search - -replace amap-tools => ../amap-tools +replace ( + amap-tools => ../../mcp-servers/amap-tools + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp + quark-search => ../../mcp-servers/quark-search +) require ( amap-tools v0.0.0-00010101000000-000000000000 + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 github.com/stretchr/testify v1.9.0 diff --git a/plugins/wasm-go/mcp-servers/all-in-one/go.sum b/plugins/wasm-go/extensions/mcp-server/go.sum similarity index 93% rename from plugins/wasm-go/mcp-servers/all-in-one/go.sum rename to plugins/wasm-go/extensions/mcp-server/go.sum index 8f0302c07..2e09ad1d9 100644 --- a/plugins/wasm-go/mcp-servers/all-in-one/go.sum +++ b/plugins/wasm-go/extensions/mcp-server/go.sum @@ -22,10 +22,6 @@ github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rR github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= -github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500 h1:4BKKZ3BreIaIGub88nlvzihTK1uJmZYYoQ7r7Xkgb5Q= -github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= -github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1 h1:+usoX0B1cwECTA2qf73IaLGyCIMVopIMev5cBWGgEZk= -github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= diff --git a/plugins/wasm-go/mcp-servers/all-in-one/main.go b/plugins/wasm-go/extensions/mcp-server/main.go similarity index 94% rename from plugins/wasm-go/mcp-servers/all-in-one/main.go rename to plugins/wasm-go/extensions/mcp-server/main.go index 5451bb5fe..e4bdfb4b5 100644 --- a/plugins/wasm-go/mcp-servers/all-in-one/main.go +++ b/plugins/wasm-go/extensions/mcp-server/main.go @@ -18,7 +18,7 @@ import ( amap "amap-tools/tools" quark "quark-search/tools" - "github.com/higress-group/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" ) func main() {} diff --git a/plugins/wasm-go/mcp-servers/all-in-one/main_test.go b/plugins/wasm-go/extensions/mcp-server/main_test.go similarity index 100% rename from plugins/wasm-go/mcp-servers/all-in-one/main_test.go rename to plugins/wasm-go/extensions/mcp-server/main_test.go diff --git a/plugins/wasm-go/mcp-filters/Dockerfile b/plugins/wasm-go/mcp-filters/Dockerfile deleted file mode 100644 index eab97b499..000000000 --- a/plugins/wasm-go/mcp-filters/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -# Use a minimal base image as we only need to store the wasm file. -FROM scratch - -# Add build argument for the filter name. This will be passed by the Makefile. -ARG FILTER_NAME - -# Copy the compiled WASM binary into the image's root directory. -# The wasm file will be named after the filter. -COPY ${FILTER_NAME}/main.wasm /plugin.wasm - -# Metadata -LABEL org.opencontainers.image.title="${FILTER_NAME}" -LABEL org.opencontainers.image.description="Higress MCP filter - ${FILTER_NAME}" -LABEL org.opencontainers.image.source="https://github.com/alibaba/higress" \ No newline at end of file diff --git a/plugins/wasm-go/mcp-filters/Makefile b/plugins/wasm-go/mcp-filters/Makefile deleted file mode 100644 index ee41eb596..000000000 --- a/plugins/wasm-go/mcp-filters/Makefile +++ /dev/null @@ -1,54 +0,0 @@ -# MCP Filter Makefile - -# Variables -FILTER_NAME ?= mcp-router -REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ -BUILD_TIME := $(shell date "+%Y%m%d-%H%M%S") -COMMIT_ID := $(shell git rev-parse --short HEAD 2>/dev/null) -IMAGE_TAG = $(if $(strip $(FILTER_VERSION)),${FILTER_VERSION},${BUILD_TIME}-${COMMIT_ID}) -IMG ?= ${REGISTRY}${FILTER_NAME}:${IMAGE_TAG} - -# Default target -.DEFAULT: build - -build: - @echo "Building WASM binary for filter: ${FILTER_NAME}..." - @if [ ! -d "${FILTER_NAME}" ]; then \ - echo "Error: Filter directory '${FILTER_NAME}' not found."; \ - exit 1; \ - fi - cd ${FILTER_NAME} && GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go - @echo "" - @echo "Output WASM file: ${FILTER_NAME}/main.wasm" - -# Build Docker image (depends on build target to ensure WASM binary exists) -build-image: build - @echo "Building Docker image for ${FILTER_NAME}..." - docker build -t ${IMG} \ - --build-arg FILTER_NAME=${FILTER_NAME} \ - -f Dockerfile . - @echo "" - @echo "Image: ${IMG}" - -# Build and push Docker image -build-push: build-image - docker push ${IMG} - -# Clean build artifacts -clean: - @echo "Cleaning build artifacts for filter: ${FILTER_NAME}..." - rm -f ${FILTER_NAME}/main.wasm - -# Help -help: - @echo "Available targets:" - @echo " build - Build WASM binary for a specific filter" - @echo " build-image - Build Docker image" - @echo " build-push - Build and push Docker image" - @echo " clean - Remove build artifacts for a specific filter" - @echo "" - @echo "Variables:" - @echo " FILTER_NAME - Name of the MCP filter to build (default: ${FILTER_NAME})" - @echo " REGISTRY - Docker registry (default: ${REGISTRY})" - @echo " FILTER_VERSION - Version tag for the image (default: timestamp-commit)" - @echo " IMG - Full image name (default: ${IMG})" diff --git a/plugins/wasm-go/mcp-servers/README.md b/plugins/wasm-go/mcp-servers/README.md index 66e784954..dff79f8c0 100644 --- a/plugins/wasm-go/mcp-servers/README.md +++ b/plugins/wasm-go/mcp-servers/README.md @@ -80,8 +80,8 @@ import ( "net/http" "my-mcp-server/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) // Define your tool structure with input parameters @@ -145,8 +145,8 @@ For better organization, you can create a separate file to load all your tools: package tools import ( - "github.com/higress-group/wasm-go/pkg/mcp" - "github.com/higress-group/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" ) func LoadTools(server *mcp.MCPServer) server.Server { @@ -170,7 +170,7 @@ import ( amap "amap-tools/tools" quark "quark-search/tools" - "github.com/higress-group/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" ) func main() {} @@ -375,7 +375,7 @@ package main import ( "my-mcp-server/tools" - "github.com/higress-group/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" ) func main() {} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/go.mod b/plugins/wasm-go/mcp-servers/amap-tools/go.mod index 89c466607..5b3b54344 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/go.mod +++ b/plugins/wasm-go/mcp-servers/amap-tools/go.mod @@ -2,9 +2,12 @@ module amap-tools go 1.24.1 +replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp + require ( - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 - github.com/higress-group/wasm-go v1.0.0 + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 ) require ( @@ -23,6 +26,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/plugins/wasm-go/mcp-servers/amap-tools/main.go b/plugins/wasm-go/mcp-servers/amap-tools/main.go index d5f46bca4..c0c698c21 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/main.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/main.go @@ -17,7 +17,7 @@ package main import ( "amap-tools/tools" - "github.com/higress-group/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" ) func main() {} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/load_tools.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/load_tools.go index ffb71e1dc..66ce3b4e5 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/load_tools.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/load_tools.go @@ -15,8 +15,8 @@ package tools import ( - "github.com/higress-group/wasm-go/pkg/mcp" - "github.com/higress-group/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" ) func LoadTools(server *mcp.MCPServer) server.Server { diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_around_search.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_around_search.go index a477a4a64..a7da9f895 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_around_search.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_around_search.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = AroundSearchRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_bicycling.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_bicycling.go index fd65723ab..1d47e7770 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_bicycling.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_bicycling.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = BicyclingRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_driving.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_driving.go index 12530c11c..defc8f51f 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_driving.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_driving.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = DrivingRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_transit_integrated.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_transit_integrated.go index 65f768c38..feab0c4b0 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_transit_integrated.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_transit_integrated.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = TransitIntegratedRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_walking.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_walking.go index 53a059ffd..349b34616 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_walking.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_direction_walking.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = WalkingRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_distance.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_distance.go index 217fd23d1..f9bc23b0e 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_distance.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_distance.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = DistanceRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_geo.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_geo.go index 78cce2ce1..ca04fd6a2 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_geo.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_geo.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = GeoRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_ip_location.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_ip_location.go index d1ff84d81..6ba6de50c 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_ip_location.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_ip_location.go @@ -24,8 +24,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_regeocode.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_regeocode.go index d1fb9e297..3d403453a 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_regeocode.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_regeocode.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = ReGeocodeRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_search_detail.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_search_detail.go index cf2c7042d..a83980349 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_search_detail.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_search_detail.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = SearchDetailRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_text_search.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_text_search.go index d70d76de2..c544f25b7 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_text_search.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_text_search.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = TextSearchRequest{} diff --git a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_weather.go b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_weather.go index 114235d25..656a8e2ac 100644 --- a/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_weather.go +++ b/plugins/wasm-go/mcp-servers/amap-tools/tools/maps_weather.go @@ -23,8 +23,8 @@ import ( "amap-tools/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" ) var _ server.Tool = WeatherRequest{} diff --git a/plugins/wasm-go/mcp-servers/quark-search/go.mod b/plugins/wasm-go/mcp-servers/quark-search/go.mod index ebf1df7a7..c757a26aa 100644 --- a/plugins/wasm-go/mcp-servers/quark-search/go.mod +++ b/plugins/wasm-go/mcp-servers/quark-search/go.mod @@ -2,8 +2,11 @@ module quark-search go 1.24.1 +replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp + require ( - github.com/higress-group/wasm-go v1.0.0 + github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0 + github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 github.com/tidwall/gjson v1.18.0 ) @@ -16,7 +19,7 @@ require ( github.com/buger/jsonparser v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 // indirect + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect @@ -24,6 +27,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect diff --git a/plugins/wasm-go/mcp-servers/quark-search/main.go b/plugins/wasm-go/mcp-servers/quark-search/main.go index 6c04f9bd5..3b728c5fd 100644 --- a/plugins/wasm-go/mcp-servers/quark-search/main.go +++ b/plugins/wasm-go/mcp-servers/quark-search/main.go @@ -17,7 +17,7 @@ package main import ( "quark-search/tools" - "github.com/higress-group/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" ) func main() {} diff --git a/plugins/wasm-go/mcp-servers/quark-search/tools/load_tools.go b/plugins/wasm-go/mcp-servers/quark-search/tools/load_tools.go index a5f917ad5..49d062888 100644 --- a/plugins/wasm-go/mcp-servers/quark-search/tools/load_tools.go +++ b/plugins/wasm-go/mcp-servers/quark-search/tools/load_tools.go @@ -15,8 +15,8 @@ package tools import ( - "github.com/higress-group/wasm-go/pkg/mcp" - "github.com/higress-group/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" ) func LoadTools(server *mcp.MCPServer) server.Server { diff --git a/plugins/wasm-go/mcp-servers/quark-search/tools/web_search.go b/plugins/wasm-go/mcp-servers/quark-search/tools/web_search.go index b1ed35108..f1d516602 100644 --- a/plugins/wasm-go/mcp-servers/quark-search/tools/web_search.go +++ b/plugins/wasm-go/mcp-servers/quark-search/tools/web_search.go @@ -24,8 +24,8 @@ import ( "quark-search/config" - "github.com/higress-group/wasm-go/pkg/mcp/server" - "github.com/higress-group/wasm-go/pkg/mcp/utils" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" "github.com/tidwall/gjson" ) diff --git a/plugins/wasm-go/pkg/log/log.go b/plugins/wasm-go/pkg/log/log.go deleted file mode 100644 index 131c32991..000000000 --- a/plugins/wasm-go/pkg/log/log.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -type Log interface { - Trace(msg string) - Tracef(format string, args ...interface{}) - Debug(msg string) - Debugf(format string, args ...interface{}) - Info(msg string) - Infof(format string, args ...interface{}) - Warn(msg string) - Warnf(format string, args ...interface{}) - Error(msg string) - Errorf(format string, args ...interface{}) - Critical(msg string) - Criticalf(format string, args ...interface{}) - ResetID(pluginID string) -} - -var pluginLog Log - -func SetPluginLog(log Log) { - pluginLog = log -} - -func Trace(msg string) { - pluginLog.Trace(msg) -} - -func Tracef(format string, args ...interface{}) { - pluginLog.Tracef(format, args...) -} - -func Debug(msg string) { - pluginLog.Debug(msg) -} - -func Debugf(format string, args ...interface{}) { - pluginLog.Debugf(format, args...) -} - -func Info(msg string) { - pluginLog.Info(msg) -} - -func Infof(format string, args ...interface{}) { - pluginLog.Infof(format, args...) -} - -func Warn(msg string) { - pluginLog.Warn(msg) -} - -func Warnf(format string, args ...interface{}) { - pluginLog.Warnf(format, args...) -} - -func Error(msg string) { - pluginLog.Error(msg) -} - -func Errorf(format string, args ...interface{}) { - pluginLog.Errorf(format, args...) -} - -func Critical(msg string) { - pluginLog.Critical(msg) -} - -func Criticalf(format string, args ...interface{}) { - pluginLog.Criticalf(format, args...) -} diff --git a/plugins/wasm-go/pkg/matcher/rule_matcher.go b/plugins/wasm-go/pkg/matcher/rule_matcher.go deleted file mode 100644 index acc0bbb82..000000000 --- a/plugins/wasm-go/pkg/matcher/rule_matcher.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package matcher - -import ( - "errors" - "fmt" - "strings" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "github.com/tidwall/gjson" -) - -type Category int - -const ( - Route Category = iota - Host - Service - RoutePrefix -) - -type MatchType int - -const ( - Prefix MatchType = iota - Exact - Suffix -) - -const ( - RULES_KEY = "_rules_" - MATCH_ROUTE_KEY = "_match_route_" - MATCH_DOMAIN_KEY = "_match_domain_" - MATCH_SERVICE_KEY = "_match_service_" - MATCH_ROUTE_PREFIX_KEY = "_match_route_prefix_" -) - -type HostMatcher struct { - matchType MatchType - host string -} - -type RuleConfig[PluginConfig any] struct { - category Category - routes map[string]struct{} - services map[string]struct{} - routePrefixs map[string]struct{} - hosts []HostMatcher - config PluginConfig -} - -type RuleMatcher[PluginConfig any] struct { - ruleConfig []RuleConfig[PluginConfig] - globalConfig PluginConfig - hasGlobalConfig bool -} - -func (m RuleMatcher[PluginConfig]) GetMatchConfig() (*PluginConfig, error) { - host, err := proxywasm.GetHttpRequestHeader(":authority") - if err != nil { - return nil, err - } - routeName, err := proxywasm.GetProperty([]string{"route_name"}) - if err != nil && err != types.ErrorStatusNotFound { - return nil, err - } - serviceName, err := proxywasm.GetProperty([]string{"cluster_name"}) - if err != nil && err != types.ErrorStatusNotFound { - return nil, err - } - for _, rule := range m.ruleConfig { - // category == Host - if rule.category == Host { - if m.hostMatch(rule, host) { - return &rule.config, nil - } - } - // category == Route - if rule.category == Route { - if _, ok := rule.routes[string(routeName)]; ok { - return &rule.config, nil - } - } - // category == RoutePrefix - if rule.category == RoutePrefix { - for routePrefix := range rule.routePrefixs { - if strings.HasPrefix(string(routeName), routePrefix) { - return &rule.config, nil - } - } - } - // category == Cluster - if m.serviceMatch(rule, string(serviceName)) { - return &rule.config, nil - } - } - if m.hasGlobalConfig { - return &m.globalConfig, nil - } - return nil, nil -} - -func (m *RuleMatcher[PluginConfig]) ParseRuleConfig(config gjson.Result, - parsePluginConfig func(gjson.Result, *PluginConfig) error, - parseOverrideConfig func(gjson.Result, PluginConfig, *PluginConfig) error) error { - var rules []gjson.Result - obj := config.Map() - keyCount := len(obj) - if keyCount == 0 { - // enable globally for empty config - m.hasGlobalConfig = true - return parsePluginConfig(config, &m.globalConfig) - } - if rulesJson, ok := obj[RULES_KEY]; ok { - rules = rulesJson.Array() - keyCount-- - } - var pluginConfig PluginConfig - var globalConfigError error - if keyCount > 0 { - err := parsePluginConfig(config, &pluginConfig) - if err != nil { - globalConfigError = err - } else { - m.globalConfig = pluginConfig - m.hasGlobalConfig = true - } - } - if len(rules) == 0 { - if m.hasGlobalConfig { - return nil - } - return fmt.Errorf("parse config failed, no valid rules; global config parse error:%v", globalConfigError) - } - for _, ruleJson := range rules { - var ( - rule RuleConfig[PluginConfig] - err error - ) - if parseOverrideConfig != nil { - err = parseOverrideConfig(ruleJson, m.globalConfig, &rule.config) - } else { - err = parsePluginConfig(ruleJson, &rule.config) - } - if err != nil { - return err - } - rule.routes = m.parseRouteMatchConfig(ruleJson) - rule.hosts = m.parseHostMatchConfig(ruleJson) - rule.services = m.parseServiceMatchConfig(ruleJson) - rule.routePrefixs = m.parseRoutePrefixMatchConfig(ruleJson) - noRoute := len(rule.routes) == 0 - noHosts := len(rule.hosts) == 0 - noService := len(rule.services) == 0 - noRoutePrefix := len(rule.routePrefixs) == 0 - if boolToInt(noRoute)+boolToInt(noService)+boolToInt(noHosts)+boolToInt(noRoutePrefix) != 3 { - return errors.New("there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.") - } - if !noRoute { - rule.category = Route - } else if !noHosts { - rule.category = Host - } else if !noService { - rule.category = Service - } else { - rule.category = RoutePrefix - } - m.ruleConfig = append(m.ruleConfig, rule) - } - return nil -} - -func (m RuleMatcher[PluginConfig]) parseRouteMatchConfig(config gjson.Result) map[string]struct{} { - keys := config.Get(MATCH_ROUTE_KEY).Array() - routes := make(map[string]struct{}) - for _, item := range keys { - routeName := item.String() - if routeName != "" { - routes[routeName] = struct{}{} - } - } - return routes -} - -func (m RuleMatcher[PluginConfig]) parseRoutePrefixMatchConfig(config gjson.Result) map[string]struct{} { - keys := config.Get(MATCH_ROUTE_PREFIX_KEY).Array() - routePrefixs := make(map[string]struct{}) - for _, item := range keys { - routePrefix := item.String() - if routePrefix != "" { - routePrefixs[routePrefix] = struct{}{} - } - } - return routePrefixs -} - -func (m RuleMatcher[PluginConfig]) parseServiceMatchConfig(config gjson.Result) map[string]struct{} { - keys := config.Get(MATCH_SERVICE_KEY).Array() - clusters := make(map[string]struct{}) - for _, item := range keys { - clusterName := item.String() - if clusterName != "" { - clusters[clusterName] = struct{}{} - } - } - return clusters -} - -func (m RuleMatcher[PluginConfig]) parseHostMatchConfig(config gjson.Result) []HostMatcher { - keys := config.Get(MATCH_DOMAIN_KEY).Array() - var hostMatchers []HostMatcher - for _, item := range keys { - host := item.String() - var hostMatcher HostMatcher - if strings.HasPrefix(host, "*") { - hostMatcher.matchType = Suffix - hostMatcher.host = host[1:] - } else if strings.HasSuffix(host, "*") { - hostMatcher.matchType = Prefix - hostMatcher.host = host[:len(host)-1] - } else { - hostMatcher.matchType = Exact - hostMatcher.host = host - } - hostMatchers = append(hostMatchers, hostMatcher) - } - return hostMatchers -} - -func stripPortFromHost(reqHost string) string { - // Port removing code is inspired by - // https://github.com/envoyproxy/envoy/blob/v1.17.0/source/common/http/header_utility.cc#L219 - portStart := strings.LastIndexByte(reqHost, ':') - if portStart != -1 { - // According to RFC3986 v6 address is always enclosed in "[]". - // section 3.2.2. - v6EndIndex := strings.LastIndexByte(reqHost, ']') - if v6EndIndex == -1 || v6EndIndex < portStart { - if portStart+1 <= len(reqHost) { - return reqHost[:portStart] - } - } - } - return reqHost -} - -func (m RuleMatcher[PluginConfig]) hostMatch(rule RuleConfig[PluginConfig], reqHost string) bool { - reqHost = stripPortFromHost(reqHost) - for _, hostMatch := range rule.hosts { - switch hostMatch.matchType { - case Suffix: - if strings.HasSuffix(reqHost, hostMatch.host) { - return true - } - case Prefix: - if strings.HasPrefix(reqHost, hostMatch.host) { - return true - } - case Exact: - if reqHost == hostMatch.host { - return true - } - default: - return false - } - } - return false -} - -func (m RuleMatcher[PluginConfig]) serviceMatch(rule RuleConfig[PluginConfig], serviceName string) bool { - parts := strings.Split(serviceName, "|") - if len(parts) != 4 { - return false - } - port := parts[1] - fqdn := parts[3] - for configServiceName := range rule.services { - colonIndex := strings.LastIndexByte(configServiceName, ':') - if colonIndex != -1 && fqdn == string(configServiceName[:colonIndex]) && port == string(configServiceName[colonIndex+1:]) { - return true - } else if fqdn == string(configServiceName) { - return true - } - } - return false -} diff --git a/plugins/wasm-go/pkg/matcher/rule_matcher_test.go b/plugins/wasm-go/pkg/matcher/rule_matcher_test.go deleted file mode 100644 index 460f414aa..000000000 --- a/plugins/wasm-go/pkg/matcher/rule_matcher_test.go +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package matcher - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tidwall/gjson" -) - -type customConfig struct { - name string - age int64 -} - -func parseConfig(json gjson.Result, config *customConfig) error { - config.name = json.Get("name").String() - config.age = json.Get("age").Int() - return nil -} - -func TestHostMatch(t *testing.T) { - cases := []struct { - name string - config RuleConfig[customConfig] - host string - result bool - }{ - { - name: "prefix", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Prefix, - host: "www.", - }, - }, - }, - host: "www.test.com", - result: true, - }, - { - name: "prefix failed", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Prefix, - host: "www.", - }, - }, - }, - host: "test.com", - result: false, - }, - { - name: "suffix", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Suffix, - host: ".example.com", - }, - }, - }, - host: "www.example.com", - result: true, - }, - { - name: "suffix failed", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Suffix, - host: ".example.com", - }, - }, - }, - host: "example.com", - result: false, - }, - { - name: "exact", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Exact, - host: "www.example.com", - }, - }, - }, - host: "www.example.com", - result: true, - }, - { - name: "exact failed", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Exact, - host: "www.example.com", - }, - }, - }, - host: "example.com", - result: false, - }, - { - name: "exact port", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Exact, - host: "www.example.com", - }, - }, - }, - host: "www.example.com:8080", - result: true, - }, - { - name: "any", - config: RuleConfig[customConfig]{ - hosts: []HostMatcher{ - { - matchType: Suffix, - host: "", - }, - }, - }, - host: "www.example.com", - result: true, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var m RuleMatcher[customConfig] - assert.Equal(t, c.result, m.hostMatch(c.config, c.host)) - }) - } -} - -func TestServiceMatch(t *testing.T) { - cases := []struct { - name string - config RuleConfig[customConfig] - service string - result bool - }{ - { - name: "fqdn", - config: RuleConfig[customConfig]{ - services: map[string]struct{}{ - "qwen.dns": {}, - }, - }, - service: "outbound|443||qwen.dns", - result: true, - }, - { - name: "fqdn with port", - config: RuleConfig[customConfig]{ - services: map[string]struct{}{ - "qwen.dns:443": {}, - }, - }, - service: "outbound|443||qwen.dns", - result: true, - }, - { - name: "not match", - config: RuleConfig[customConfig]{ - services: map[string]struct{}{ - "moonshot.dns:443": {}, - }, - }, - service: "outbound|443||qwen.dns", - result: false, - }, - { - name: "error config format", - config: RuleConfig[customConfig]{ - services: map[string]struct{}{ - "qwen.dns:": {}, - }, - }, - service: "outbound|443||qwen.dns", - result: false, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var m RuleMatcher[customConfig] - assert.Equal(t, c.result, m.serviceMatch(c.config, c.service)) - }) - } -} - -func TestParseRuleConfig(t *testing.T) { - cases := []struct { - name string - config string - errMsg string - expected RuleMatcher[customConfig] - }{ - { - name: "global config", - config: `{"name":"john", "age":18}`, - expected: RuleMatcher[customConfig]{ - globalConfig: customConfig{ - name: "john", - age: 18, - }, - hasGlobalConfig: true, - }, - }, - { - name: "rules config", - config: `{"_rules_":[{"_match_domain_":["*.example.com","www.*","*","www.abc.com"],"name":"john", "age":18},{"_match_route_":["test1","test2"],"name":"ann", "age":16},{"_match_service_":["test1.dns","test2.static:8080"],"name":"ann", "age":16},{"_match_route_prefix_":["api1","api2"],"name":"ann", "age":16}]}`, - expected: RuleMatcher[customConfig]{ - ruleConfig: []RuleConfig[customConfig]{ - { - category: Host, - hosts: []HostMatcher{ - { - matchType: Suffix, - host: ".example.com", - }, - { - matchType: Prefix, - host: "www.", - }, - { - matchType: Suffix, - host: "", - }, - { - matchType: Exact, - host: "www.abc.com", - }, - }, - routes: map[string]struct{}{}, - services: map[string]struct{}{}, - routePrefixs: map[string]struct{}{}, - config: customConfig{ - name: "john", - age: 18, - }, - }, - { - category: Route, - routes: map[string]struct{}{ - "test1": {}, - "test2": {}, - }, - services: map[string]struct{}{}, - routePrefixs: map[string]struct{}{}, - config: customConfig{ - name: "ann", - age: 16, - }, - }, - { - category: Service, - routes: map[string]struct{}{}, - services: map[string]struct{}{ - "test1.dns": {}, - "test2.static:8080": {}, - }, - routePrefixs: map[string]struct{}{}, - config: customConfig{ - name: "ann", - age: 16, - }, - }, - { - category: RoutePrefix, - routes: map[string]struct{}{}, - services: map[string]struct{}{}, - routePrefixs: map[string]struct{}{ - "api1": {}, - "api2": {}, - }, - config: customConfig{ - name: "ann", - age: 16, - }, - }, - }, - }, - }, - { - name: "no rule", - config: `{"_rules_":[]}`, - errMsg: "parse config failed, no valid rules; global config parse error:", - }, - { - name: "invalid rule", - config: `{"_rules_":[{"_match_domain_":["*"],"_match_route_":["test"]}]}`, - errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.", - }, - { - name: "invalid rule", - config: `{"_rules_":[{"_match_domain_":["*"],"_match_service_":["test.dns"]}]}`, - errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.", - }, - { - name: "invalid rule", - config: `{"_rules_":[{"age":16}]}`, - errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.", - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var actual RuleMatcher[customConfig] - err := actual.ParseRuleConfig(gjson.Parse(c.config), parseConfig, nil) - if err != nil { - if c.errMsg == "" { - t.Errorf("parse failed: %v", err) - } - if err.Error() != c.errMsg { - t.Errorf("expect err: %s, actual err: %s", c.errMsg, - err.Error()) - } - return - } - assert.Equal(t, c.expected, actual) - }) - } -} - -type completeConfig struct { - // global config - consumers []string - // rule config - allow []string -} - -func parseGlobalConfig(json gjson.Result, global *completeConfig) error { - if json.Get("consumers").Exists() && json.Get("allow").Exists() { - return errors.New("consumers and allow should not be configured at the same level") - } - - for _, item := range json.Get("consumers").Array() { - global.consumers = append(global.consumers, item.String()) - } - - return nil -} - -func parseOverrideRuleConfig(json gjson.Result, global completeConfig, config *completeConfig) error { - if json.Get("consumers").Exists() && json.Get("allow").Exists() { - return errors.New("consumers and allow should not be configured at the same level") - } - - // override config via global - *config = global - - for _, item := range json.Get("allow").Array() { - config.allow = append(config.allow, item.String()) - } - - return nil -} - -func TestParseOverrideConfig(t *testing.T) { - cases := []struct { - name string - config string - errMsg string - expected RuleMatcher[completeConfig] - }{ - { - name: "override rule config", - config: `{"consumers":["c1","c2","c3"],"_rules_":[{"_match_route_":["r1","r2"],"allow":["c1","c3"]}]}`, - expected: RuleMatcher[completeConfig]{ - ruleConfig: []RuleConfig[completeConfig]{ - { - category: Route, - routes: map[string]struct{}{ - "r1": {}, - "r2": {}, - }, - services: map[string]struct{}{}, - routePrefixs: map[string]struct{}{}, - config: completeConfig{ - consumers: []string{"c1", "c2", "c3"}, - allow: []string{"c1", "c3"}, - }, - }, - }, - globalConfig: completeConfig{ - consumers: []string{"c1", "c2", "c3"}, - }, - hasGlobalConfig: true, - }, - }, - { - name: "invalid config", - config: `{"consumers":["c1","c2","c3"],"allow":["c1"]}`, - errMsg: "parse config failed, no valid rules; global config parse error:consumers and allow should not be configured at the same level", - }, - { - name: "invalid config", - config: `{"_rules_":[{"_match_route_":["r1","r2"],"consumers":["c1","c2"],"allow":["c1"]}]}`, - errMsg: "consumers and allow should not be configured at the same level", - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var actual RuleMatcher[completeConfig] - err := actual.ParseRuleConfig(gjson.Parse(c.config), parseGlobalConfig, parseOverrideRuleConfig) - if err != nil { - if c.errMsg == "" { - t.Errorf("parse failed: %v", err) - } - if err.Error() != c.errMsg { - t.Errorf("expect err: %s, actual err: %s", c.errMsg, err.Error()) - } - return - } - assert.Equal(t, c.expected, actual) - }) - } -} diff --git a/plugins/wasm-go/pkg/matcher/utils.go b/plugins/wasm-go/pkg/matcher/utils.go deleted file mode 100644 index daa8c1b1c..000000000 --- a/plugins/wasm-go/pkg/matcher/utils.go +++ /dev/null @@ -1,8 +0,0 @@ -package matcher - -func boolToInt(b bool) int { - if b { - return 1 - } - return 0 -} diff --git a/plugins/wasm-go/pkg/wrapper/response_wrapper.go b/plugins/wasm-go/pkg/mcp/consts/vars.go similarity index 57% rename from plugins/wasm-go/pkg/wrapper/response_wrapper.go rename to plugins/wasm-go/pkg/mcp/consts/vars.go index 3cd91daf3..d07833310 100644 --- a/plugins/wasm-go/pkg/wrapper/response_wrapper.go +++ b/plugins/wasm-go/pkg/mcp/consts/vars.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,17 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -package wrapper +package consts -import ( - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" +const ( + ToolSetNameSplitter = "___" ) - -func IsResponseFromUpstream() bool { - if codeDetails, err := proxywasm.GetProperty([]string{"response", "code_details"}); err == nil { - return string(codeDetails) == "via_upstream" - } else { - proxywasm.LogErrorf("get response code details failed: %v", err) - return false - } -} diff --git a/plugins/wasm-go/pkg/mcp/filter/plugin.go b/plugins/wasm-go/pkg/mcp/filter/plugin.go new file mode 100644 index 000000000..baae2ad09 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/filter/plugin.go @@ -0,0 +1,353 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filter + +import ( + "github.com/tidwall/gjson" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +const ( + defaultMaxBodyBytes uint32 = 100 * 1024 * 1024 +) + +type HTTPFilterF func(context wrapper.HttpContext, config any, headers [][2]string, body []byte) types.Action + +type ToolCallRequestFilterF func(context wrapper.HttpContext, config any, toolName string, toolArgs gjson.Result, rawBody []byte) types.Action + +type ToolCallResponseFilterF func(context wrapper.HttpContext, config any, isError bool, content gjson.Result, rawBody []byte) types.Action + +type ToolListResponseFilterF func(context wrapper.HttpContext, config any, tools gjson.Result, rawBody []byte) types.Action + +type JsonRpcRequestFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action + +type JsonRpcResponseFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action + +type Context struct { + filterName string + httpRequestFilter HTTPFilterF + httpResponseFilter HTTPFilterF + jsonRpcRequestFilter JsonRpcRequestFilterF + jsonRpcResponseFilter JsonRpcResponseFilterF + toolCallRequestFilter ToolCallRequestFilterF + toolCallResponseFilter ToolCallResponseFilterF + toolListResponseFilter ToolListResponseFilterF + parseFilterConfig ParseFilterConfigF + parseFilterRuleOverrideConfig ParseFilterRuleOverrideConfigF +} + +type CtxOption interface { + Apply(*Context) +} + +var globalContext Context + +type ParseFilterConfigF func(configBytes []byte, filterConfig *any) error + +type ParseFilterRuleOverrideConfigF func(configBytes []byte, filterGlobalConfig any, filterConfig *any) error + +type setConfigParserOption struct { + f ParseFilterConfigF + g ParseFilterRuleOverrideConfigF +} + +func SetConfigParser(f ParseFilterConfigF) CtxOption { + return &setConfigParserOption{ + f: f, + } +} + +func SetConfigOverrideParser(f ParseFilterConfigF, g ParseFilterRuleOverrideConfigF) CtxOption { + return &setConfigParserOption{ + f: f, + g: g, + } +} + +func (o *setConfigParserOption) Apply(ctx *Context) { + ctx.parseFilterConfig = o.f + ctx.parseFilterRuleOverrideConfig = o.g +} + +type filterNameOption struct { + name string +} + +func FilterName(name string) CtxOption { + return &filterNameOption{name} +} + +func (o *filterNameOption) Apply(ctx *Context) { + ctx.filterName = o.name +} + +type setJsonRpcRequestFilterOption struct { + f JsonRpcRequestFilterF +} + +func SetJsonRpcRequestFilter(f JsonRpcRequestFilterF) CtxOption { + return &setJsonRpcRequestFilterOption{f} +} + +func (o *setJsonRpcRequestFilterOption) Apply(ctx *Context) { + ctx.jsonRpcRequestFilter = o.f +} + +type setJsonRpcResponseFilterOption struct { + f JsonRpcResponseFilterF +} + +func SetJsonRpcResponseFilter(f JsonRpcResponseFilterF) CtxOption { + return &setJsonRpcResponseFilterOption{f} +} + +func (o *setJsonRpcResponseFilterOption) Apply(ctx *Context) { + ctx.jsonRpcResponseFilter = o.f +} + +type setFallbackHTTPRequestFilterOption struct { + f HTTPFilterF +} + +func SetFallbackHTTPRequestFilter(f HTTPFilterF) CtxOption { + return &setFallbackHTTPRequestFilterOption{f} +} + +func (o *setFallbackHTTPRequestFilterOption) Apply(ctx *Context) { + ctx.httpRequestFilter = o.f +} + +type setFallbackHTTPResponseFilterOption struct { + f HTTPFilterF +} + +func SetFallbackHTTPResponseFilter(f HTTPFilterF) CtxOption { + return &setFallbackHTTPResponseFilterOption{f} +} + +func (o *setFallbackHTTPResponseFilterOption) Apply(ctx *Context) { + ctx.httpResponseFilter = o.f +} + +type toolCallRequestFilterOption struct { + f ToolCallRequestFilterF +} + +func SetToolCallRequestFilter(f ToolCallRequestFilterF) CtxOption { + return &toolCallRequestFilterOption{f: f} +} + +func (o *toolCallRequestFilterOption) Apply(ctx *Context) { + ctx.toolCallRequestFilter = o.f +} + +type toolCallResponseFilterOption struct { + f ToolCallResponseFilterF +} + +func SetToolCallResponseFilter(f ToolCallResponseFilterF) CtxOption { + return &toolCallResponseFilterOption{f: f} +} + +func (o *toolCallResponseFilterOption) Apply(ctx *Context) { + ctx.toolCallResponseFilter = o.f +} + +type toolListResponseFilterOption struct { + f ToolListResponseFilterF +} + +func SetToolListResponseFilter(f ToolListResponseFilterF) CtxOption { + return &toolListResponseFilterOption{f: f} +} + +func (o *toolListResponseFilterOption) Apply(ctx *Context) { + ctx.toolListResponseFilter = o.f +} + +func Load(options ...CtxOption) { + for _, opt := range options { + opt.Apply(&globalContext) + } +} + +func Initialize() { + if globalContext.filterName == "" { + panic("FilterName not set") + } + if globalContext.parseFilterConfig == nil { + panic("SetConfigParser not set") + } + var configOption wrapper.CtxOption[mcpFilterConfig] + if globalContext.parseFilterRuleOverrideConfig == nil { + configOption = wrapper.ParseRawConfig(parseRawConfig) + } else { + configOption = wrapper.ParseOverrideRawConfig(parseGlobalConfig, parseOverrideConfig) + } + wrapper.SetCtx( + globalContext.filterName, + configOption, + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessResponseHeaders(onHttpResponseHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.ProcessResponseBody(onHttpResponseBody), + ) + +} + +type mcpFilterConfig struct { + config any + httpRequestHandler HTTPFilterF + httpResponseHandler HTTPFilterF + jsonRpcRequestHandler utils.JsonRpcRequestHandler + jsonRpcResponseHandler utils.JsonRpcResponseHandler +} + +func installHandler(config *mcpFilterConfig) { + config.httpRequestHandler = globalContext.httpRequestFilter + config.httpResponseHandler = globalContext.httpResponseFilter + bizConfig := config.config + if globalContext.jsonRpcRequestFilter != nil || globalContext.toolCallRequestFilter != nil { + config.jsonRpcRequestHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action { + if globalContext.jsonRpcRequestFilter != nil { + ret := globalContext.jsonRpcRequestFilter(context, bizConfig, id, method, params, rawBody) + if ret != types.ActionContinue { + return ret + } + } + context.SetContext("JSONRPC_METHOD", method) + if method == "tools/call" && globalContext.toolCallRequestFilter != nil { + toolName := params.Get("name").String() + toolArgs := params.Get("arguments") + return globalContext.toolCallRequestFilter(context, bizConfig, toolName, toolArgs, rawBody) + } + return types.ActionContinue + } + } + if globalContext.jsonRpcResponseFilter != nil || globalContext.toolListResponseFilter != nil || globalContext.toolCallResponseFilter != nil { + config.jsonRpcResponseHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action { + if globalContext.jsonRpcResponseFilter != nil { + ret := globalContext.jsonRpcResponseFilter(context, bizConfig, id, result, error, rawBody) + if ret != types.ActionContinue { + return ret + } + } + method := context.GetStringContext("JSONRPC_METHOD", "") + if method == "tools/list" && globalContext.toolListResponseFilter != nil { + return globalContext.toolListResponseFilter(context, bizConfig, result.Get("tools"), rawBody) + } + if method == "tools/call" && globalContext.toolCallResponseFilter != nil { + return globalContext.toolCallResponseFilter(context, bizConfig, result.Get("isError").Bool(), result.Get("content"), rawBody) + } + return types.ActionContinue + } + } + log.Debugf("installHandler called, config is: %#v", config) +} + +func parseRawConfig(configBytes []byte, config *mcpFilterConfig) error { + err := globalContext.parseFilterConfig(configBytes, &config.config) + if err != nil { + return err + } + installHandler(config) + return nil +} + +func parseGlobalConfig(configBytes []byte, config *mcpFilterConfig) error { + err := globalContext.parseFilterConfig(configBytes, &config.config) + if err != nil { + return err + } + return nil +} + +func parseOverrideConfig(configBytes []byte, global mcpFilterConfig, config *mcpFilterConfig) error { + err := globalContext.parseFilterRuleOverrideConfig(configBytes, global.config, &config.config) + if err != nil { + return err + } + installHandler(config) + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action { + log.Debugf("onHttpRequestHeaders called") + if !ctx.HasRequestBody() || (config.httpRequestHandler == nil && config.jsonRpcRequestHandler == nil) { + log.Debugf("no request body or no handler, skip reading body") + ctx.DontReadRequestBody() + return types.ActionContinue + } + log.Debugf("has request body and handler, read body") + ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) + return types.HeaderStopIteration +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action { + log.Debugf("onHttpRequestBody called, body size: %d", len(body)) + if !gjson.GetBytes(body, "jsonrpc").Exists() { + if config.httpRequestHandler != nil { + log.Debugf("body is not jsonrpc, using httpRequestHandler") + headers, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + log.Errorf("get request headers failed, err:%v", err) + return types.ActionContinue + } + return config.httpRequestHandler(ctx, config.config, headers, body) + } + log.Debugf("body is not jsonrpc, but no httpRequestHandler, skip") + return types.ActionContinue + } + log.Debugf("body is jsonrpc, using HandleJsonRpcRequest") + return utils.HandleJsonRpcRequest(ctx, body, config.jsonRpcRequestHandler) +} + +func onHttpResponseHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action { + log.Debugf("onHttpResponseHeaders called") + // IsApplicationJson checks if the content type is application/json, so we can skip reading the body if it's application/octet-stream + if !ctx.HasResponseBody() || !wrapper.IsApplicationJson() || (config.httpResponseHandler == nil && config.jsonRpcResponseHandler == nil) { + log.Debugf("no response body or no handler, skip reading body") + ctx.DontReadResponseBody() + return types.ActionContinue + } + log.Debugf("has response body and handler, read body") + ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes) + return types.HeaderStopIteration +} + +func onHttpResponseBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action { + log.Debugf("onHttpResponseBody called, body size: %d", len(body)) + if !gjson.GetBytes(body, "jsonrpc").Exists() { + if config.httpResponseHandler != nil { + log.Debugf("body is not jsonrpc, using httpResponseHandler") + headers, err := proxywasm.GetHttpResponseHeaders() + if err != nil { + log.Errorf("get response headers failed, err:%v", err) + return types.ActionContinue + } + return config.httpResponseHandler(ctx, config.config, headers, body) + } + log.Debugf("body is not jsonrpc, but no httpResponseHandler, skip") + return types.ActionContinue + } + log.Debugf("body is jsonrpc, using HandleJsonRpcResponse") + return utils.HandleJsonRpcResponse(ctx, body, config.jsonRpcResponseHandler) +} diff --git a/plugins/wasm-go/pkg/mcp/go.mod b/plugins/wasm-go/pkg/mcp/go.mod new file mode 100644 index 000000000..c4fbd3a16 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/go.mod @@ -0,0 +1,39 @@ +module github.com/alibaba/higress/plugins/wasm-go/pkg/mcp + +go 1.24.1 + +require ( + github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 + github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 + github.com/invopop/jsonschema v0.13.0 + github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + dario.cat/mergo v1.0.1 // indirect + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver/v3 v3.3.0 // indirect + github.com/Masterminds/sprig/v3 v3.3.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/huandu/xstrings v1.5.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect + github.com/spf13/cast v1.7.0 // indirect + github.com/tetratelabs/wazero v1.7.2 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/resp v0.1.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + golang.org/x/crypto v0.26.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/plugins/wasm-go/pkg/mcp/mcp.go b/plugins/wasm-go/pkg/mcp/mcp.go new file mode 100644 index 000000000..09d04a7f9 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/mcp.go @@ -0,0 +1,84 @@ +package mcp + +import ( + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/filter" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" +) + +var _ server.Server = &MCPServer{} + +// MCPServer implements the Server interface using BaseMCPServer +type MCPServer struct { + base server.BaseMCPServer +} + +// NewMCPServer creates a new MCPServer +func NewMCPServer() *MCPServer { + return &MCPServer{ + base: server.NewBaseMCPServer(), + } +} + +// Clone implements Server interface +func (s *MCPServer) Clone() server.Server { + return &MCPServer{ + base: s.base.CloneBase(), + } +} + +// AddMCPTool implements Server interface +func (s *MCPServer) AddMCPTool(name string, tool server.Tool) server.Server { + s.base.AddMCPTool(name, tool) + return s +} + +// GetConfig implements Server interface +func (s *MCPServer) GetConfig(v any) { + s.base.GetConfig(v) +} + +// GetMCPTools implements Server interface +func (s *MCPServer) GetMCPTools() map[string]server.Tool { + return s.base.GetMCPTools() +} + +// SetConfig implements Server interface +func (s *MCPServer) SetConfig(config []byte) { + s.base.SetConfig(config) +} + +// mcp server function +var ( + LoadMCPServer = server.Load + + InitMCPServer = server.Initialize + + AddMCPServer = server.AddMCPServer +) + +// mcp filter function +var ( + LoadMCPFilter = filter.Load + + InitMCPFilter = filter.Initialize + + SetConfigParser = filter.SetConfigParser + + SetConfigOverrideParser = filter.SetConfigOverrideParser + + FilterName = filter.FilterName + + SetJsonRpcRequestFilter = filter.SetJsonRpcRequestFilter + + SetJsonRpcResponseFilter = filter.SetJsonRpcResponseFilter + + SetFallbackHTTPRequestFilter = filter.SetFallbackHTTPRequestFilter + + SetFallbackHTTPResponseFilter = filter.SetFallbackHTTPResponseFilter + + SetToolCallRequestFilter = filter.SetToolCallRequestFilter + + SetToolCallResponseFilter = filter.SetToolCallResponseFilter + + SetToolListResponseFilter = filter.SetToolListResponseFilter +) diff --git a/plugins/wasm-go/pkg/mcp/server/auth_utils.go b/plugins/wasm-go/pkg/mcp/server/auth_utils.go new file mode 100644 index 000000000..e0dfca35e --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/auth_utils.go @@ -0,0 +1,232 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/base64" + "errors" + "fmt" + "net/url" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" +) + +// setOrReplaceHeader sets or replaces a header in the headers slice. +// If the header exists (case-insensitive comparison), it replaces the value. +// If the header doesn't exist, it appends a new header. +func setOrReplaceHeader(headers *[][2]string, key, value string) { + lowerKey := strings.ToLower(key) + + // Check if header already exists + for i, header := range *headers { + if strings.ToLower(header[0]) == lowerKey { + // Replace existing header value + (*headers)[i][1] = value + return + } + } + + // Header doesn't exist, append new one + *headers = append(*headers, [2]string{key, value}) +} + +// SecurityScheme defines a security scheme for the REST API +type SecurityScheme struct { + ID string `json:"id"` + Type string `json:"type"` // http, apiKey + Scheme string `json:"scheme,omitempty"` // basic, bearer (for type: http) + In string `json:"in,omitempty"` // header, query (for type: apiKey) + Name string `json:"name,omitempty"` // Header or query parameter name (for type: apiKey) + DefaultCredential string `json:"defaultCredential,omitempty"` +} + +// SecurityRequirement specifies a security scheme requirement for a tool +type SecurityRequirement struct { + ID string `json:"id"` // References a security scheme ID + Credential string `json:"credential,omitempty"` // Overrides default credential + Passthrough bool `json:"passthrough,omitempty"` // If true, credentials from client request will be passed through +} + +// AuthRequestContext holds the data needed for applying security schemes. +type AuthRequestContext struct { + Method string + Headers [][2]string // Direct slice, modifications within applySecurity will update this field in the struct instance + ParsedURL *url.URL // Pointer to allow modification (e.g., RawQuery) + RequestBody []byte // For future security types that might inspect the body + PassthroughCredential string // Credential extracted from client request for passthrough +} + +// SecuritySchemeProvider provides access to security schemes +type SecuritySchemeProvider interface { + GetSecurityScheme(id string) (SecurityScheme, bool) +} + +// ExtractAndRemoveIncomingCredential extracts a credential from the current incoming HTTP request +// and removes it. It uses global proxywasm functions to access request details. +// For query parameters, "removal" is conceptual as we build a new request; +// this function primarily extracts the value for potential passthrough. +func ExtractAndRemoveIncomingCredential(scheme SecurityScheme) (string, error) { + credentialValue := "" + var err error + + switch scheme.Type { + case "http": + authHeader, _ := proxywasm.GetHttpRequestHeader("Authorization") // Error ignored, check content + if authHeader == "" { + // If no header, it's not an error for extraction if not required, but indicates not found. + // For removal, there's nothing to remove. + return "", nil // Or a specific "not found" error if scheme implies it must be there. + } + + if scheme.Scheme == "bearer" { + if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return "", fmt.Errorf("incoming Authorization header is not Bearer auth: %s", authHeader) + } + credentialValue = strings.TrimSpace(authHeader[len("Bearer "):]) + } else if scheme.Scheme == "basic" { + if !strings.HasPrefix(strings.ToLower(authHeader), "basic ") { + return "", fmt.Errorf("incoming Authorization header is not Basic auth: %s", authHeader) + } + credentialValue = strings.TrimSpace(authHeader[len("Basic "):]) + } else { + return "", fmt.Errorf("unsupported http scheme for credential extraction/removal: %s", scheme.Scheme) + } + proxywasm.RemoveHttpRequestHeader("Authorization") + log.Debugf("Extracted and removed Authorization header for incoming %s scheme.", scheme.Scheme) + + case "apiKey": + if scheme.In == "header" { + if scheme.Name == "" { + return "", errors.New("apiKey in header requires a name for the header") + } + headerValue, _ := proxywasm.GetHttpRequestHeader(scheme.Name) // Error ignored, check content + if headerValue == "" { + return "", nil // Not found, not necessarily an error for extraction. + } + credentialValue = headerValue + proxywasm.RemoveHttpRequestHeader(scheme.Name) + log.Debugf("Extracted and removed %s header for incoming apiKey auth.", scheme.Name) + } else if scheme.In == "query" { + if scheme.Name == "" { + return "", errors.New("apiKey in query requires a name for the query parameter") + } + pathHeader, _ := proxywasm.GetHttpRequestHeader(":path") // Error ignored, check content + if pathHeader == "" { + // This case might be an error as :path should generally exist. + return "", fmt.Errorf("no :path header found in incoming request for apiKey in query") + } + + requestURL, parseErr := url.Parse(pathHeader) + if parseErr != nil { + return "", fmt.Errorf("failed to parse incoming :path header '%s': %v", pathHeader, parseErr) + } + + queryValues := requestURL.Query() + apiKeyValue := queryValues.Get(scheme.Name) + if apiKeyValue == "" { + return "", nil // Not found + } + credentialValue = apiKeyValue + log.Debugf("Extracted %s query parameter from incoming request. Removal from original :path is implicit.", scheme.Name) + } else { + return "", fmt.Errorf("unsupported apiKey 'in' value: %s", scheme.In) + } + default: + return "", fmt.Errorf("unsupported security scheme type for credential extraction/removal: %s", scheme.Type) + } + + return credentialValue, err +} + +// ApplySecurity applies the configured security scheme to the request. +// It modifies reqCtx.Headers and reqCtx.ParsedURL (specifically RawQuery) in place if necessary. +func ApplySecurity(securityConfig SecurityRequirement, provider SecuritySchemeProvider, reqCtx *AuthRequestContext) error { + if securityConfig.ID == "" { + return nil // No security scheme defined + } + if reqCtx.ParsedURL == nil { + return errors.New("ParsedURL in AuthRequestContext cannot be nil for ApplySecurity") + } + + upstreamScheme, schemeOk := provider.GetSecurityScheme(securityConfig.ID) + if !schemeOk { + return fmt.Errorf("upstream security scheme with id '%s' not found", securityConfig.ID) + } + + var credentialToUse string + if reqCtx.PassthroughCredential != "" { + // Use the passthrough credential value. + // The upstreamScheme dictates how this value is formatted and applied. + credentialToUse = reqCtx.PassthroughCredential + log.Debugf("Using passthrough credential for upstream request with scheme %s.", upstreamScheme.ID) + } else { + // Use configured credential for the upstream request. + credentialToUse = upstreamScheme.DefaultCredential + if securityConfig.Credential != "" { + credentialToUse = securityConfig.Credential + } + if credentialToUse == "" { + return fmt.Errorf("no credential found or configured for upstream security scheme '%s'", upstreamScheme.ID) + } + log.Debugf("Using configured credential for upstream request with scheme %s.", upstreamScheme.ID) + } + + switch upstreamScheme.Type { + case "http": + authValue := credentialToUse + if upstreamScheme.Scheme == "basic" { + if !strings.HasPrefix(authValue, "Basic ") { + if reqCtx.PassthroughCredential != "" { // Came from passthrough, it's the base64 token part + authValue = "Basic " + credentialToUse + } else { // Came from config + if strings.Contains(credentialToUse, ":") { // Assumed to be "user:pass" + authValue = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentialToUse)) + } else { // Assumed to be already base64 encoded string (token part) + authValue = "Basic " + credentialToUse + } + } + } + } else if upstreamScheme.Scheme == "bearer" { + // Passthrough for Bearer gives the token part. Configured credential is the token. + if !strings.HasPrefix(authValue, "Bearer ") { + authValue = "Bearer " + credentialToUse + } + } else { + return fmt.Errorf("unsupported http scheme type for upstream: %s", upstreamScheme.Scheme) + } + setOrReplaceHeader(&reqCtx.Headers, "Authorization", authValue) + case "apiKey": + if upstreamScheme.In == "header" { + if upstreamScheme.Name == "" { + return errors.New("apiKey in header requires a name for the header for upstream") + } + setOrReplaceHeader(&reqCtx.Headers, upstreamScheme.Name, credentialToUse) + } else if upstreamScheme.In == "query" { + if upstreamScheme.Name == "" { + return errors.New("apiKey in query requires a name for the query parameter for upstream") + } + queryValues := reqCtx.ParsedURL.Query() + queryValues.Set(upstreamScheme.Name, credentialToUse) + reqCtx.ParsedURL.RawQuery = queryValues.Encode() + } else { + return fmt.Errorf("unsupported apiKey 'in' value for upstream: %s", upstreamScheme.In) + } + default: + return fmt.Errorf("unsupported security scheme type: %s", upstreamScheme.Type) + } + return nil +} diff --git a/plugins/wasm-go/pkg/mcp/server/base_server.go b/plugins/wasm-go/pkg/mcp/server/base_server.go new file mode 100644 index 000000000..81e3b28f0 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/base_server.go @@ -0,0 +1,100 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package server + +import ( + "encoding/base64" + "encoding/json" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + + "github.com/higress-group/wasm-go/pkg/log" +) + +// BaseMCPServer provides common functionality for MCP servers +type BaseMCPServer struct { + tools map[string]Tool + config []byte +} + +// NewBaseMCPServer creates a new BaseMCPServer +func NewBaseMCPServer() BaseMCPServer { + return BaseMCPServer{ + tools: make(map[string]Tool), + } +} + +// AddMCPTool adds a tool to the server +func (s *BaseMCPServer) AddMCPTool(name string, tool Tool) Server { + if _, exist := s.tools[name]; exist { + log.Errorf("Conflict! There is a tool with the same name:%s", name) + return s + } + s.tools[name] = tool + return s +} + +// GetMCPTools returns all tools registered with the server +func (s *BaseMCPServer) GetMCPTools() map[string]Tool { + return s.tools +} + +// SetConfig sets the server configuration +func (s *BaseMCPServer) SetConfig(config []byte) { + s.config = config +} + +// GetConfig gets the server configuration +// It first tries to get the config from the request header, then falls back to the stored config +func (s *BaseMCPServer) GetConfig(v any) { + var config []byte + serverConfigBase64, _ := proxywasm.GetHttpRequestHeader("x-higress-mcpserver-config") + proxywasm.RemoveHttpRequestHeader("x-higress-mcpserver-config") + if serverConfigBase64 != "" { + serverConfig, err := base64.StdEncoding.DecodeString(serverConfigBase64) + if err != nil { + log.Errorf("base64 decode mcp server config failed:%s, bytes:%s", err, serverConfigBase64) + } else { + config = serverConfig + } + log.Infof("parse server config from request, config:%s", serverConfig) + } else { + config = s.config + } + if len(config) == 0 { + return + } + err := json.Unmarshal(config, v) + if err != nil { + log.Errorf("json unmarshal server config failed:%v, config:%s", err, config) + } +} + +// Clone creates a copy of the server +// This method should be overridden by derived types +func (s *BaseMCPServer) Clone() Server { + panic("Clone method must be implemented by derived types") +} + +// CloneBase creates a copy of the base server +func (s *BaseMCPServer) CloneBase() BaseMCPServer { + newServer := BaseMCPServer{ + tools: make(map[string]Tool), + config: s.config, + } + for k, v := range s.tools { + newServer.tools[k] = v + } + return newServer +} diff --git a/plugins/wasm-go/pkg/mcp/server/composed_server.go b/plugins/wasm-go/pkg/mcp/server/composed_server.go new file mode 100644 index 000000000..b2e7b7f75 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/composed_server.go @@ -0,0 +1,127 @@ +package server + +import ( + "fmt" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts" +) + +// ComposedMCPServer represents a server composed of tools from other servers. +type ComposedMCPServer struct { + name string // Name of the composed server (from toolSet.name) + serverTools []ServerToolConfig // Configuration of which tools to include + registry *GlobalToolRegistry // Reference to the global tool registry + config []byte // Configuration for the composed server itself (if any) +} + +// NewComposedMCPServer creates a new ComposedMCPServer. +func NewComposedMCPServer(name string, serverToolsConfig []ServerToolConfig, registry *GlobalToolRegistry) *ComposedMCPServer { + return &ComposedMCPServer{ + name: name, + serverTools: serverToolsConfig, + registry: registry, + } +} + +// GetName returns the name of the composed server. +func (cs *ComposedMCPServer) GetName() string { + return cs.name +} + +// AddMCPTool for ComposedMCPServer is a no-op as tools are defined by toolSet. +func (cs *ComposedMCPServer) AddMCPTool(name string, tool Tool) Server { + log.Warnf("AddMCPTool called on ComposedMCPServer '%s'; this is a no-op.", cs.name) + return cs +} + +// GetMCPTools constructs and returns the map of tools exposed by this composed server. +// The tool names are prefixed with their original server name, e.g., "${originalServer}___${toolName}". +// The Tool instances are DescriptiveTool, only providing Description and InputSchema. +func (cs *ComposedMCPServer) GetMCPTools() map[string]Tool { + composedTools := make(map[string]Tool) + for _, stc := range cs.serverTools { + originalServerName := stc.ServerName + for _, originalToolName := range stc.Tools { + toolInfo, found := cs.registry.GetToolInfo(originalServerName, originalToolName) + if !found { + log.Warnf("Tool %s/%s not found in global registry for composed server %s", originalServerName, originalToolName, cs.name) + continue + } + + composedToolName := fmt.Sprintf("%s%s%s", originalServerName, consts.ToolSetNameSplitter, originalToolName) + composedTools[composedToolName] = &DescriptiveTool{ + description: toolInfo.Description, + inputSchema: toolInfo.InputSchema, + outputSchema: toolInfo.OutputSchema, // New field for MCP Protocol Version 2025-06-18 + } + } + } + return composedTools +} + +// SetConfig sets the configuration for the composed server itself. +func (cs *ComposedMCPServer) SetConfig(config []byte) { + cs.config = config +} + +// GetConfig retrieves the configuration of the composed server itself. +func (cs *ComposedMCPServer) GetConfig(v any) { + if len(cs.config) == 0 { + return + } + if ptrBytes, ok := v.(*[]byte); ok { + *ptrBytes = cs.config + } else { + // If you need to unmarshal to a struct, you'd do it here. + // For now, keeping it simple as per previous discussions. + log.Warnf("ComposedMCPServer.GetConfig called with unhandled type for v. Config not set.") + } +} + +// Clone creates a new instance of the ComposedMCPServer with the same configuration. +func (cs *ComposedMCPServer) Clone() Server { + cloned := NewComposedMCPServer(cs.name, cs.serverTools, cs.registry) + cloned.SetConfig(cs.config) + return cloned +} + +// DescriptiveTool is a placeholder Tool implementation for ComposedMCPServer. +// Its Call and Create methods should never be invoked. +type DescriptiveTool struct { + description string + inputSchema map[string]any + outputSchema map[string]any // New field for MCP Protocol Version 2025-06-18 +} + +// Create for DescriptiveTool should not be called. +func (dt *DescriptiveTool) Create(params []byte) Tool { + log.Errorf("DescriptiveTool.Create called for tool used in ComposedMCPServer. This should not happen.") + // Return a new instance to fulfill the interface, though it's an error state. + return &DescriptiveTool{ + description: dt.description, + inputSchema: dt.inputSchema, + outputSchema: dt.outputSchema, + } +} + +// Call for DescriptiveTool should not be called. +func (dt *DescriptiveTool) Call(httpCtx HttpContext, server Server) error { + log.Errorf("DescriptiveTool.Call called for tool used in ComposedMCPServer. This should not happen.") + return fmt.Errorf("DescriptiveTool.Call should not be invoked on a ComposedMCPServer's tool") +} + +// Description returns the tool's description. +func (dt *DescriptiveTool) Description() string { + return dt.description +} + +// InputSchema returns the tool's input schema. +func (dt *DescriptiveTool) InputSchema() map[string]any { + return dt.inputSchema +} + +// OutputSchema returns the tool's output schema (MCP Protocol Version 2025-06-18). +func (dt *DescriptiveTool) OutputSchema() map[string]any { + return dt.outputSchema +} diff --git a/plugins/wasm-go/pkg/mcp/server/config_validator_test.go b/plugins/wasm-go/pkg/mcp/server/config_validator_test.go new file mode 100644 index 000000000..fd579d268 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/config_validator_test.go @@ -0,0 +1,328 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "os" + "testing" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" +) + +// testLogger is a mock logger for testing to prevent panics +type testLogger struct{} + +func (l *testLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) } +func (l *testLogger) Tracef(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...) +} +func (l *testLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) } +func (l *testLogger) Debugf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...) +} +func (l *testLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) } +func (l *testLogger) Infof(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...) +} +func (l *testLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) } +func (l *testLogger) Warnf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...) +} +func (l *testLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) } +func (l *testLogger) Errorf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...) +} +func (l *testLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) } +func (l *testLogger) Criticalf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...) +} +func (l *testLogger) ResetID(pluginID string) {} + +func init() { + // Set a custom logger for testing to prevent panics + log.SetPluginLog(&testLogger{}) +} + +// TestMcpProxyConfigValidation tests configuration validation for mcp-proxy servers +func TestMcpProxyConfigValidation(t *testing.T) { + tests := []struct { + name string + config string + shouldErr bool + errMsg string + }{ + { + name: "valid basic proxy config", + config: `{ + "server": { + "name": "test-proxy", + "type": "mcp-proxy", + "transport": "http", + "mcpServerURL": "http://backend.example.com/mcp", + "timeout": 5000 + }, + "tools": [ + { + "name": "test-tool", + "description": "Test tool", + "args": [ + { + "name": "input", + "description": "Input parameter", + "type": "string", + "required": true + } + ] + } + ] + }`, + shouldErr: false, + }, + { + name: "proxy config with security schemes", + config: `{ + "server": { + "name": "secure-proxy", + "type": "mcp-proxy", + "transport": "http", + "mcpServerURL": "https://secure.example.com/mcp", + "timeout": 8000, + "securitySchemes": [ + { + "id": "ApiKeyAuth", + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + "defaultCredential": "test-key" + } + ] + }, + "tools": [ + { + "name": "secure-tool", + "description": "Secure tool", + "args": [ + { + "name": "data", + "description": "Data parameter", + "type": "object", + "required": true + } + ], + "requestTemplate": { + "security": { + "id": "ApiKeyAuth" + } + } + } + ] + }`, + shouldErr: false, + }, + { + name: "missing mcpServerURL should fail", + config: `{ + "server": { + "name": "invalid-proxy", + "type": "mcp-proxy", + "transport": "http", + "timeout": 5000 + }, + "tools": [ + { + "name": "test-tool", + "description": "Test tool", + "args": [] + } + ] + }`, + shouldErr: true, + errMsg: "mcpServerURL is required", + }, + { + name: "invalid server type should use default REST handling", + config: `{ + "server": { + "name": "rest-server", + "type": "rest-api" + }, + "tools": [ + { + "name": "rest-tool", + "description": "REST tool", + "args": [], + "requestTemplate": { + "url": "http://example.com/api", + "method": "GET" + }, + "responseTemplate": { + "body": "$.result" + } + } + ] + }`, + shouldErr: false, // Should fall back to REST server logic + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configJson := gjson.Parse(tt.config) + config := &McpServerConfig{} + + // Create validation options (similar to validator package) + toolRegistry := &GlobalToolRegistry{} + toolRegistry.Initialize() + + opts := &ConfigOptions{ + Servers: make(map[string]Server), + ToolRegistry: toolRegistry, + SkipPreRegisteredServers: true, + } + + err := ParseConfigCore(configJson, config, opts) + + if tt.shouldErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + } + }) + } +} + +// TestSecuritySchemeValidation tests security scheme configuration validation +func TestSecuritySchemeValidation(t *testing.T) { + tests := []struct { + name string + scheme SecurityScheme + shouldErr bool + }{ + { + name: "valid API key scheme", + scheme: SecurityScheme{ + ID: "ApiKeyAuth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + }, + shouldErr: false, + }, + { + name: "valid HTTP bearer scheme", + scheme: SecurityScheme{ + ID: "BearerAuth", + Type: "http", + Scheme: "bearer", + }, + shouldErr: false, + }, + { + name: "invalid scheme - missing ID", + scheme: SecurityScheme{ + Type: "apiKey", + In: "header", + Name: "X-API-Key", + }, + shouldErr: true, + }, + { + name: "invalid scheme - missing Name for apiKey", + scheme: SecurityScheme{ + ID: "ApiKeyAuth", + Type: "apiKey", + In: "header", + }, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This will test the validation logic once SecurityScheme validation is implemented + err := ValidateSecurityScheme(tt.scheme) + + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestToolConfigValidation tests tool configuration validation +func TestToolConfigValidation(t *testing.T) { + tests := []struct { + name string + toolCfg McpProxyToolConfig + shouldErr bool + }{ + { + name: "valid tool config", + toolCfg: McpProxyToolConfig{ + Name: "valid-tool", + Description: "A valid tool", + Args: []ToolArg{ + { + Name: "param1", + Description: "Parameter 1", + Type: "string", + Required: true, + }, + }, + }, + shouldErr: false, + }, + { + name: "invalid tool - missing name", + toolCfg: McpProxyToolConfig{ + Description: "Tool without name", + Args: []ToolArg{}, + }, + shouldErr: true, + }, + { + name: "invalid tool - empty description", + toolCfg: McpProxyToolConfig{ + Name: "tool-no-desc", + Description: "", + Args: []ToolArg{}, + }, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateToolConfig(tt.toolCfg) + + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// These validation functions are now implemented in proxy_server.go diff --git a/plugins/wasm-go/pkg/mcp/server/plugin.go b/plugins/wasm-go/pkg/mcp/server/plugin.go new file mode 100644 index 000000000..765e9f1bc --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/plugin.go @@ -0,0 +1,817 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "reflect" + "slices" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/invopop/jsonschema" + "github.com/tidwall/gjson" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +const ( + DefaultMaxBodyBytes uint32 = 100 * 1024 * 1024 + GlobalToolRegistryKey = "GlobalToolRegistry" +) + +// SupportedMCPVersions contains all supported MCP protocol versions +var SupportedMCPVersions = []string{"2024-11-05", "2025-03-26", "2025-06-18"} + +// validateURL validates that the given string is a valid URL +func validateURL(urlStr string) error { + if urlStr == "" { + return errors.New("url cannot be empty") + } + + parsedURL, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + // Allow both full URLs (with scheme and host) and path-only URLs + // Path-only URLs will be resolved against the cluster's base URL + if parsedURL.Scheme != "" { + // If scheme is provided, host must also be provided + if parsedURL.Host == "" { + return errors.New("url with scheme must include a host") + } + + // Only allow http and https schemes for security + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("unsupported URL scheme '%s', only http and https are allowed", parsedURL.Scheme) + } + } + + return nil +} + +// setupMcpProxyServer creates and configures an MCP proxy server +func setupMcpProxyServer(serverName string, serverJson gjson.Result, serverConfigJsonForInstance string) (*McpProxyServer, error) { + proxyServer := NewMcpProxyServer(serverName) + proxyServer.SetConfig([]byte(serverConfigJsonForInstance)) + + // Parse and validate transport (required for mcp-proxy) + transportStr := serverJson.Get("transport").String() + if transportStr == "" { + return nil, errors.New("transport field is required for mcp-proxy server type") + } + transport := TransportProtocol(transportStr) + if transport != TransportHTTP && transport != TransportSSE { + return nil, fmt.Errorf("invalid transport value: %s, must be 'http' or 'sse'", transportStr) + } + proxyServer.SetTransport(transport) + + // Parse and validate mcpServerURL (required for mcp-proxy) + mcpServerURL := serverJson.Get("mcpServerURL").String() + if mcpServerURL == "" { + return nil, errors.New("mcpServerURL is required for mcp-proxy server type") + } + if err := validateURL(mcpServerURL); err != nil { + return nil, fmt.Errorf("invalid mcpServerURL: %v", err) + } + proxyServer.SetMcpServerURL(mcpServerURL) + + // Parse timeout (optional) + timeout := serverJson.Get("timeout").Int() + if timeout > 0 { + proxyServer.SetTimeout(int(timeout)) + } + + // Parse passthroughAuthHeader (optional, defaults to false) + passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool() + proxyServer.SetPassthroughAuthHeader(passthroughAuthHeader) + + // Parse security schemes + securitySchemesJson := serverJson.Get("securitySchemes") + if securitySchemesJson.Exists() { + for _, schemeJson := range securitySchemesJson.Array() { + var scheme SecurityScheme + if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil { + return nil, fmt.Errorf("failed to parse security scheme config: %v", err) + } + proxyServer.AddSecurityScheme(scheme) + } + } + + // Parse default downstream security + defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity") + if defaultDownstreamSecurityJson.Exists() { + var defaultDownstreamSecurity SecurityRequirement + if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil { + return nil, fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err) + } + proxyServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity) + } + + // Parse default upstream security + defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity") + if defaultUpstreamSecurityJson.Exists() { + var defaultUpstreamSecurity SecurityRequirement + if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil { + return nil, fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err) + } + proxyServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity) + } + + return proxyServer, nil +} + +type HttpContext wrapper.HttpContext + +type Context struct { + servers map[string]Server +} + +type CtxOption interface { + Apply(*Context) +} + +var globalContext Context + +// ToolInfo stores information about a tool for the global registry. +type ToolInfo struct { + Name string + Description string + InputSchema map[string]any + OutputSchema map[string]any // New field for MCP Protocol Version 2025-06-18 + ServerName string // Original server name + Tool Tool // The actual tool instance for cloning +} + +// GlobalToolRegistry holds all tools from all servers. +type GlobalToolRegistry struct { + // serverName -> toolName -> toolInfo + serverTools map[string]map[string]ToolInfo +} + +// Initialize initializes the GlobalToolRegistry +func (r *GlobalToolRegistry) Initialize() { + r.serverTools = make(map[string]map[string]ToolInfo) +} + +// RegisterTool registers a tool into the global registry. +func (r *GlobalToolRegistry) RegisterTool(serverName string, toolName string, tool Tool) { + if _, ok := r.serverTools[serverName]; !ok { + r.serverTools[serverName] = make(map[string]ToolInfo) + } + toolInfo := ToolInfo{ + Name: toolName, + Description: tool.Description(), + InputSchema: tool.InputSchema(), + ServerName: serverName, + Tool: tool, + } + // Check if tool implements OutputSchema (MCP Protocol Version 2025-06-18) + if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok { + toolInfo.OutputSchema = toolWithSchema.OutputSchema() + } + r.serverTools[serverName][toolName] = toolInfo + log.Debugf("Registered tool %s/%s", serverName, toolName) +} + +// GetToolInfo retrieves tool information from the global registry. +func (r *GlobalToolRegistry) GetToolInfo(serverName string, toolName string) (ToolInfo, bool) { + if serverTools, ok := r.serverTools[serverName]; ok { + toolInfo, found := serverTools[toolName] + return toolInfo, found + } + return ToolInfo{}, false +} + +func onPluginStartOrReload(context wrapper.PluginContext) error { + toolRegistry := &GlobalToolRegistry{} + toolRegistry.Initialize() + context.SetContext(GlobalToolRegistryKey, toolRegistry) + context.EnableRuleLevelConfigIsolation() + return nil +} + +// GetServer retrieves a server instance from the global context. +// This is needed by ComposedMCPServer to get original server instances. +func GetServerFromGlobalContext(serverName string) (Server, bool) { + server, exist := globalContext.servers[serverName] + return server, exist +} + +type Server interface { + AddMCPTool(name string, tool Tool) Server + GetMCPTools() map[string]Tool // For single server, returns its tools. For composed, returns composed tools. + SetConfig(config []byte) + GetConfig(v any) + Clone() Server + // GetName() string // Returns the server name - REMOVED +} + +type Tool interface { + Create(params []byte) Tool + Call(httpCtx HttpContext, server Server) error + Description() string + InputSchema() map[string]any +} + +// ToolWithOutputSchema is an optional interface for tools that support output schema +// (MCP Protocol Version 2025-06-18). Tools can optionally implement this interface +// to provide output schema information. +type ToolWithOutputSchema interface { + Tool + OutputSchema() map[string]any +} + +// ToolSetConfig defines the configuration for a toolset. +type ToolSetConfig struct { + Name string `json:"name"` + ServerTools []ServerToolConfig `json:"serverTools"` +} + +// ServerToolConfig specifies which tools from a server to include in a toolset. +type ServerToolConfig struct { + ServerName string `json:"serverName"` + Tools []string `json:"tools"` +} + +// ConfigOptions contains the dependencies needed for config parsing +type ConfigOptions struct { + Servers map[string]Server + ToolRegistry *GlobalToolRegistry + // Skip validation for pre-registered Go-based servers + SkipPreRegisteredServers bool +} + +type McpServerConfig struct { + serverName string // Store the server name directly + server Server // Can be a single server or a composed server + methodHandlers utils.MethodHandlers + toolSet *ToolSetConfig // Parsed toolset configuration + isComposed bool +} + +// GetServerName returns the server name for external access +func (c *McpServerConfig) GetServerName() string { + return c.serverName +} + +// GetIsComposed returns whether this is a composed server for external access +func (c *McpServerConfig) GetIsComposed() bool { + return c.isComposed +} + +// computeEffectiveAllowTools computes the effective allowTools by taking the intersection +// of config allowTools and request header allowTools. +// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set. +func computeEffectiveAllowTools(configAllowTools *map[string]struct{}) *map[string]struct{} { + // Get allowTools from request header + allowToolsHeaderStr, _ := proxywasm.GetHttpRequestHeader("x-envoy-allow-mcp-tools") + proxywasm.RemoveHttpRequestHeader("x-envoy-allow-mcp-tools") + // Only consider header as "present" if it has non-empty value + // Empty string means header is not set or explicitly empty, both treated as "no restriction" + headerExists := allowToolsHeaderStr != "" + return computeEffectiveAllowToolsFromHeader(configAllowTools, allowToolsHeaderStr, headerExists) +} + +// computeEffectiveAllowToolsFromHeader computes the effective allowTools by taking the intersection +// of config allowTools and header allowTools string. +// This is useful when the header string is already extracted (e.g., in async callbacks). +// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set. +func computeEffectiveAllowToolsFromHeader(configAllowTools *map[string]struct{}, allowToolsHeaderStr string, headerExists bool) *map[string]struct{} { + var allowToolsFromHeader *map[string]struct{} + if headerExists { + // Header is present (even if empty string), parse it + headerMap := make(map[string]struct{}) + for tool := range strings.SplitSeq(allowToolsHeaderStr, ",") { + trimmedTool := strings.TrimSpace(tool) + if trimmedTool == "" { + continue + } + headerMap[trimmedTool] = struct{}{} + } + // Always create pointer even if map is empty, to distinguish from "not configured" + allowToolsFromHeader = &headerMap + } + + // Compute effective allowTools (intersection of config and header) + if configAllowTools == nil && allowToolsFromHeader == nil { + // Both not configured, allow all tools + return nil + } else if configAllowTools == nil { + // Only header restrictions + return allowToolsFromHeader + } else if allowToolsFromHeader == nil { + // Only config restrictions + return configAllowTools + } else { + // Both restrictions exist, compute intersection + intersection := make(map[string]struct{}) + for tool := range *configAllowTools { + if _, exists := (*allowToolsFromHeader)[tool]; exists { + intersection[tool] = struct{}{} + } + } + return &intersection + } +} + +// parseConfigCore contains the core config parsing logic with dependency injection +func parseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error { + toolSetJson := configJson.Get("toolSet") + serverJson := configJson.Get("server") // This is for single server or REST server definition + pluginServerConfigJson := configJson.Get("server.config").Raw // Config for the plugin instance itself, if any. + + // serverConfigJsonForInstance is the config passed to the specific server instance (single or REST) + // It's distinct from pluginServerConfigJson which might be for the mcp-server plugin itself. + var serverConfigJsonForInstance string + + if toolSetJson.Exists() { + config.isComposed = true + var tsConfig ToolSetConfig + if err := json.Unmarshal([]byte(toolSetJson.Raw), &tsConfig); err != nil { + return fmt.Errorf("failed to parse toolSet config: %v", err) + } + config.toolSet = &tsConfig + config.serverName = tsConfig.Name // Use toolSet name as the server name for composed server + log.Infof("Parsing toolSet configuration: %s", config.serverName) + + composedServer := NewComposedMCPServer(config.serverName, tsConfig.ServerTools, opts.ToolRegistry) + // A composed server itself might have a config block, e.g. for shared settings, though not typical. + composedServer.SetConfig([]byte(pluginServerConfigJson)) + config.server = composedServer + } else if serverJson.Exists() { + config.isComposed = false + config.serverName = serverJson.Get("name").String() + if config.serverName == "" { + return errors.New("server.name field is missing for single server config") + } + // This is the config for the specific server being defined (e.g. REST server's own config) + serverConfigJsonForInstance = serverJson.Get("config").Raw + log.Infof("Parsing single server configuration: %s", config.serverName) + + // Check server type to determine which type of server to create + serverType := serverJson.Get("type").String() + if serverType == "" { + serverType = "rest" // Default to REST server type + } + + toolsJson := configJson.Get("tools") // These are REST tools for this server instance or MCP proxy tools + + if serverType == "mcp-proxy" { + // Create MCP proxy server + proxyServer, err := setupMcpProxyServer(config.serverName, serverJson, serverConfigJsonForInstance) + if err != nil { + return err + } + + // Handle tools configuration (optional for MCP proxy) + if toolsJson.Exists() && len(toolsJson.Array()) > 0 { + for _, toolJson := range toolsJson.Array() { + var proxyTool McpProxyToolConfig + if err := json.Unmarshal([]byte(toolJson.Raw), &proxyTool); err != nil { + return fmt.Errorf("failed to parse proxy tool config: %v", err) + } + + if err := proxyServer.AddProxyTool(proxyTool); err != nil { + return fmt.Errorf("failed to add proxy tool %s: %v", proxyTool.Name, err) + } + // Register tool to registry + opts.ToolRegistry.RegisterTool(config.serverName, proxyTool.Name, proxyServer.GetMCPTools()[proxyTool.Name]) + } + } + // Set the proxy server regardless of whether tools are configured + config.server = proxyServer + } else if toolsJson.Exists() && len(toolsJson.Array()) > 0 { + // Handle REST-to-MCP server (requires tools configuration) + // Create REST-to-MCP server (default behavior) + restServer := NewRestMCPServer(config.serverName) // Pass the server name + restServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config + + securitySchemesJson := serverJson.Get("securitySchemes") + if securitySchemesJson.Exists() { + for _, schemeJson := range securitySchemesJson.Array() { + var scheme SecurityScheme + if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil { + return fmt.Errorf("failed to parse security scheme config: %v", err) + } + restServer.AddSecurityScheme(scheme) + } + } + + // Parse default downstream security + defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity") + if defaultDownstreamSecurityJson.Exists() { + var defaultDownstreamSecurity SecurityRequirement + if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil { + return fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err) + } + restServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity) + } + + // Parse default upstream security + defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity") + if defaultUpstreamSecurityJson.Exists() { + var defaultUpstreamSecurity SecurityRequirement + if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil { + return fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err) + } + restServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity) + } + + // Parse passthroughAuthHeader (optional, defaults to false) + passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool() + restServer.SetPassthroughAuthHeader(passthroughAuthHeader) + + for _, toolJson := range toolsJson.Array() { + var restTool RestTool + if err := json.Unmarshal([]byte(toolJson.Raw), &restTool); err != nil { + return fmt.Errorf("failed to parse tool config: %v", err) + } + + if err := restServer.AddRestTool(restTool); err != nil { + return fmt.Errorf("failed to add tool %s: %v", restTool.Name, err) + } + // Register tool to registry + opts.ToolRegistry.RegisterTool(config.serverName, restTool.Name, restServer.GetMCPTools()[restTool.Name]) + } + config.server = restServer + } else { + // Logic for pre-registered Go-based servers (non-REST) + if opts.SkipPreRegisteredServers { + // In validation mode, skip pre-registered servers validation + // Just validate the basic structure without actual server instance + config.server = nil // Will be handled appropriately in validation context + } else { + if serverInstance, exist := opts.Servers[config.serverName]; exist { + clonedServer := serverInstance.Clone() + clonedServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config + config.server = clonedServer + // Register tools from this server to registry + for toolName, toolInstance := range clonedServer.GetMCPTools() { + opts.ToolRegistry.RegisterTool(config.serverName, toolName, toolInstance) + } + } else { + return fmt.Errorf("mcp server type '%s' not registered", config.serverName) + } + } + } + } else { + return errors.New("either 'server' or 'toolSet' field must be present in the configuration") + } + + // Parse allowTools - this might need adjustment for composed servers + // Use pointer to distinguish between "not configured" (nil) and "configured as empty" (empty map) + var allowTools *map[string]struct{} // For single server, tool name. For composed, serverName/toolName. + allowToolsResult := configJson.Get("allowTools") + if allowToolsResult.Exists() { + // allowTools is configured, create the map + toolsMap := make(map[string]struct{}) + allowToolsArray := allowToolsResult.Array() + for _, toolJson := range allowToolsArray { + toolsMap[toolJson.String()] = struct{}{} + } + allowTools = &toolsMap + } + // If allowTools is nil, it means not configured (allow all) + + config.methodHandlers = make(utils.MethodHandlers) + // Use config.serverName which is now reliably set + currentServerNameForHandlers := config.serverName + + config.methodHandlers["ping"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + utils.OnMCPResponseSuccess(ctx, map[string]any{}, fmt.Sprintf("mcp:%s:ping", currentServerNameForHandlers)) + return nil + } + config.methodHandlers["notifications/initialized"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/initialized", currentServerNameForHandlers), nil, nil, -1) + return nil + } + config.methodHandlers["notifications/cancelled"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/cancelled", currentServerNameForHandlers), nil, nil, -1) + return nil + } + config.methodHandlers["initialize"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + requestedVersion := params.Get("protocolVersion").String() + if requestedVersion == "" { + utils.OnMCPResponseError(ctx, errors.New("protocolVersion is required"), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:initialize:error", currentServerNameForHandlers)) + return nil + } + + // MCP specification compliant version negotiation: + // If the server supports the requested protocol version, it MUST respond with the same version. + // Otherwise, the server MUST respond with another protocol version it supports. + // This SHOULD be the latest version supported by the server. + negotiatedVersion := requestedVersion + if !slices.Contains(SupportedMCPVersions, requestedVersion) { + // Return the latest supported version instead of rejecting the request + negotiatedVersion = SupportedMCPVersions[len(SupportedMCPVersions)-1] + log.Warnf("Client requested unsupported version %s, responding with latest supported version %s", + requestedVersion, negotiatedVersion) + } + + utils.OnMCPResponseSuccess(ctx, map[string]any{ + "protocolVersion": negotiatedVersion, + "capabilities": map[string]any{ + "tools": map[string]any{}, + }, + "serverInfo": map[string]any{ + "name": currentServerNameForHandlers, // Use the actual server name (single or composed) + "version": "1.0.0", + }, + }, fmt.Sprintf("mcp:%s:initialize", currentServerNameForHandlers)) + return nil + } + + // Override tools/list and tools/call handlers for MCP proxy servers first + if config.server != nil { + if proxyServer, ok := config.server.(*McpProxyServer); ok { + // Use MCP proxy specific handlers that support ActionPause + proxyHandlers := CreateMcpProxyMethodHandlers(proxyServer, allowTools) + config.methodHandlers["tools/list"] = proxyHandlers["tools/list"] + config.methodHandlers["tools/call"] = proxyHandlers["tools/call"] + } + } + + // Default tools/list handler for non-proxy servers + if config.methodHandlers["tools/list"] == nil { + config.methodHandlers["tools/list"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + var listedTools []map[string]any + // GetMCPTools() will return appropriately formatted tools for both single and composed servers + allTools := config.server.GetMCPTools() // For composed, keys are "serverName/toolName" + + // Compute effective allowTools using helper function + effectiveAllowTools := computeEffectiveAllowTools(allowTools) + + for toolFullName, tool := range allTools { + // For composed server, toolFullName is "originalServerName/originalToolName" + // For single server, toolFullName is "originalToolName" + // The allowTools map should use the same format as toolFullName + if effectiveAllowTools != nil { + if _, allow := (*effectiveAllowTools)[toolFullName]; !allow { + continue + } + } + toolDef := map[string]any{ + "name": toolFullName, + "description": tool.Description(), + "inputSchema": tool.InputSchema(), + } + // Add outputSchema if tool implements ToolWithOutputSchema (MCP Protocol Version 2025-06-18) + if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok { + if outputSchema := toolWithSchema.OutputSchema(); len(outputSchema) > 0 { + toolDef["outputSchema"] = outputSchema + } + } + listedTools = append(listedTools, toolDef) + } + utils.OnMCPResponseSuccess(ctx, map[string]any{ + "tools": listedTools, + }, fmt.Sprintf("mcp:%s:tools/list", currentServerNameForHandlers)) + return nil + } + } + + // Default tools/call handler for non-proxy servers + if config.methodHandlers["tools/call"] == nil { + config.methodHandlers["tools/call"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + if config.isComposed { + // This endpoint is for a composed server (toolSet). + // Actual tool calls should be routed by mcp-router to individual servers. + // If a tools/call request reaches here, it's a misconfiguration or unexpected. + errMsg := fmt.Sprintf("tools/call is not supported on a composed toolSet endpoint ('%s'). It should be routed by mcp-router to the target server.", currentServerNameForHandlers) + log.Errorf(errMsg) + utils.OnMCPResponseError(ctx, errors.New(errMsg), utils.ErrMethodNotFound, fmt.Sprintf("mcp:%s:tools/call:not_supported_on_toolset", currentServerNameForHandlers)) + return nil + } + + // Logic for single (non-composed) server + toolName := params.Get("name").String() // For single server, this is the direct tool name + args := params.Get("arguments") + + // Compute effective allowTools using helper function + effectiveAllowTools := computeEffectiveAllowTools(allowTools) + + // Check if tool is allowed + if effectiveAllowTools != nil { + if _, allow := (*effectiveAllowTools)[toolName]; !allow { + utils.OnMCPResponseError(ctx, fmt.Errorf("Tool not allowed: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:tool_not_allowed", currentServerNameForHandlers)) + return nil + } + } + + proxywasm.SetProperty([]string{"mcp_server_name"}, []byte(currentServerNameForHandlers)) + proxywasm.SetProperty([]string{"mcp_tool_name"}, []byte(toolName)) + + toolToCall, ok := config.server.GetMCPTools()[toolName] + if !ok { + utils.OnMCPResponseError(ctx, fmt.Errorf("unknown tool: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:invalid_tool_name", currentServerNameForHandlers)) + return nil + } + + log.Debugf("Tool call [%s] on server [%s] with arguments[%s]", toolName, currentServerNameForHandlers, args.Raw) + toolInstance := toolToCall.Create([]byte(args.Raw)) + err := toolInstance.Call(ctx, config.server) // Pass the single server instance + if err != nil { + utils.OnMCPToolCallError(ctx, err) + return nil + } + return nil + } + } + + return nil +} + +// ParseConfigCore exports the core parsing logic for external use (e.g., validation) +func ParseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error { + return parseConfigCore(configJson, config, opts) +} + +func parseConfig(context wrapper.PluginContext, configJson gjson.Result, config *McpServerConfig) error { + registryI := context.GetContext(GlobalToolRegistryKey) + if registryI == nil { + return errors.New("GlobalToolRegistry not found") + } + registry, ok := registryI.(*GlobalToolRegistry) + if !ok { + return errors.New("invalid GlobalToolRegistry") + } + // Build runtime dependencies using global variables + opts := &ConfigOptions{ + Servers: globalContext.servers, + ToolRegistry: registry, + } + + // Call the core parsing logic + return parseConfigCore(configJson, config, opts) +} + +func Load(options ...CtxOption) { + for _, opt := range options { + opt.Apply(&globalContext) + } +} + +func Initialize() { + if globalContext.servers == nil { + panic("At least one mcpserver needs to be added.") + } + wrapper.SetCtx( + "mcp-server", + wrapper.PrePluginStartOrReload[McpServerConfig](onPluginStartOrReload), + wrapper.ParseConfigWithContext(parseConfig), + wrapper.WithLogger[McpServerConfig](&utils.MCPServerLog{}), + wrapper.ProcessRequestHeaders(onHttpRequestHeaders), + wrapper.ProcessRequestBody(onHttpRequestBody), + wrapper.ProcessResponseHeaders(onHttpResponseHeaders), + wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody), + wrapper.WithRebuildMaxMemBytes[McpServerConfig](200*1024*1024), + ) +} + +type addMCPServerOption struct { + name string + server Server +} + +func AddMCPServer(name string, server Server) CtxOption { + return &addMCPServerOption{ + name: name, + server: server, + } +} + +func (o *addMCPServerOption) Apply(ctx *Context) { + if ctx.servers == nil { + ctx.servers = make(map[string]Server) + } + if _, exist := ctx.servers[o.name]; exist { + panic(fmt.Sprintf("Conflict! There is a mcp server with the same name:%s", + o.name)) + } + ctx.servers[o.name] = o.server +} + +func ToInputSchema(v any) map[string]any { + t := reflect.TypeOf(v) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + inputSchema := jsonschema.Reflect(v).Definitions[t.Name()] + inputSchemaBytes, _ := json.Marshal(inputSchema) + var result map[string]any + json.Unmarshal(inputSchemaBytes, &result) + return result +} + +func StoreServerState(ctx wrapper.HttpContext, config any) { + if utils.IsStatefulSession(ctx) { + log.Warnf("There is no session ID, unable to store state.") + return + } + configBytes, err := json.Marshal(config) + if err != nil { + log.Errorf("Server config marshal failed:%v, config:%s", err, configBytes) + return + } + proxywasm.SetProperty([]string{"mcp_server_config"}, configBytes) +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action { + ctx.DisableReroute() + ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes) + ctx.SetResponseBodyBufferLimit(DefaultMaxBodyBytes) + + // Remove accept-encoding header to prevent backend from compressing the response + // This ensures we can properly process and modify the response body + proxywasm.RemoveHttpRequestHeader("accept-encoding") + + // Parse MCP-Protocol-Version header and store in context + // This allows clients to specify the MCP protocol version via HTTP header + // instead of only through the JSON-RPC initialize method + protocolVersion, _ := proxywasm.GetHttpRequestHeader("MCP-Protocol-Version") + if protocolVersion != "" { + // Validate the protocol version against supported versions + if slices.Contains(SupportedMCPVersions, protocolVersion) { + log.Debugf("MCP Protocol Version set from header: %s", protocolVersion) + } else { + log.Warnf("Unsupported MCP Protocol Version in header: %s", protocolVersion) + } + + // Remove the header from the request to prevent it from being forwarded + proxywasm.RemoveHttpRequestHeader("MCP-Protocol-Version") + } + + if ctx.Method() == "GET" { + proxywasm.SendHttpResponseWithDetail(405, "not_support_sse_on_this_endpoint", nil, nil, -1) + return types.HeaderStopAllIterationAndWatermark + } + // Handle DELETE request for session termination (MCP 2025-06-18 spec) + // Per spec: "Clients that no longer need a particular session SHOULD send an HTTP DELETE + // to the MCP endpoint with the Mcp-Session-Id header, to explicitly terminate the session." + // Per spec: "The server MAY respond to this request with HTTP 405 Method Not Allowed, + // indicating that the server does not allow clients to terminate sessions." + if ctx.Method() == "DELETE" { + proxywasm.SendHttpResponseWithDetail(405, "session_termination_not_supported", nil, nil, -1) + return types.HeaderStopAllIterationAndWatermark + } + if !ctx.HasRequestBody() { + proxywasm.SendHttpResponseWithDetail(400, "missing_body_in_mcp_request", nil, nil, -1) + return types.HeaderStopAllIterationAndWatermark + } + return types.HeaderStopIteration +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config McpServerConfig, body []byte) types.Action { + return utils.HandleJsonRpcMethod(ctx, body, config.methodHandlers) +} + +func onHttpResponseHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action { + // Check if this request initiated SSE channel (tools/list or tools/call with SSE transport) + // Only these requests need special SSE streaming response processing + if ctx.GetContext(CtxSSEProxyState) != nil { + // Check if response has a body + if ctx.HasResponseBody() { + // Pause streaming response for processing + // Content-type validation will be done in onHttpStreamingResponseBody + ctx.NeedPauseStreamingResponse() + return types.HeaderStopIteration + } else { + // No body, return error + utils.OnMCPResponseError(ctx, fmt.Errorf("no response body in SSE response"), utils.ErrInternalError, "mcp-proxy:sse:no_body") + return types.HeaderStopAllIterationAndWatermark + } + } + + // For non-SSE streaming requests, continue normally + return types.HeaderContinue +} + +func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte { + // Check if this request initiated SSE channel (tools/list or tools/call with SSE transport) + // Only these requests need special SSE streaming response processing + if ctx.GetContext(CtxSSEProxyState) != nil { + return handleSSEStreamingResponse(ctx, config, data, endOfStream) + } + + // For non-SSE streaming requests, return data as-is + return data +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_auth_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_auth_test.go new file mode 100644 index 000000000..a477fbfab --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_auth_test.go @@ -0,0 +1,429 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestApiKeyAuthentication tests API key authentication forwarding +func TestApiKeyAuthentication(t *testing.T) { + server := NewMcpProxyServer("auth-test") + + // Configure security scheme + scheme := SecurityScheme{ + ID: "ApiKeyAuth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + DefaultCredential: "default-api-key", + } + + server.AddSecurityScheme(scheme) + + // Set server fields directly + server.SetMcpServerURL("http://secure-backend.example.com/mcp") + server.SetTimeout(5000) + + // Create tool with client-to-gateway and gateway-to-backend security + toolConfig := McpProxyToolConfig{ + Name: "secure_tool", + Description: "Tool requiring authentication", + Security: SecurityRequirement{ + ID: "ApiKeyAuth", // Client-to-gateway authentication + Passthrough: true, // Extract client credential for backend use + }, + Args: []ToolArg{ + { + Name: "data", + Description: "Data parameter", + Type: "string", + Required: true, + }, + }, + OutputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "result": map[string]any{ + "type": "string", + "description": "The result of the operation", + }, + }, + }, + RequestTemplate: RequestTemplate{ + Security: SecurityRequirement{ + ID: "ApiKeyAuth", // Gateway-to-backend authentication (same scheme for simplicity) + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()["secure_tool"] + require.True(t, exists) + + params := map[string]interface{}{ + "data": "test data", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Authentication is now handled automatically during tool calls + // The actual authentication flow is tested in integration tests +} + +// TestBearerAuthentication tests Bearer token authentication +func TestBearerAuthentication(t *testing.T) { + server := NewMcpProxyServer("bearer-auth-test") + + // Configure Bearer security scheme + scheme := SecurityScheme{ + ID: "BearerAuth", + Type: "http", + Scheme: "bearer", + } + + server.AddSecurityScheme(scheme) + + // Set server fields directly + server.SetMcpServerURL("https://secure-backend.example.com/mcp") + server.SetTimeout(8000) + + // Create tool with Bearer authentication + // Create tool using only gateway-to-backend authentication (no client auth required) + toolConfig := McpProxyToolConfig{ + Name: "bearer_tool", + Description: "Tool with Bearer authentication to backend only", + Args: []ToolArg{ + { + Name: "query", + Description: "Query parameter", + Type: "string", + Required: true, + }, + }, + RequestTemplate: RequestTemplate{ + Security: SecurityRequirement{ + ID: "BearerAuth", // Only gateway-to-backend authentication + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()["bearer_tool"] + require.True(t, exists) + + params := map[string]interface{}{ + "query": "test query", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Authentication is now handled automatically during tool calls + // The actual authentication flow is tested in integration tests + + // Test backward compatibility: this tool uses RequestTemplate.Security (legacy way) + // which should still work +} + +// TestBasicAuthentication tests Basic authentication +func TestBasicAuthentication(t *testing.T) { + server := NewMcpProxyServer("basic-auth-test") + + // Configure Basic security scheme + scheme := SecurityScheme{ + ID: "BasicAuth", + Type: "http", + Scheme: "basic", + } + + server.AddSecurityScheme(scheme) + + // Test tool call with Basic authentication + toolConfig := McpProxyToolConfig{ + Name: "basic_tool", + Description: "Tool with Basic authentication", + Args: []ToolArg{ + { + Name: "resource", + Description: "Resource identifier", + Type: "string", + Required: true, + }, + }, + RequestTemplate: RequestTemplate{ + Security: SecurityRequirement{ + ID: "BasicAuth", + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()["basic_tool"] + require.True(t, exists) + + params := map[string]interface{}{ + "resource": "test-resource", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Authentication is now handled automatically during tool calls + // The actual authentication flow is tested in integration tests + + // Test OutputSchema functionality (only for tools that have it configured) + if toolWithOutputSchema, ok := tool.(ToolWithOutputSchema); ok { + outputSchema := toolWithOutputSchema.OutputSchema() + if outputSchema != nil { + // Only validate if outputSchema is configured + assert.Equal(t, "object", outputSchema["type"]) + properties, hasProperties := outputSchema["properties"].(map[string]any) + require.True(t, hasProperties) + resultSchema, hasResult := properties["result"].(map[string]any) + require.True(t, hasResult) + assert.Equal(t, "string", resultSchema["type"]) + assert.Equal(t, "The result of the operation", resultSchema["description"]) + } + } +} + +// TestMultipleSecuritySchemes tests multiple security schemes in one server +func TestMultipleSecuritySchemes(t *testing.T) { + server := NewMcpProxyServer("multi-auth-test") + + // Add multiple security schemes + schemes := []SecurityScheme{ + { + ID: "ApiKeyAuth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + }, + { + ID: "BearerAuth", + Type: "http", + Scheme: "bearer", + }, + } + + for _, scheme := range schemes { + server.AddSecurityScheme(scheme) + } + + // Test that both schemes are available + for _, scheme := range schemes { + retrievedScheme, exists := server.GetSecurityScheme(scheme.ID) + assert.True(t, exists) + assert.Equal(t, scheme.ID, retrievedScheme.ID) + assert.Equal(t, scheme.Type, retrievedScheme.Type) + } +} + +// ProxyAuthContext, RequestTemplate, SecurityConfig and authentication methods +// are now implemented in proxy_server.go + +// TestToolsListAuthentication tests authentication configuration for tools/list requests +func TestToolsListAuthentication(t *testing.T) { + server := NewMcpProxyServer("test-server") + + // Add a security scheme for global authentication + scheme := SecurityScheme{ + ID: "GlobalAuth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + DefaultCredential: "default-global-key", + } + server.AddSecurityScheme(scheme) + + // Test that we can retrieve the security scheme + retrievedScheme, exists := server.GetSecurityScheme("GlobalAuth") + assert.True(t, exists) + assert.Equal(t, "GlobalAuth", retrievedScheme.ID) + assert.Equal(t, "apiKey", retrievedScheme.Type) + assert.Equal(t, "header", retrievedScheme.In) + assert.Equal(t, "X-API-Key", retrievedScheme.Name) + + // Test setting default security directly on server + defaultDownstreamSecurity := SecurityRequirement{ + ID: "GlobalAuth", + Passthrough: true, + } + defaultUpstreamSecurity := SecurityRequirement{ + ID: "GlobalAuth", + } + + server.SetDefaultDownstreamSecurity(defaultDownstreamSecurity) + server.SetDefaultUpstreamSecurity(defaultUpstreamSecurity) + + // Verify default security settings + retrievedDownstream := server.GetDefaultDownstreamSecurity() + assert.Equal(t, "GlobalAuth", retrievedDownstream.ID) + assert.True(t, retrievedDownstream.Passthrough) + + retrievedUpstream := server.GetDefaultUpstreamSecurity() + assert.Equal(t, "GlobalAuth", retrievedUpstream.ID) + + t.Logf("Tools/list authentication configuration test completed successfully") +} + +// TestDefaultSecurityFallback tests the fallback mechanism from tool-level to default security +func TestDefaultSecurityFallback(t *testing.T) { + server := NewMcpProxyServer("test-server") + + // Add security schemes + defaultScheme := SecurityScheme{ + ID: "DefaultAuth", + Type: "apiKey", + In: "header", + Name: "X-Default-Key", + DefaultCredential: "default-key", + } + toolScheme := SecurityScheme{ + ID: "ToolAuth", + Type: "apiKey", + In: "header", + Name: "X-Tool-Key", + DefaultCredential: "tool-key", + } + server.AddSecurityScheme(defaultScheme) + server.AddSecurityScheme(toolScheme) + + // Test tool configuration with tool-level security (should use tool-level, not default) + toolConfigWithSecurity := McpProxyToolConfig{ + Name: "secure_tool", + Description: "Tool with its own security", + Security: SecurityRequirement{ + ID: "ToolAuth", + Passthrough: true, + }, + RequestTemplate: RequestTemplate{ + Security: SecurityRequirement{ + ID: "ToolAuth", + }, + }, + } + + // Test tool configuration without tool-level security (should fallback to default) + toolConfigWithoutSecurity := McpProxyToolConfig{ + Name: "fallback_tool", + Description: "Tool that falls back to default security", + // No Security field configured, should use default + RequestTemplate: RequestTemplate{ + // No Security field configured, should use default + }, + } + + // Set default security directly on server + server.SetDefaultDownstreamSecurity(SecurityRequirement{ + ID: "DefaultAuth", + Passthrough: false, + }) + server.SetDefaultUpstreamSecurity(SecurityRequirement{ + ID: "DefaultAuth", + }) + + // Set server configuration directly + server.SetMcpServerURL("http://backend.example.com") + server.SetTimeout(5000) + + // Add tools to server + err := server.AddProxyTool(toolConfigWithSecurity) + assert.NoError(t, err) + err = server.AddProxyTool(toolConfigWithoutSecurity) + assert.NoError(t, err) + + // Verify tools were added + tools := server.GetMCPTools() + assert.Contains(t, tools, "secure_tool") + assert.Contains(t, tools, "fallback_tool") + + t.Logf("Default security fallback test completed successfully") +} + +// TestURLModificationInAuthentication tests that authentication can modify the URL (e.g., adding query parameters) +func TestURLModificationInAuthentication(t *testing.T) { + server := NewMcpProxyServer("test-server") + + // Add a security scheme that adds parameters to query (apiKey in query) + scheme := SecurityScheme{ + ID: "QueryApiKey", + Type: "apiKey", + In: "query", + Name: "api_key", + DefaultCredential: "test-key-123", + } + server.AddSecurityScheme(scheme) + + // Verify the security scheme was added correctly + retrievedScheme, exists := server.GetSecurityScheme("QueryApiKey") + assert.True(t, exists) + assert.Equal(t, "apiKey", retrievedScheme.Type) + assert.Equal(t, "query", retrievedScheme.In) + assert.Equal(t, "api_key", retrievedScheme.Name) + + t.Logf("URL modification authentication configuration test completed successfully") +} + +// TestProxyServerFields tests the server-level field setting and getting +func TestProxyServerFields(t *testing.T) { + server := NewMcpProxyServer("test-server") + + // Test mcpServerURL + testURL := "http://mcp.example.com:8080/mcp" + server.SetMcpServerURL(testURL) + assert.Equal(t, testURL, server.GetMcpServerURL()) + + // Test timeout + testTimeout := 10000 + server.SetTimeout(testTimeout) + assert.Equal(t, testTimeout, server.GetTimeout()) + + // Test default security settings + downstreamSec := SecurityRequirement{ + ID: "test-downstream", + Passthrough: true, + } + upstreamSec := SecurityRequirement{ + ID: "test-upstream", + } + + server.SetDefaultDownstreamSecurity(downstreamSec) + server.SetDefaultUpstreamSecurity(upstreamSec) + + assert.Equal(t, "test-downstream", server.GetDefaultDownstreamSecurity().ID) + assert.True(t, server.GetDefaultDownstreamSecurity().Passthrough) + assert.Equal(t, "test-upstream", server.GetDefaultUpstreamSecurity().ID) + + t.Logf("Proxy server fields test completed successfully") +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_integration_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_integration_test.go new file mode 100644 index 000000000..d943a80ab --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_integration_test.go @@ -0,0 +1,302 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockHttpContext is a mock implementation for testing - skipping interface implementation for now +// Tests that require full HttpContext will be tested in integration tests with real host +type MockHttpContext struct { + responseBody []byte + responseStatus int + headers map[string]string +} + +// TestMcpProtocolInitialization tests the MCP protocol initialization flow +func TestMcpProtocolInitialization(t *testing.T) { + // Create proxy server + server := NewMcpProxyServer("test-proxy") + + // Set server fields directly + server.SetMcpServerURL("http://mock-backend.example.com/mcp") + server.SetTimeout(5000) + + // Create proxy tool + toolConfig := McpProxyToolConfig{ + Name: "test-tool", + Description: "Test tool for initialization", + Args: []ToolArg{ + { + Name: "input", + Description: "Test input", + Type: "string", + Required: true, + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()["test-tool"] + require.True(t, exists) + + // Create tool instance with parameters + params := map[string]interface{}{ + "input": "test value", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Skip HttpContext-dependent test for now - will be tested in integration + // mockCtx := &MockHttpContext{} + // err = toolInstance.Call(mockCtx, server) + // assert.NoError(t, err) + + // Test the tool creation was successful + assert.NotNil(t, toolInstance) +} + +// TestMcpSessionManagement tests temporary session creation and cleanup +func TestMcpSessionManagement(t *testing.T) { + _ = NewMcpProxyServer("session-test") + + // Skip session management test until implemented + t.Skip("Session management not implemented yet") + + // Test session creation + sessionManager := NewMcpSessionManager() + sessionID, err := sessionManager.CreateSession("http://backend.example.com/mcp") + + // This will fail until session management is implemented + assert.NoError(t, err) + assert.NotEmpty(t, sessionID) + + // Test session retrieval + session, exists := sessionManager.GetSession(sessionID) + assert.True(t, exists) + assert.NotNil(t, session) + + // Test session cleanup + sessionManager.CleanupSession(sessionID) + _, exists = sessionManager.GetSession(sessionID) + assert.False(t, exists) +} + +// TestMcpProtocolVersionNegotiation tests protocol version handling +func TestMcpProtocolVersionNegotiation(t *testing.T) { + tests := []struct { + name string + requestedVersion string + supportedVersions []string + shouldSucceed bool + expectedVersion string + }{ + { + name: "supported version 2025-03-26", + requestedVersion: "2025-03-26", + supportedVersions: []string{"2024-11-05", "2025-03-26"}, + shouldSucceed: true, + expectedVersion: "2025-03-26", + }, + { + name: "unsupported version", + requestedVersion: "2026-01-01", + supportedVersions: []string{"2024-11-05", "2025-03-26"}, + shouldSucceed: false, + expectedVersion: "", + }, + { + name: "fallback to supported version", + requestedVersion: "2025-06-18", + supportedVersions: []string{"2024-11-05", "2025-03-26"}, + shouldSucceed: false, + expectedVersion: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip until NewMcpVersionNegotiator is implemented + t.Skip("Version negotiation not implemented yet") + + negotiator := NewMcpVersionNegotiator(tt.supportedVersions) + version, err := negotiator.NegotiateVersion(tt.requestedVersion) + + if tt.shouldSucceed { + assert.NoError(t, err) + assert.Equal(t, tt.expectedVersion, version) + } else { + assert.Error(t, err) + } + }) + } +} + +// TestMcpInitializeRequest tests the initialize request format and handling +func TestMcpInitializeRequest(t *testing.T) { + _ = NewMcpProxyServer("init-test") + + // Skip until CreateInitializeRequest is implemented + t.Skip("MCP protocol initialization not implemented yet") + + // Test initialize request creation + initRequest := CreateInitializeRequest() + + assert.Equal(t, "2.0", initRequest.JsonRPC) + assert.Equal(t, "initialize", initRequest.Method) + assert.NotNil(t, initRequest.Params) + + // Validate client info + params := initRequest.Params.(map[string]interface{}) + clientInfo := params["clientInfo"].(map[string]interface{}) + assert.Equal(t, "Higress-mcp-proxy", clientInfo["name"]) + assert.Equal(t, "1.0.0", clientInfo["version"]) + + // Test protocol version + assert.Equal(t, "2025-03-26", params["protocolVersion"]) +} + +// TestMcpNotificationsInitialized tests the notifications/initialized message +func TestMcpNotificationsInitialized(t *testing.T) { + // Skip until CreateInitializedNotification is implemented + t.Skip("MCP notifications not implemented yet") + + // Test notifications/initialized request creation + notification := CreateInitializedNotification() + + assert.Equal(t, "2.0", notification.JsonRPC) + assert.Equal(t, "notifications/initialized", notification.Method) + assert.Nil(t, notification.ID) // Notifications don't have IDs +} + +// TestMcpErrorHandling tests error response handling and source identification +func TestMcpErrorHandling(t *testing.T) { + tests := []struct { + name string + errorType string + originalError error + expectedSource string + expectedCode int + }{ + { + name: "backend connection error", + errorType: "connection", + originalError: assert.AnError, + expectedSource: "mcp-proxy", + expectedCode: -32603, + }, + { + name: "backend timeout error", + errorType: "timeout", + originalError: assert.AnError, + expectedSource: "mcp-proxy", + expectedCode: -32000, + }, + { + name: "protocol version error", + errorType: "version", + originalError: assert.AnError, + expectedSource: "mcp-proxy", + expectedCode: -32602, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip until CreateMcpErrorResponse is implemented + t.Skip("MCP error handling not implemented yet") + + errorResponse := CreateMcpErrorResponse(tt.errorType, tt.originalError, "http://backend.example.com/mcp") + + assert.Equal(t, "2.0", errorResponse.JsonRPC) + assert.NotNil(t, errorResponse.Error) + assert.Equal(t, tt.expectedCode, errorResponse.Error.Code) + assert.Equal(t, tt.expectedSource, errorResponse.Error.Data["source"]) + }) + } +} + +// Helper types and functions that will fail until implemented + +type McpSessionManager struct{} + +func NewMcpSessionManager() *McpSessionManager { + panic("McpSessionManager not implemented yet") +} + +func (m *McpSessionManager) CreateSession(backendURL string) (string, error) { + panic("CreateSession not implemented yet") +} + +func (m *McpSessionManager) GetSession(sessionID string) (interface{}, bool) { + panic("GetSession not implemented yet") +} + +func (m *McpSessionManager) CleanupSession(sessionID string) { + panic("CleanupSession not implemented yet") +} + +type McpVersionNegotiator struct { + supportedVersions []string +} + +func NewMcpVersionNegotiator(versions []string) *McpVersionNegotiator { + panic("McpVersionNegotiator not implemented yet") +} + +func (n *McpVersionNegotiator) NegotiateVersion(requested string) (string, error) { + panic("NegotiateVersion not implemented yet") +} + +type McpRequest struct { + JsonRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +type McpErrorResponse struct { + JsonRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Error *McpError `json:"error"` +} + +type McpError struct { + Code int `json:"code"` + Message string `json:"message"` + Data map[string]interface{} `json:"data,omitempty"` +} + +func CreateInitializeRequest() *McpRequest { + panic("CreateInitializeRequest not implemented yet") +} + +func CreateInitializedNotification() *McpRequest { + panic("CreateInitializedNotification not implemented yet") +} + +func CreateMcpErrorResponse(errorType string, originalError error, backendURL string) *McpErrorResponse { + panic("CreateMcpErrorResponse not implemented yet") +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_server.go b/plugins/wasm-go/pkg/mcp/server/proxy_server.go new file mode 100644 index 000000000..c468b0e98 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_server.go @@ -0,0 +1,500 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "fmt" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +// McpProxyConfig represents the configuration for MCP proxy server +// Note: mcpServerURL, timeout, defaultDownstreamSecurity, and defaultUpstreamSecurity +// are now direct server fields, not part of this config structure +type McpProxyConfig struct { + // This structure is kept for any additional server configuration that may be needed in the future + // Currently, most configuration is handled as direct server fields +} + +// TransportProtocol represents the transport protocol type for MCP proxy +type TransportProtocol string + +const ( + TransportHTTP TransportProtocol = "http" // StreamableHTTP protocol + TransportSSE TransportProtocol = "sse" // SSE protocol +) + +// ToolArg represents an argument for a proxy tool +type ToolArg struct { + Name string `json:"name"` + Description string `json:"description"` + Type string `json:"type"` + Required bool `json:"required"` + Default interface{} `json:"default,omitempty"` + Enum []interface{} `json:"enum,omitempty"` +} + +// McpProxyToolConfig represents a tool configuration for MCP proxy +type McpProxyToolConfig struct { + Name string `json:"name"` + Description string `json:"description"` + Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server + Args []ToolArg `json:"args"` + OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18 + RequestTemplate RequestTemplate `json:"requestTemplate,omitempty"` +} + +// RequestTemplate defines request template configuration for proxy tools +type RequestTemplate struct { + Security SecurityRequirement `json:"security,omitempty"` +} + +// McpProxyServer implements Server interface for MCP-to-MCP proxy +type McpProxyServer struct { + Name string + base BaseMCPServer + toolsConfig map[string]McpProxyToolConfig + securitySchemes map[string]SecurityScheme + defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication + defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication + mcpServerURL string // Backend MCP server URL + timeout int // Request timeout in milliseconds + transport TransportProtocol // Transport protocol (http or sse) + passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security +} + +// NewMcpProxyServer creates a new MCP proxy server +func NewMcpProxyServer(name string) *McpProxyServer { + return &McpProxyServer{ + Name: name, + base: NewBaseMCPServer(), + toolsConfig: make(map[string]McpProxyToolConfig), + securitySchemes: make(map[string]SecurityScheme), + } +} + +// AddSecurityScheme adds a security scheme to the server's map +func (s *McpProxyServer) AddSecurityScheme(scheme SecurityScheme) { + if s.securitySchemes == nil { + s.securitySchemes = make(map[string]SecurityScheme) + } + s.securitySchemes[scheme.ID] = scheme +} + +// GetSecurityScheme retrieves a security scheme by its ID from the map +func (s *McpProxyServer) GetSecurityScheme(id string) (SecurityScheme, bool) { + scheme, ok := s.securitySchemes[id] + return scheme, ok +} + +// SetDefaultDownstreamSecurity sets the default downstream security configuration +func (s *McpProxyServer) SetDefaultDownstreamSecurity(security SecurityRequirement) { + s.defaultDownstreamSecurity = security +} + +// GetDefaultDownstreamSecurity gets the default downstream security configuration +func (s *McpProxyServer) GetDefaultDownstreamSecurity() SecurityRequirement { + return s.defaultDownstreamSecurity +} + +// SetDefaultUpstreamSecurity sets the default upstream security configuration +func (s *McpProxyServer) SetDefaultUpstreamSecurity(security SecurityRequirement) { + s.defaultUpstreamSecurity = security +} + +// GetDefaultUpstreamSecurity gets the default upstream security configuration +func (s *McpProxyServer) GetDefaultUpstreamSecurity() SecurityRequirement { + return s.defaultUpstreamSecurity +} + +// SetMcpServerURL sets the backend MCP server URL +func (s *McpProxyServer) SetMcpServerURL(url string) { + s.mcpServerURL = url +} + +// GetMcpServerURL gets the backend MCP server URL +func (s *McpProxyServer) GetMcpServerURL() string { + return s.mcpServerURL +} + +// SetTimeout sets the request timeout in milliseconds +func (s *McpProxyServer) SetTimeout(timeout int) { + s.timeout = timeout +} + +// GetTimeout gets the request timeout in milliseconds +func (s *McpProxyServer) GetTimeout() int { + return s.timeout +} + +// SetTransport sets the transport protocol +func (s *McpProxyServer) SetTransport(transport TransportProtocol) { + s.transport = transport +} + +// GetTransport gets the transport protocol +func (s *McpProxyServer) GetTransport() TransportProtocol { + return s.transport +} + +// AddMCPTool implements Server interface +func (s *McpProxyServer) AddMCPTool(name string, tool Tool) Server { + s.base.AddMCPTool(name, tool) + return s +} + +// AddProxyTool adds a proxy tool configuration +func (s *McpProxyServer) AddProxyTool(toolConfig McpProxyToolConfig) error { + s.toolsConfig[toolConfig.Name] = toolConfig + s.base.AddMCPTool(toolConfig.Name, &McpProxyTool{ + serverName: s.Name, + name: toolConfig.Name, + toolConfig: toolConfig, + }) + return nil +} + +// GetMCPTools implements Server interface +func (s *McpProxyServer) GetMCPTools() map[string]Tool { + return s.base.GetMCPTools() +} + +// SetConfig implements Server interface +func (s *McpProxyServer) SetConfig(config []byte) { + s.base.SetConfig(config) +} + +// GetConfig implements Server interface +func (s *McpProxyServer) GetConfig(v any) { + s.base.GetConfig(v) +} + +// Clone implements Server interface +func (s *McpProxyServer) Clone() Server { + newServer := &McpProxyServer{ + Name: s.Name, + base: s.base.CloneBase(), + toolsConfig: make(map[string]McpProxyToolConfig), + securitySchemes: make(map[string]SecurityScheme), + } + for k, v := range s.toolsConfig { + newServer.toolsConfig[k] = v + } + // Deep copy securitySchemes + if s.securitySchemes != nil { + for k, v := range s.securitySchemes { + newServer.securitySchemes[k] = v + } + } + return newServer +} + +// GetToolConfig returns the proxy tool configuration for a given tool name +func (s *McpProxyServer) GetToolConfig(name string) (McpProxyToolConfig, bool) { + config, ok := s.toolsConfig[name] + return config, ok +} + +// SetPassthroughAuthHeader sets the passthrough auth header flag +func (s *McpProxyServer) SetPassthroughAuthHeader(passthrough bool) { + s.passthroughAuthHeader = passthrough +} + +// GetPassthroughAuthHeader gets the passthrough auth header flag +func (s *McpProxyServer) GetPassthroughAuthHeader() bool { + return s.passthroughAuthHeader +} + +// ForwardToolsList forwards tools/list request to backend MCP server +func (s *McpProxyServer) ForwardToolsList(ctx HttpContext, cursor *string) error { + wrapperCtx := ctx.(wrapper.HttpContext) + + // Handle default downstream security for tools/list requests + // tools/list requests use server-level default authentication configuration + passthroughCredential := "" + downstreamSecurity := s.GetDefaultDownstreamSecurity() + if downstreamSecurity.ID != "" { + clientScheme, schemeOk := s.GetSecurityScheme(downstreamSecurity.ID) + if !schemeOk { + log.Warnf("Default downstream security scheme ID '%s' not found for tools/list request.", downstreamSecurity.ID) + } else { + // Extract and remove the credential from the incoming request + extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme) + if err != nil { + log.Warnf("Failed to extract/remove incoming credential for tools/list using scheme %s: %v", clientScheme.ID, err) + } else if extractedCred == "" { + log.Debugf("No incoming credential found for tools/list using scheme %s for extraction/removal.", clientScheme.ID) + } + + // Only use passthrough if explicitly configured + if downstreamSecurity.Passthrough && extractedCred != "" { + passthroughCredential = extractedCred + log.Debugf("Passthrough credential set for tools/list request.") + } + } + } else { + // Fallback: Remove Authorization header if no downstream security is defined + // This prevents downstream credentials from being mistakenly passed to upstream + // Unless passthroughAuthHeader is explicitly set to true + if !s.GetPassthroughAuthHeader() { + proxywasm.RemoveHttpRequestHeader("Authorization") + } + } + + // Create protocol handler using server fields + handler := NewMcpProtocolHandler(s.GetMcpServerURL(), s.GetTimeout()) + + // Prepare authentication information for gateway-to-backend communication + var authInfo *ProxyAuthInfo + upstreamSecurity := s.GetDefaultUpstreamSecurity() + if upstreamSecurity.ID != "" { + authInfo = &ProxyAuthInfo{ + SecuritySchemeID: upstreamSecurity.ID, + PassthroughCredential: passthroughCredential, + Server: s, + } + } + + // This will handle initialization asynchronously if needed and use ActionPause/Resume + return handler.ForwardToolsList(wrapperCtx, cursor, authInfo) +} + +// McpProxyTool implements Tool interface for MCP-to-MCP proxy +type McpProxyTool struct { + serverName string + name string + toolConfig McpProxyToolConfig + arguments map[string]interface{} +} + +// Create implements Tool interface +func (t *McpProxyTool) Create(params []byte) Tool { + newTool := &McpProxyTool{ + serverName: t.serverName, + name: t.name, + toolConfig: t.toolConfig, + arguments: make(map[string]interface{}), + } + + if len(params) > 0 { + json.Unmarshal(params, &newTool.arguments) + } + + return newTool +} + +// Call implements Tool interface - this is where the MCP protocol handling happens +func (t *McpProxyTool) Call(httpCtx HttpContext, server Server) error { + ctx := httpCtx.(wrapper.HttpContext) + + // Get proxy server instance to access configuration + proxyServer, ok := server.(*McpProxyServer) + if !ok { + return fmt.Errorf("server is not a McpProxyServer") + } + + // Handle tool-level or default downstream security: extract credential for passthrough if configured + // toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity + passthroughCredential := "" + var downstreamSecurity SecurityRequirement + if t.toolConfig.Security.ID != "" { + // Use tool-level security if configured + downstreamSecurity = t.toolConfig.Security + log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID) + } else { + // Fall back to server's default downstream security + downstreamSecurity = proxyServer.GetDefaultDownstreamSecurity() + if downstreamSecurity.ID != "" { + log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID) + } + } + + if downstreamSecurity.ID != "" { + clientScheme, schemeOk := proxyServer.GetSecurityScheme(downstreamSecurity.ID) + if !schemeOk { + log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name) + } else { + // Extract and remove the credential from the incoming request + extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme) + if err != nil { + log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err) + } else if extractedCred == "" { + log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID) + } + + // Only use passthrough if explicitly configured + if downstreamSecurity.Passthrough && extractedCred != "" { + passthroughCredential = extractedCred + log.Debugf("Passthrough credential set for tool %s.", t.name) + } + } + } else { + // Fallback: Remove Authorization header if no downstream security is defined + // This prevents downstream credentials from being mistakenly passed to upstream + // Unless passthroughAuthHeader is explicitly set to true + if !proxyServer.GetPassthroughAuthHeader() { + proxywasm.RemoveHttpRequestHeader("Authorization") + } + } + + // Create protocol handler using server fields + handler := NewMcpProtocolHandler(proxyServer.GetMcpServerURL(), proxyServer.GetTimeout()) + + // Prepare authentication information for gateway-to-backend communication + // toolConfig.RequestTemplate.Security represents gateway-to-backend authentication, falls back to server's defaultUpstreamSecurity + var authInfo *ProxyAuthInfo + var upstreamSecurity SecurityRequirement + if t.toolConfig.RequestTemplate.Security.ID != "" { + // Use tool-level upstream security if configured + upstreamSecurity = t.toolConfig.RequestTemplate.Security + log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID) + } else { + // Fall back to server's default upstream security + upstreamSecurity = proxyServer.GetDefaultUpstreamSecurity() + if upstreamSecurity.ID != "" { + log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID) + } + } + + if upstreamSecurity.ID != "" { + authInfo = &ProxyAuthInfo{ + SecuritySchemeID: upstreamSecurity.ID, + PassthroughCredential: passthroughCredential, + Server: proxyServer, + } + } + + // This will handle initialization asynchronously if needed and use ActionPause/Resume + return handler.ForwardToolsCall(ctx, t.name, t.arguments, authInfo) +} + +// Description implements Tool interface +func (t *McpProxyTool) Description() string { + return t.toolConfig.Description +} + +// InputSchema implements Tool interface +func (t *McpProxyTool) InputSchema() map[string]any { + schema := map[string]any{ + "type": "object", + "properties": make(map[string]any), + "required": []string{}, + } + + properties := schema["properties"].(map[string]any) + var required []string + + for _, arg := range t.toolConfig.Args { + argSchema := map[string]any{ + "type": arg.Type, + "description": arg.Description, + } + + if arg.Default != nil { + argSchema["default"] = arg.Default + } + + if len(arg.Enum) > 0 { + argSchema["enum"] = arg.Enum + } + + properties[arg.Name] = argSchema + + if arg.Required { + required = append(required, arg.Name) + } + } + + schema["required"] = required + return schema +} + +// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18) +func (t *McpProxyTool) OutputSchema() map[string]any { + return t.toolConfig.OutputSchema +} + +// ValidateSecurityScheme validates a security scheme configuration +func ValidateSecurityScheme(scheme SecurityScheme) error { + if scheme.ID == "" { + return fmt.Errorf("security scheme ID is required") + } + + if scheme.Type != "apiKey" && scheme.Type != "http" { + return fmt.Errorf("invalid security scheme type: %s", scheme.Type) + } + + if scheme.Type == "apiKey" { + if scheme.Name == "" { + return fmt.Errorf("security scheme name is required for apiKey type") + } + if scheme.In != "header" && scheme.In != "query" && scheme.In != "cookie" { + return fmt.Errorf("invalid security scheme location: %s", scheme.In) + } + } + + if scheme.Type == "http" { + if scheme.Scheme == "" { + return fmt.Errorf("security scheme scheme is required for http type") + } + } + + return nil +} + +// ValidateToolConfig validates a tool configuration +func ValidateToolConfig(config McpProxyToolConfig) error { + if config.Name == "" { + return fmt.Errorf("tool name is required") + } + + if config.Description == "" { + return fmt.Errorf("tool description is required") + } + + // Validate arguments + argNames := make(map[string]bool) + for _, arg := range config.Args { + if arg.Name == "" { + return fmt.Errorf("argument name is required") + } + + if argNames[arg.Name] { + return fmt.Errorf("duplicate argument name: %s", arg.Name) + } + argNames[arg.Name] = true + + if arg.Description == "" { + return fmt.Errorf("argument description is required for %s", arg.Name) + } + + validTypes := []string{"string", "number", "integer", "boolean", "array", "object"} + validType := false + for _, t := range validTypes { + if arg.Type == t { + validType = true + break + } + } + if !validType { + return fmt.Errorf("invalid argument type %s for %s", arg.Type, arg.Name) + } + } + + return nil +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go new file mode 100644 index 000000000..6c7079e2e --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go @@ -0,0 +1,112 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestMcpProxyServerBasicInterface tests that McpProxyServer implements the Server interface +func TestMcpProxyServerBasicInterface(t *testing.T) { + // This test will fail until we implement McpProxyServer + server := NewMcpProxyServer("test-proxy") + + // Test Server interface implementation + assert.NotNil(t, server) + assert.Equal(t, "test-proxy", server.Name) + + // Test that it implements all required methods + tools := server.GetMCPTools() + assert.NotNil(t, tools) + assert.Equal(t, 0, len(tools)) + + // Test Clone method + cloned := server.Clone() + assert.NotNil(t, cloned) +} + +// TestMcpProxyServerConfiguration tests configuration setting and getting +func TestMcpProxyServerConfiguration(t *testing.T) { + server := NewMcpProxyServer("test-proxy") + + // Set server fields directly + server.SetMcpServerURL("http://backend.example.com/mcp") + server.SetTimeout(5000) + + // Add security scheme + scheme := SecurityScheme{ + ID: "test-auth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + } + server.AddSecurityScheme(scheme) + + // Verify server fields + assert.Equal(t, "http://backend.example.com/mcp", server.GetMcpServerURL()) + assert.Equal(t, 5000, server.GetTimeout()) + + // Verify security scheme + retrievedScheme, exists := server.GetSecurityScheme("test-auth") + assert.True(t, exists) + assert.Equal(t, "test-auth", retrievedScheme.ID) + assert.Equal(t, "apiKey", retrievedScheme.Type) +} + +// TestMcpProxyServerAddTool tests adding proxy tools +func TestMcpProxyServerAddTool(t *testing.T) { + server := NewMcpProxyServer("test-proxy") + + toolConfig := McpProxyToolConfig{ + Name: "test-tool", + Description: "Test tool for proxy", + Args: []ToolArg{ + { + Name: "input", + Description: "Test input", + Type: "string", + Required: true, + }, + }, + } + + err := server.AddProxyTool(toolConfig) + assert.NoError(t, err) + + tools := server.GetMCPTools() + assert.Len(t, tools, 1) + assert.Contains(t, tools, "test-tool") +} + +// TestMcpProxyServerSecuritySchemes tests security scheme management +func TestMcpProxyServerSecuritySchemes(t *testing.T) { + server := NewMcpProxyServer("test-proxy") + + scheme := SecurityScheme{ + ID: "test-auth", + Type: "apiKey", + In: "header", + Name: "X-API-Key", + } + + server.AddSecurityScheme(scheme) + + retrievedScheme, exists := server.GetSecurityScheme("test-auth") + assert.True(t, exists) + assert.Equal(t, scheme.ID, retrievedScheme.ID) + assert.Equal(t, scheme.Type, retrievedScheme.Type) +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_tool.go b/plugins/wasm-go/pkg/mcp/server/proxy_tool.go new file mode 100644 index 000000000..d6812d7dc --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_tool.go @@ -0,0 +1,1269 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + // Context keys for MCP proxy state management + CtxMcpProxyInitialized = "mcp_proxy_initialized" + CtxMcpProxySessionID = "mcp_proxy_session_id" + CtxMcpProxyToolName = "mcp_proxy_tool_name" + CtxMcpProxyToolArgs = "mcp_proxy_tool_args" + CtxMcpProxyOperation = "mcp_proxy_operation" +) + +// ProxyAuthInfo holds authentication information for proxy tool calls +type ProxyAuthInfo struct { + SecuritySchemeID string // RequestTemplate.Security.ID for gateway-to-backend auth + PassthroughCredential string // Credential extracted from client request (if passthrough enabled) + Server *McpProxyServer // Server instance for accessing security schemes +} + +// McpProxyOperation represents the current operation type +type McpProxyOperation string + +const ( + OpToolsList McpProxyOperation = "tools/list" + OpToolsCall McpProxyOperation = "tools/call" +) + +// McpProtocolHandler handles MCP protocol initialization and communication +type McpProtocolHandler struct { + backendURL string + timeout int + sessionID string +} + +// NewMcpProtocolHandler creates a new MCP protocol handler +func NewMcpProtocolHandler(backendURL string, timeout int) *McpProtocolHandler { + return &McpProtocolHandler{ + backendURL: backendURL, + timeout: timeout, + } +} + +// parseSSEResponse parses Server-Sent Events format and extracts data field content +func parseSSEResponse(sseData []byte) ([]byte, error) { + scanner := bufio.NewScanner(bytes.NewReader(sseData)) + // Set max token size to 32MB to handle large messages + maxTokenSize := 32 * 1024 * 1024 // 32MB + scanner.Buffer(make([]byte, 0, 64*1024), maxTokenSize) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Look for data field + if strings.HasPrefix(line, "data: ") { + dataContent := strings.TrimPrefix(line, "data: ") + return []byte(dataContent), nil + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + return nil, fmt.Errorf("SSE response line exceeds maximum token size (32MB): %w", err) + } + return nil, fmt.Errorf("error reading SSE data: %v", err) + } + + return nil, fmt.Errorf("no data field found in SSE response") +} + +// Initialize performs the MCP protocol initialization sequence asynchronously +func (h *McpProtocolHandler) Initialize(ctx wrapper.HttpContext, authInfo *ProxyAuthInfo) error { + log.Infof("Starting MCP protocol initialization for %s", h.backendURL) + + // Check if already initialized for this context + if initialized := ctx.GetContext(CtxMcpProxyInitialized); initialized != nil { + if sessionID := ctx.GetContext(CtxMcpProxySessionID); sessionID != nil { + h.sessionID = sessionID.(string) + log.Debugf("MCP proxy already initialized with session ID: %s", h.sessionID) + return nil + } + } + + // Step 1: Send initialize request + initRequest := h.createInitializeRequest() + requestBody, err := json.Marshal(initRequest) + if err != nil { + return fmt.Errorf("failed to marshal initialize request: %v", err) + } + + // Send initialize request to backend asynchronously + err = h.sendMcpRequest(ctx, requestBody, authInfo, func(statusCode int, responseHeaders [][2]string, responseBody []byte) { + // Don't resume here - either OnMCPResponseError will send response directly, + // or sendInitializedNotification will continue the async flow + if statusCode != 200 { + log.Errorf("Initialize request failed with status %d: %s", statusCode, string(responseBody)) + utils.OnMCPResponseError(ctx, fmt.Errorf("backend initialization failed"), utils.ErrInternalError, "mcp-proxy:initialize:backend_error") + return + } + + // Determine response content type and parse accordingly + var jsonResponseBody []byte + var contentType string + + // Find content-type header + for _, header := range responseHeaders { + if strings.ToLower(header[0]) == "content-type" { + contentType = strings.ToLower(header[1]) + break + } + } + + // Parse response based on content type + if strings.Contains(contentType, "text/event-stream") { + // Handle SSE format + log.Debugf("Processing SSE response for initialize request") + parsedJSON, err := parseSSEResponse(responseBody) + if err != nil { + log.Errorf("Failed to parse SSE response: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:initialize:sse_parse_error") + return + } + jsonResponseBody = parsedJSON + } else { + // Handle JSON format (default) + log.Debugf("Processing JSON response for initialize request") + jsonResponseBody = responseBody + } + + // Parse initialize response + var response map[string]interface{} + if err := json.Unmarshal(jsonResponseBody, &response); err != nil { + log.Errorf("Failed to parse initialize response: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:initialize:parse_error") + return + } + + // Check for protocol version compatibility + if errorObj, exists := response["error"]; exists { + log.Errorf("Backend initialization error: %v", errorObj) + + // Check if it's a version compatibility error + if errorMap, ok := errorObj.(map[string]interface{}); ok { + if code, codeOk := errorMap["code"]; codeOk && code == -32602 { + // Protocol version not supported + utils.OnMCPResponseError(ctx, fmt.Errorf("protocol version not supported by backend"), utils.ErrInvalidParams, "mcp-proxy:initialize:version_incompatible") + return + } + } + + utils.OnMCPResponseError(ctx, fmt.Errorf("backend initialization failed"), utils.ErrInternalError, "mcp-proxy:initialize:backend_error") + return + } + + // Extract session ID from response headers if present + for _, header := range responseHeaders { + if header[0] == "Mcp-Session-Id" { + h.sessionID = header[1] + ctx.SetContext(CtxMcpProxySessionID, h.sessionID) + log.Infof("Received MCP session ID: %s", h.sessionID) + break + } + } + + // Step 2: Send notifications/initialized + h.sendInitializedNotification(ctx, authInfo) + }) + + return err +} + +// ForwardToolsList forwards tools/list request to backend MCP server +func (h *McpProtocolHandler) ForwardToolsList(ctx wrapper.HttpContext, cursor *string, authInfo *ProxyAuthInfo) error { + log.Debugf("Forwarding tools/list request to %s", h.backendURL) + + // Store the cursor for later execution + ctx.SetContext(CtxMcpProxyOperation, OpToolsList) + if cursor != nil { + ctx.SetContext("mcp_proxy_cursor", *cursor) + } + if authInfo != nil { + ctx.SetContext("mcp_proxy_auth_info", authInfo) + } + + // Check if MCP is already initialized + if initialized := ctx.GetContext(CtxMcpProxyInitialized); initialized != nil { + // Already initialized, execute directly + return h.executeToolsList(ctx) + } + + // Need to initialize first, which will execute tools/list in its callback + return h.Initialize(ctx, authInfo) +} + +// executeToolsList executes the actual tools/list request +func (h *McpProtocolHandler) executeToolsList(ctx wrapper.HttpContext) error { + var cursor *string + if cursorVal := ctx.GetContext("mcp_proxy_cursor"); cursorVal != nil { + cursorStr := cursorVal.(string) + cursor = &cursorStr + } + + listRequest := h.createToolsListRequest(cursor) + requestBody, err := json.Marshal(listRequest) + if err != nil { + return fmt.Errorf("failed to marshal tools/list request: %v", err) + } + + headers := [][2]string{ + {"Content-Type", "application/json"}, + {"Accept", "application/json,text/event-stream"}, + } + + // Add session ID if we have one + if h.sessionID != "" { + headers = append(headers, [2]string{"Mcp-Session-Id", h.sessionID}) + } + + // Start with the original backend URL + finalURL := h.backendURL + + // Apply authentication if auth info was provided + if authInfoCtx := ctx.GetContext("mcp_proxy_auth_info"); authInfoCtx != nil { + if authInfo, ok := authInfoCtx.(*ProxyAuthInfo); ok && authInfo.SecuritySchemeID != "" { + // Apply authentication using shared utilities + modifiedURL, err := h.applyProxyAuthentication(authInfo.Server, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &headers) + if err != nil { + log.Errorf("Failed to apply authentication for tools/list request: %v", err) + } else { + // Use the modified URL if authentication was applied successfully + finalURL = modifiedURL + log.Debugf("Using modified URL for tools/list request: %s", finalURL) + } + } + } + + // Use RouteCall for the final tools/list request with potentially modified URL + return ctx.RouteCall("POST", finalURL, headers, requestBody, func(statusCode int, responseHeaders [][2]string, responseBody []byte) { + if statusCode != 200 { + log.Errorf("Tools/list request failed with status %d: %s", statusCode, string(responseBody)) + utils.OnMCPResponseError(ctx, fmt.Errorf("backend tools/list failed"), utils.ErrInternalError, "mcp-proxy:tools/list:backend_error") + return + } + + // Determine response content type and parse accordingly + var jsonResponseBody []byte + var contentType string + + // Find content-type header + for _, header := range responseHeaders { + if strings.ToLower(header[0]) == "content-type" { + contentType = strings.ToLower(header[1]) + break + } + } + + // Parse response based on content type + if strings.Contains(contentType, "text/event-stream") { + // Handle SSE format + log.Debugf("Processing SSE response for tools/list request") + parsedJSON, err := parseSSEResponse(responseBody) + if err != nil { + log.Errorf("Failed to parse SSE response: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:tools/list:sse_parse_error") + return + } + jsonResponseBody = parsedJSON + } else { + // Handle JSON format (default) + log.Debugf("Processing JSON response for tools/list request") + jsonResponseBody = responseBody + } + + // Parse response and forward to client + var response map[string]interface{} + if err := json.Unmarshal(jsonResponseBody, &response); err != nil { + log.Errorf("Failed to parse tools/list response: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:tools/list:parse_error") + return + } + + // Forward the tools/list result with allowTools filtering + if result, hasResult := response["result"]; hasResult { + if resultMap, ok := result.(map[string]interface{}); ok { + // Apply allowTools filtering if needed + filteredResult := h.applyAllowToolsFilter(ctx, resultMap) + utils.OnMCPResponseSuccess(ctx, filteredResult, "mcp-proxy:tools/list:success") + } else { + utils.OnMCPResponseError(ctx, fmt.Errorf("invalid tools/list result type"), utils.ErrInternalError, "mcp-proxy:tools/list:invalid_type") + } + } else { + utils.OnMCPResponseError(ctx, fmt.Errorf("invalid tools/list response"), utils.ErrInternalError, "mcp-proxy:tools/list:invalid_response") + } + }) +} + +// ForwardToolsCall forwards tools/call request to backend MCP server +func (h *McpProtocolHandler) ForwardToolsCall(ctx wrapper.HttpContext, toolName string, arguments map[string]interface{}, authInfo *ProxyAuthInfo) error { + log.Debugf("Forwarding tools/call request for tool %s to %s", toolName, h.backendURL) + + // Store the tool call parameters for later execution + ctx.SetContext(CtxMcpProxyOperation, OpToolsCall) + ctx.SetContext(CtxMcpProxyToolName, toolName) + ctx.SetContext(CtxMcpProxyToolArgs, arguments) + if authInfo != nil { + ctx.SetContext("mcp_proxy_auth_info", authInfo) + } + + // Check if MCP is already initialized + if initialized := ctx.GetContext(CtxMcpProxyInitialized); initialized != nil { + // Already initialized, execute directly + return h.executeToolsCall(ctx) + } + + // Need to initialize first, which will execute tools/call in its callback + return h.Initialize(ctx, authInfo) +} + +// executeToolsCall executes the actual tools/call request +func (h *McpProtocolHandler) executeToolsCall(ctx wrapper.HttpContext) error { + toolName := ctx.GetContext(CtxMcpProxyToolName).(string) + arguments := ctx.GetContext(CtxMcpProxyToolArgs).(map[string]interface{}) + + callRequest := h.createToolsCallRequest(toolName, arguments) + requestBody, err := json.Marshal(callRequest) + if err != nil { + return fmt.Errorf("failed to marshal tools/call request: %v", err) + } + + headers := [][2]string{ + {"Content-Type", "application/json"}, + {"Accept", "application/json,text/event-stream"}, + } + + // Add session ID if we have one + if h.sessionID != "" { + headers = append(headers, [2]string{"Mcp-Session-Id", h.sessionID}) + } + + // Start with the original backend URL + finalURL := h.backendURL + + // Apply authentication if auth info was provided + if authInfoCtx := ctx.GetContext("mcp_proxy_auth_info"); authInfoCtx != nil { + if authInfo, ok := authInfoCtx.(*ProxyAuthInfo); ok && authInfo.SecuritySchemeID != "" { + // Apply authentication using shared utilities + modifiedURL, err := h.applyProxyAuthentication(authInfo.Server, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &headers) + if err != nil { + log.Errorf("Failed to apply authentication for proxy tool call: %v", err) + } else { + // Use the modified URL if authentication was applied successfully + finalURL = modifiedURL + log.Debugf("Using modified URL for tools/call request: %s", finalURL) + } + } + } + + // Use RouteCall for the final tools/call request with potentially modified URL + return ctx.RouteCall("POST", finalURL, headers, requestBody, func(statusCode int, responseHeaders [][2]string, responseBody []byte) { + if statusCode != 200 { + log.Errorf("Tools/call request failed with status %d: %s", statusCode, string(responseBody)) + utils.OnMCPResponseError(ctx, fmt.Errorf("backend tools/call failed"), utils.ErrInternalError, "mcp-proxy:tools/call:backend_error") + return + } + + // Determine response content type and parse accordingly + var jsonResponseBody []byte + var contentType string + + // Find content-type header + for _, header := range responseHeaders { + if strings.ToLower(header[0]) == "content-type" { + contentType = strings.ToLower(header[1]) + break + } + } + + // Parse response based on content type + if strings.Contains(contentType, "text/event-stream") { + // Handle SSE format + log.Debugf("Processing SSE response for tools/call request") + parsedJSON, err := parseSSEResponse(responseBody) + if err != nil { + log.Errorf("Failed to parse SSE response: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:tools/call:sse_parse_error") + return + } + jsonResponseBody = parsedJSON + } else { + // Handle JSON format (default) + log.Debugf("Processing JSON response for tools/call request") + jsonResponseBody = responseBody + } + + // Parse response and check for backend errors (single unmarshal) + parsedResponse, isError, errorType := ParseBackendResponse(jsonResponseBody) + if parsedResponse == nil { + log.Errorf("Failed to parse tools/call response") + utils.OnMCPResponseError(ctx, fmt.Errorf("invalid JSON response"), utils.ErrInternalError, "mcp-proxy:tools/call:parse_error") + return + } + + // Log backend errors for observability + if isError { + log.Warnf("Backend reported %s for %s", errorType, toolName) + } + + // Forward the tools/call result (pass through both success and error responses) + if result, hasResult := parsedResponse["result"]; hasResult { + if resultMap, ok := result.(map[string]interface{}); ok { + utils.OnMCPResponseSuccess(ctx, resultMap, "mcp-proxy:tools/call:success") + } else { + utils.OnMCPResponseError(ctx, fmt.Errorf("invalid tools/call result type"), utils.ErrInternalError, "mcp-proxy:tools/call:invalid_type") + } + } else if errorField, hasError := parsedResponse["error"]; hasError { + // Pass through JSON-RPC error as MCP error + if errorMap, ok := errorField.(map[string]interface{}); ok { + errorMsg := "Backend error" + if msg, hasMsg := errorMap["message"]; hasMsg { + errorMsg = fmt.Sprintf("%v", msg) + } + utils.OnMCPResponseError(ctx, fmt.Errorf("%s", errorMsg), utils.ErrInternalError, "mcp-proxy:tools/call:backend_error") + } else { + utils.OnMCPResponseError(ctx, fmt.Errorf("backend error"), utils.ErrInternalError, "mcp-proxy:tools/call:backend_error") + } + } else { + utils.OnMCPResponseError(ctx, fmt.Errorf("invalid tools/call response"), utils.ErrInternalError, "mcp-proxy:tools/call:invalid_response") + } + }) +} + +// sendMcpRequest sends an MCP request to the backend server using POST method +func (h *McpProtocolHandler) sendMcpRequest(ctx wrapper.HttpContext, body []byte, authInfo *ProxyAuthInfo, callback func(int, [][2]string, []byte)) error { + // Copy headers from current request + headers := copyHeadersForStreamableHTTP(ctx) + + // Override/ensure required headers for MCP request + ensureHeader(&headers, "Content-Type", "application/json") + ensureHeader(&headers, "Accept", "application/json,text/event-stream") + + // Add session ID if we have one + if h.sessionID != "" { + ensureHeader(&headers, "Mcp-Session-Id", h.sessionID) + } + + // Start with the original backend URL + finalURL := h.backendURL + + // Apply authentication if auth info was provided + if authInfo != nil && authInfo.SecuritySchemeID != "" { + modifiedURL, err := h.applyProxyAuthentication(authInfo.Server, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &headers) + if err != nil { + log.Errorf("Failed to apply authentication for MCP request: %v", err) + } else { + // Use the modified URL if authentication was applied successfully + finalURL = modifiedURL + log.Debugf("Using modified URL for MCP request: %s", finalURL) + } + } + + // Determine timeout + timeout := uint32(h.timeout) + if timeout == 0 { + timeout = 5000 // Default 5 seconds + } + + // Create HTTP client using RouteCluster + client := wrapper.NewClusterClient(wrapper.RouteCluster{}) + + // Convert callback to the expected format + wrappedCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + // Convert http.Header to [][2]string format + headerSlice := make([][2]string, 0, len(responseHeaders)) + for key, values := range responseHeaders { + if len(values) > 0 { + headerSlice = append(headerSlice, [2]string{key, values[0]}) + } + } + callback(statusCode, headerSlice, responseBody) + } + + // All MCP requests use POST method with potentially modified URL + return client.Post(finalURL, headers, body, wrappedCallback, timeout) +} + +// createInitializeRequest creates an MCP initialize request +func (h *McpProtocolHandler) createInitializeRequest() map[string]interface{} { + return map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "Higress-mcp-proxy", + "version": "1.0.0", + }, + }, + } +} + +// sendInitializedNotification sends the notifications/initialized message +func (h *McpProtocolHandler) sendInitializedNotification(ctx wrapper.HttpContext, authInfo *ProxyAuthInfo) { + notification := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + + requestBody, err := json.Marshal(notification) + if err != nil { + log.Errorf("Failed to marshal initialized notification: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:notifications/initialized:marshal_error") + return + } + + // Send the notification (no response expected) + err = h.sendMcpRequest(ctx, requestBody, authInfo, func(statusCode int, responseHeaders [][2]string, responseBody []byte) { + // Always resume at the end, regardless of success or failure + defer proxywasm.ResumeHttpRequest() + + if statusCode >= 300 { + log.Warnf("Initialized notification failed with status %d: %s", statusCode, string(responseBody)) + // Even if notification fails, we can still proceed with the operation + // The backend might still be functional for actual tool calls + } else { + log.Debugf("MCP initialization completed successfully") + } + + // Mark initialization as complete + ctx.SetContext(CtxMcpProxyInitialized, true) + + // Now execute the originally requested operation + operation := ctx.GetContext(CtxMcpProxyOperation) + if operation != nil { + switch operation.(McpProxyOperation) { + case OpToolsList: + if err := h.executeToolsList(ctx); err != nil { + log.Errorf("Failed to execute tools/list: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:tools/list:execution_error") + } + case OpToolsCall: + if err := h.executeToolsCall(ctx); err != nil { + log.Errorf("Failed to execute tools/call: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:tools/call:execution_error") + } + default: + log.Warnf("Unknown MCP proxy operation: %v", operation) + utils.OnMCPResponseError(ctx, fmt.Errorf("unknown operation"), utils.ErrInternalError, "mcp-proxy:unknown_operation") + } + } else { + // No pending operation, just complete the initialization + log.Debugf("MCP initialization completed, no pending operation") + } + }) + + if err != nil { + log.Errorf("Failed to send initialized notification: %v", err) + utils.OnMCPResponseError(ctx, err, utils.ErrInternalError, "mcp-proxy:notifications/initialized:send_error") + } +} + +// createToolsListRequest creates a tools/list request +func (h *McpProtocolHandler) createToolsListRequest(cursor *string) map[string]interface{} { + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": map[string]interface{}{}, + } + + if cursor != nil && *cursor != "" { + request["params"].(map[string]interface{})["cursor"] = *cursor + } + + return request +} + +// createToolsCallRequest creates a tools/call request +func (h *McpProtocolHandler) createToolsCallRequest(toolName string, arguments map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": map[string]interface{}{ + "name": toolName, + "arguments": arguments, + }, + } +} + +// ParseBackendResponse parses the response body and checks if it's a backend error +// Returns the parsed response, whether it's an error, and the error type +func ParseBackendResponse(responseBody []byte) (response map[string]interface{}, isError bool, errorType string) { + if err := json.Unmarshal(responseBody, &response); err != nil { + return nil, false, "" + } + + // Check for JSON-RPC 2.0 error format (top-level error field) + if _, hasError := response["error"]; hasError { + return response, true, "jsonrpc_error" + } + + // Check for error in result.isError format + if result, hasResult := response["result"]; hasResult { + if resultMap, ok := result.(map[string]interface{}); ok { + if isErr, hasIsError := resultMap["isError"]; hasIsError && isErr == true { + return response, true, "result_isError" + } + } + } + + return response, false, "" +} + +// IsBackendError checks if the response is a backend error (JSON-RPC 2.0 error or result.isError) +// Returns true if it's an error response, and the error type ("jsonrpc_error" or "result_isError") +func IsBackendError(responseBody []byte) (isError bool, errorType string) { + _, isError, errorType = ParseBackendResponse(responseBody) + return isError, errorType +} + +// McpSession represents a temporary MCP session +type McpSession struct { + ID string + BackendURL string + CreatedAt time.Time + LastUsed time.Time +} + +// McpSessionManagerImpl manages temporary MCP sessions +type McpSessionManagerImpl struct { + sessions map[string]*McpSession +} + +// NewMcpSessionManagerImpl creates a new session manager +func NewMcpSessionManagerImpl() *McpSessionManagerImpl { + return &McpSessionManagerImpl{ + sessions: make(map[string]*McpSession), + } +} + +// CreateSession creates a new temporary session +func (m *McpSessionManagerImpl) CreateSession(backendURL string) (string, error) { + sessionID := fmt.Sprintf("mcp-session-%d", time.Now().UnixNano()) + session := &McpSession{ + ID: sessionID, + BackendURL: backendURL, + CreatedAt: time.Now(), + LastUsed: time.Now(), + } + + m.sessions[sessionID] = session + log.Debugf("Created MCP session %s for %s", sessionID, backendURL) + + return sessionID, nil +} + +// GetSession retrieves a session by ID +func (m *McpSessionManagerImpl) GetSession(sessionID string) (*McpSession, bool) { + session, exists := m.sessions[sessionID] + if exists { + session.LastUsed = time.Now() + } + return session, exists +} + +// CleanupSession removes a session +func (m *McpSessionManagerImpl) CleanupSession(sessionID string) { + if _, exists := m.sessions[sessionID]; exists { + delete(m.sessions, sessionID) + log.Debugf("Cleaned up MCP session %s", sessionID) + } +} + +// CleanupExpiredSessions removes sessions older than specified duration +func (m *McpSessionManagerImpl) CleanupExpiredSessions(maxAge time.Duration) { + now := time.Now() + for sessionID, session := range m.sessions { + if now.Sub(session.LastUsed) > maxAge { + delete(m.sessions, sessionID) + log.Debugf("Cleaned up expired MCP session %s", sessionID) + } + } +} + +// CreateMcpProxyMethodHandlers creates JSON-RPC method handlers for MCP proxy operations +func CreateMcpProxyMethodHandlers(server *McpProxyServer, allowTools *map[string]struct{}) utils.MethodHandlers { + return utils.MethodHandlers{ + "tools/list": func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + // Check transport type + if server.GetTransport() == TransportSSE { + return handleSSEToolsList(ctx, id, params, server, allowTools) + } + + // StreamableHTTP transport (original logic) + // Extract cursor parameter if present + var cursor *string + if cursorResult := params.Get("cursor"); cursorResult.Exists() { + cursorStr := cursorResult.String() + cursor = &cursorStr + } + + // Extract allowTools from header and compute effective allowTools + allowToolsHeaderStr, _ := proxywasm.GetHttpRequestHeader("x-envoy-allow-mcp-tools") + proxywasm.RemoveHttpRequestHeader("x-envoy-allow-mcp-tools") + // Only consider header as "present" if it has non-empty value + // Empty string means header is not set or explicitly empty, both treated as "no restriction" + headerExists := allowToolsHeaderStr != "" + effectiveAllowTools := computeEffectiveAllowToolsFromHeader(allowTools, allowToolsHeaderStr, headerExists) + + // Store server reference and effective allowTools in context for callback use + ctx.SetContext("mcp_proxy_server", server) + ctx.SetContext("mcp_proxy_effective_allow_tools", effectiveAllowTools) + + // This will trigger async initialization if needed + if err := server.ForwardToolsList(ctx, cursor); err != nil { + return err + } + + // Signal that we need to pause and wait for async response + ctx.SetContext(utils.CtxNeedPause, true) + return nil + }, + "tools/call": func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error { + // Check transport type + if server.GetTransport() == TransportSSE { + return handleSSEToolsCall(ctx, id, params, server, allowTools) + } + + // StreamableHTTP transport (original logic) + // Extract tool name and arguments + toolName := params.Get("name").String() + if toolName == "" { + return fmt.Errorf("missing tool name") + } + + // Compute effective allowTools using helper function + effectiveAllowTools := computeEffectiveAllowTools(allowTools) + + // Check if tool is allowed + if effectiveAllowTools != nil { + if _, allow := (*effectiveAllowTools)[toolName]; !allow { + utils.OnMCPResponseError(ctx, fmt.Errorf("Tool not allowed: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp-proxy:%s:tools/call:tool_not_allowed", server.Name)) + return nil + } + } + + // Extract arguments (optional) + arguments := make(map[string]interface{}) + argsResult := params.Get("arguments") + if argsResult.Exists() { + if err := json.Unmarshal([]byte(argsResult.Raw), &arguments); err != nil { + return fmt.Errorf("invalid arguments: %v", err) + } + } + + // Set properties for monitoring and debugging (consistent with default handler) + proxywasm.SetProperty([]string{"mcp_server_name"}, []byte(server.Name)) + proxywasm.SetProperty([]string{"mcp_tool_name"}, []byte(toolName)) + + // Create a tool instance and call it + toolConfig, exists := server.GetToolConfig(toolName) + if !exists { + log.Warnf("tool not found: %s, will not use tool specifiy security config", toolName) + } + + // Debug logging (consistent with default handler) + log.Debugf("Tool call [%s] on server [%s] with arguments[%s]", toolName, server.Name, argsResult.Raw) + + tool := &McpProxyTool{ + serverName: server.Name, + name: toolName, + toolConfig: toolConfig, + arguments: arguments, + } + + // This will trigger async initialization if needed + err := tool.Call(ctx, server) + if err != nil { + return err + } + + // Signal that we need to pause and wait for async response + ctx.SetContext(utils.CtxNeedPause, true) + return nil + }, + } +} + +// applyAllowToolsFilter applies allowTools filtering to the tools/list response +func (h *McpProtocolHandler) applyAllowToolsFilter(ctx wrapper.HttpContext, resultMap map[string]interface{}) map[string]interface{} { + // Get pre-computed effective allowTools from context + var effectiveAllowTools *map[string]struct{} + if allowToolsCtx := ctx.GetContext("mcp_proxy_effective_allow_tools"); allowToolsCtx != nil { + if allowToolsPtr, ok := allowToolsCtx.(*map[string]struct{}); ok { + effectiveAllowTools = allowToolsPtr + } + } + + // If no restrictions, return original result + if effectiveAllowTools == nil { + return resultMap + } + + // Apply filtering to tools array + if tools, hasTools := resultMap["tools"]; hasTools { + if toolsArray, ok := tools.([]interface{}); ok { + filteredTools := make([]interface{}, 0) + + for _, tool := range toolsArray { + if toolMap, ok := tool.(map[string]interface{}); ok { + if name, hasName := toolMap["name"]; hasName { + if toolName, ok := name.(string); ok { + // Check if tool is allowed + if _, allow := (*effectiveAllowTools)[toolName]; !allow { + continue + } + // Tool is allowed, add to filtered list + filteredTools = append(filteredTools, tool) + } + } + } + } + + // Create new result with filtered tools + filteredResult := make(map[string]interface{}) + for k, v := range resultMap { + filteredResult[k] = v + } + filteredResult["tools"] = filteredTools + return filteredResult + } + } + + // If tools array not found or invalid format, return original + return resultMap +} + +// applyProxyAuthentication applies authentication to the proxy request headers and URL +func (h *McpProtocolHandler) applyProxyAuthentication(server *McpProxyServer, schemeID string, passthroughCredential string, headers *[][2]string) (string, error) { + // Parse the backend URL to create a proper URL object for the shared function + parsedURL, err := url.Parse(h.backendURL) + if err != nil { + return "", fmt.Errorf("failed to parse backend URL: %v", err) + } + + // Create authentication context + authCtx := AuthRequestContext{ + Method: "POST", + Headers: *headers, + ParsedURL: parsedURL, + RequestBody: []byte{}, // Not used for header/query auth + PassthroughCredential: passthroughCredential, + } + + // Create security config for gateway-to-backend authentication + // The passthrough credential (if any) comes from client-to-gateway authentication + securityConfig := SecurityRequirement{ + ID: schemeID, + Credential: "", // Will use passthrough credential or default credential from scheme + Passthrough: passthroughCredential != "", // Use passthrough if we have a credential + } + + // Apply authentication using shared utilities + err = ApplySecurity(securityConfig, server, &authCtx) + if err != nil { + return "", err + } + + // Update headers with authentication applied + *headers = authCtx.Headers + + // Reconstruct URL from potentially modified ParsedURL (similar to rest_server.go logic) + u := authCtx.ParsedURL + encodedPath := u.EscapedPath() + var urlStr string + if u.Scheme != "" && u.Host != "" { + urlStr = u.Scheme + "://" + u.Host + encodedPath + } else { + urlStr = "/" + strings.TrimPrefix(encodedPath, "/") + } + if u.RawQuery != "" { + urlStr += "?" + u.RawQuery + } + if u.Fragment != "" { + urlStr += "#" + u.Fragment + } + + return urlStr, nil +} + +// handleSSEToolsList handles tools/list request for SSE transport +func handleSSEToolsList(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result, server *McpProxyServer, allowTools *map[string]struct{}) error { + // Extract allowTools from header and compute effective allowTools + allowToolsHeaderStr, _ := proxywasm.GetHttpRequestHeader("x-envoy-allow-mcp-tools") + proxywasm.RemoveHttpRequestHeader("x-envoy-allow-mcp-tools") + headerExists := allowToolsHeaderStr != "" + effectiveAllowTools := computeEffectiveAllowToolsFromHeader(allowTools, allowToolsHeaderStr, headerExists) + + // Store server reference, effective allowTools, and JSON-RPC ID in context + ctx.SetContext("mcp_proxy_server", server) + ctx.SetContext("mcp_proxy_effective_allow_tools", effectiveAllowTools) + + // Prepare request body for tools/list + listRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": map[string]interface{}{}, + } + + if cursorResult := params.Get("cursor"); cursorResult.Exists() { + listRequest["params"].(map[string]interface{})["cursor"] = cursorResult.String() + } + + requestBody, err := json.Marshal(listRequest) + if err != nil { + return fmt.Errorf("failed to marshal tools/list request: %v", err) + } + + // Use common function to handle SSE request + return handleSSERequest(ctx, id, requestBody, server, server.GetDefaultDownstreamSecurity(), server.GetDefaultUpstreamSecurity()) +} + +// handleSSEToolsCall handles tools/call request for SSE transport +func handleSSEToolsCall(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result, server *McpProxyServer, allowTools *map[string]struct{}) error { + // Extract tool name and arguments + toolName := params.Get("name").String() + if toolName == "" { + return fmt.Errorf("missing tool name") + } + + // Compute effective allowTools + effectiveAllowTools := computeEffectiveAllowTools(allowTools) + + // Check if tool is allowed + if effectiveAllowTools != nil { + if _, allow := (*effectiveAllowTools)[toolName]; !allow { + utils.OnMCPResponseError(ctx, fmt.Errorf("Tool not allowed: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp-proxy:%s:tools/call:tool_not_allowed", server.Name)) + return nil + } + } + + // Store server reference in context + ctx.SetContext("mcp_proxy_server", server) + + // Extract arguments + arguments := make(map[string]interface{}) + argsResult := params.Get("arguments") + if argsResult.Exists() { + if err := json.Unmarshal([]byte(argsResult.Raw), &arguments); err != nil { + return fmt.Errorf("invalid arguments: %v", err) + } + } + + // Set properties for monitoring + proxywasm.SetProperty([]string{"mcp_server_name"}, []byte(server.Name)) + proxywasm.SetProperty([]string{"mcp_tool_name"}, []byte(toolName)) + + log.Debugf("Tool call [%s] on server [%s] with arguments[%s]", toolName, server.Name, argsResult.Raw) + + // Prepare request body for tools/call + // Use id: 2 because initialize uses id: 1, and we only send one tool request (list or call) + callRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]interface{}{ + "name": toolName, + "arguments": arguments, + }, + } + + requestBody, err := json.Marshal(callRequest) + if err != nil { + return fmt.Errorf("failed to marshal tools/call request: %v", err) + } + + // Get tool config for tool-level security + toolConfig, _ := server.GetToolConfig(toolName) + + // Determine downstream and upstream security (tool-level or server default) + var downstreamSecurity SecurityRequirement + if toolConfig.Security.ID != "" { + downstreamSecurity = toolConfig.Security + } else { + downstreamSecurity = server.GetDefaultDownstreamSecurity() + } + + var upstreamSecurity SecurityRequirement + if toolConfig.RequestTemplate.Security.ID != "" { + upstreamSecurity = toolConfig.RequestTemplate.Security + } else { + upstreamSecurity = server.GetDefaultUpstreamSecurity() + } + + // Use common function to handle SSE request + return handleSSERequest(ctx, id, requestBody, server, downstreamSecurity, upstreamSecurity) +} + +// handleSSERequest is the common function to handle SSE requests for tools/list and tools/call +func handleSSERequest(ctx wrapper.HttpContext, id utils.JsonRpcID, requestBody []byte, server *McpProxyServer, downstreamSecurity SecurityRequirement, upstreamSecurity SecurityRequirement) error { + // Store JSON-RPC ID in context + ctx.SetContext(CtxSSEProxyJsonRpcID, id) + + // Store request body in context for later use + ctx.SetContext(CtxSSEProxyRequestBody, requestBody) + + // Handle downstream security first (to extract and remove credentials before copying headers) + passthroughCredential := "" + if downstreamSecurity.ID != "" { + clientScheme, schemeOk := server.GetSecurityScheme(downstreamSecurity.ID) + if schemeOk { + extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme) + if err == nil && extractedCred != "" && downstreamSecurity.Passthrough { + passthroughCredential = extractedCred + } + } + } else { + // Fallback: Remove Authorization header if no downstream security is defined + // This prevents downstream credentials from being mistakenly passed to upstream + // Unless passthroughAuthHeader is explicitly set to true + if !server.GetPassthroughAuthHeader() { + proxywasm.RemoveHttpRequestHeader("Authorization") + } + } + + // Prepare authentication info + var authInfo *ProxyAuthInfo + if upstreamSecurity.ID != "" { + authInfo = &ProxyAuthInfo{ + SecuritySchemeID: upstreamSecurity.ID, + PassthroughCredential: passthroughCredential, + Server: server, + } + } + + // Store auth info in context (headers will be copied directly in response phase) + ctx.SetContext(CtxSSEProxyAuthInfo, authInfo) + + // Convert current request to SSE GET request + // The request will continue through the filter chain and be routed to backend + // The response will be handled by onHttpResponseHeaders and onHttpStreamingResponseBody + err := initiateSSEChannelInRequestPhase(ctx, server, authInfo) + if err != nil { + log.Errorf("Failed to convert request to SSE GET: %v", err) + return err + } + + // Explicitly set to NOT pause - let the request continue to establish SSE channel + ctx.SetContext(utils.CtxNeedPause, false) + return nil +} + +// initiateSSEChannelInRequestPhase modifies the current request to be a GET request for establishing SSE channel +func initiateSSEChannelInRequestPhase(ctx wrapper.HttpContext, server *McpProxyServer, authInfo *ProxyAuthInfo) error { + // Copy original request headers + getHeaders := copyAndCleanHeadersForSSE(ctx) + + // Apply authentication to headers and URL + finalURL := server.GetMcpServerURL() + finalHeaders := getHeaders + + if authInfo != nil && authInfo.SecuritySchemeID != "" { + modifiedURL, err := applyProxyAuthenticationForSSE(server, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, finalURL) + if err != nil { + log.Errorf("Failed to apply authentication for SSE GET: %v", err) + } else { + finalURL = modifiedURL + } + } + + // Parse the target URL + parsedURL, err := url.Parse(finalURL) + if err != nil { + return fmt.Errorf("failed to parse MCP server URL: %v", err) + } + + // Store initial state + ctx.SetContext(CtxSSEProxyState, SSEStateWaitingEndpoint) + + log.Infof("Converting request to SSE GET request for: %s", finalURL) + + // Modify the current request to be a GET request + // Replace :method pseudo-header + if err := proxywasm.ReplaceHttpRequestHeader(":method", "GET"); err != nil { + log.Warnf("Failed to replace :method header: %v", err) + } + + // Replace :path pseudo-header + path := parsedURL.Path + if parsedURL.RawQuery != "" { + path += "?" + parsedURL.RawQuery + } + if path == "" { + path = "/" + } + if err := proxywasm.ReplaceHttpRequestHeader(":path", path); err != nil { + log.Warnf("Failed to replace :path header: %v", err) + } + + // Replace :authority pseudo-header (host:port or just host) + authority := parsedURL.Host + if authority == "" { + authority = parsedURL.Hostname() + if parsedURL.Port() != "" { + authority += ":" + parsedURL.Port() + } + } + if err := proxywasm.ReplaceHttpRequestHeader(":authority", authority); err != nil { + log.Warnf("Failed to replace :authority header: %v", err) + } + + // Note: :scheme pseudo-header is managed by Envoy and should not be modified + + // Remove headers that are not appropriate for GET requests + proxywasm.RemoveHttpRequestHeader("content-type") + proxywasm.RemoveHttpRequestHeader("content-length") + proxywasm.RemoveHttpRequestHeader("transfer-encoding") + + // Set Accept header for SSE + if err := proxywasm.ReplaceHttpRequestHeader("accept", "text/event-stream"); err != nil { + log.Warnf("Failed to set Accept header: %v", err) + } + + // Apply any additional headers from authentication + for _, header := range finalHeaders { + // Skip pseudo-headers and headers already set + headerName := strings.ToLower(header[0]) + if strings.HasPrefix(headerName, ":") { + continue + } + if headerName == "accept" || headerName == "content-type" || headerName == "content-length" || headerName == "transfer-encoding" { + continue + } + if err := proxywasm.ReplaceHttpRequestHeader(header[0], header[1]); err != nil { + log.Warnf("Failed to set header %s: %v", header[0], err) + } + } + + log.Debugf("SSE GET request prepared: %s %s (authority: %s)", "GET", path, authority) + return nil +} + +// copyHeadersForStreamableHTTP copies headers from current request for StreamableHTTP requests +// This is used for initialize/notification requests in non-SSE mode +func copyHeadersForStreamableHTTP(ctx wrapper.HttpContext) [][2]string { + headers := make([][2]string, 0) + + // Headers to skip + skipHeaders := map[string]bool{ + "content-length": true, // Will be set by the client + "transfer-encoding": true, // Will be set by the client + ":path": true, // Pseudo-header, not needed + ":method": true, // Pseudo-header, not needed + ":scheme": true, // Pseudo-header, not needed + ":authority": true, // Pseudo-header, not needed + } + + // Get all request headers + headerMap, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + log.Warnf("Failed to get request headers: %v", err) + // Return minimal headers + return [][2]string{} + } + + // Copy headers, skipping unwanted ones + for _, header := range headerMap { + headerName := strings.ToLower(header[0]) + if skipHeaders[headerName] { + continue + } + headers = append(headers, header) + } + + return headers +} + +// ensureHeader ensures a header is set to a specific value, replacing if it exists +func ensureHeader(headers *[][2]string, key, value string) { + keyLower := strings.ToLower(key) + // Check if header already exists + for i, h := range *headers { + if strings.ToLower(h[0]) == keyLower { + // Replace existing header + (*headers)[i] = [2]string{key, value} + return + } + } + // Header doesn't exist, add it + *headers = append(*headers, [2]string{key, value}) +} + +// copyAndCleanHeadersForSSE copies original request headers and cleans them for SSE GET request +func copyAndCleanHeadersForSSE(ctx wrapper.HttpContext) [][2]string { + headers := make([][2]string, 0) + + // Headers to skip for GET request + skipHeaders := map[string]bool{ + "content-type": true, + "content-length": true, + "transfer-encoding": true, + "accept": true, // Will be set explicitly for SSE + ":path": true, + ":method": true, + ":scheme": true, + ":authority": true, + } + + // Get all request headers + headerMap, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + log.Warnf("Failed to get request headers: %v", err) + // Return minimal headers with Accept + return [][2]string{{"Accept", "text/event-stream"}} + } + + // Copy headers, skipping unwanted ones + for _, header := range headerMap { + headerName := strings.ToLower(header[0]) + if skipHeaders[headerName] { + continue + } + headers = append(headers, header) + } + + // Set/override Accept header for SSE + headers = append(headers, [2]string{"Accept", "text/event-stream"}) + + log.Debugf("Prepared %d headers for SSE GET request", len(headers)) + return headers +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_tools_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_tools_test.go new file mode 100644 index 000000000..3917ab0a0 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_tools_test.go @@ -0,0 +1,485 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestToolsListForwarding tests the tools/list request forwarding +func TestToolsListForwarding(t *testing.T) { + // Create proxy server with tools + server := NewMcpProxyServer("tools-list-test") + + // Set server fields directly + server.SetMcpServerURL("http://backend.example.com/mcp") + server.SetTimeout(5000) + + // Add test tools + toolConfigs := []McpProxyToolConfig{ + { + Name: "get_weather", + Description: "Get weather information", + Args: []ToolArg{ + { + Name: "location", + Description: "City name", + Type: "string", + Required: true, + }, + }, + }, + { + Name: "get_news", + Description: "Get latest news", + Args: []ToolArg{ + { + Name: "category", + Description: "News category", + Type: "string", + Required: false, + }, + }, + }, + } + + for _, toolConfig := range toolConfigs { + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + } + + // Skip HttpContext-dependent test for now - will be tested in integration + // Test that tools were added to server successfully + tools := server.GetMCPTools() + assert.Len(t, tools, 2) + assert.Contains(t, tools, "get_weather") + assert.Contains(t, tools, "get_news") +} + +// TestToolsCallForwarding tests the tools/call request forwarding +func TestToolsCallForwarding(t *testing.T) { + server := NewMcpProxyServer("tools-call-test") + + // Set server fields directly + server.SetMcpServerURL("http://backend.example.com/mcp") + server.SetTimeout(5000) + + // Add test tool + toolConfig := McpProxyToolConfig{ + Name: "test_tool", + Description: "Test tool for call forwarding", + Args: []ToolArg{ + { + Name: "input", + Description: "Input parameter", + Type: "string", + Required: true, + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + // Get the tool and create instance + tool, exists := server.GetMCPTools()["test_tool"] + require.True(t, exists) + + params := map[string]interface{}{ + "input": "test value", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Skip HttpContext-dependent test for now - will be tested in integration + // Test tool instance creation was successful + assert.NotNil(t, toolInstance) + assert.Equal(t, "test_tool", toolInstance.(*McpProxyTool).name) + assert.Equal(t, "test value", toolInstance.(*McpProxyTool).arguments["input"]) +} + +// TestToolsCallWithParameters tests tool call with various parameter types +func TestToolsCallWithParameters(t *testing.T) { + tests := []struct { + name string + toolConfig McpProxyToolConfig + params map[string]interface{} + shouldErr bool + }{ + { + name: "string parameter", + toolConfig: McpProxyToolConfig{ + Name: "string_tool", + Description: "Tool with string parameter", + Args: []ToolArg{ + { + Name: "text", + Description: "Text input", + Type: "string", + Required: true, + }, + }, + }, + params: map[string]interface{}{ + "text": "hello world", + }, + shouldErr: false, + }, + { + name: "number parameter", + toolConfig: McpProxyToolConfig{ + Name: "number_tool", + Description: "Tool with number parameter", + Args: []ToolArg{ + { + Name: "value", + Description: "Numeric value", + Type: "number", + Required: true, + }, + }, + }, + params: map[string]interface{}{ + "value": 42.5, + }, + shouldErr: false, + }, + { + name: "object parameter", + toolConfig: McpProxyToolConfig{ + Name: "object_tool", + Description: "Tool with object parameter", + Args: []ToolArg{ + { + Name: "data", + Description: "Object data", + Type: "object", + Required: true, + }, + }, + }, + params: map[string]interface{}{ + "data": map[string]interface{}{ + "key1": "value1", + "key2": 123, + }, + }, + shouldErr: false, + }, + { + name: "missing required parameter", + toolConfig: McpProxyToolConfig{ + Name: "required_tool", + Description: "Tool with required parameter", + Args: []ToolArg{ + { + Name: "required_param", + Description: "Required parameter", + Type: "string", + Required: true, + }, + }, + }, + params: map[string]interface{}{}, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMcpProxyServer("param-test") + + // Set server fields directly + server.SetMcpServerURL("http://backend.example.com/mcp") + server.SetTimeout(5000) + + err := server.AddProxyTool(tt.toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()[tt.toolConfig.Name] + require.True(t, exists) + + paramsBytes, err := json.Marshal(tt.params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Skip HttpContext-dependent test for now - will be tested in integration + // Test tool instance creation + assert.NotNil(t, toolInstance) + if !tt.shouldErr { + assert.Equal(t, tt.toolConfig.Name, toolInstance.(*McpProxyTool).name) + } + }) + } +} + +// TestToolsCallWithCursor tests tools/list with pagination cursor +func TestToolsCallWithCursor(t *testing.T) { + server := NewMcpProxyServer("cursor-test") + + // Set server fields directly + server.SetMcpServerURL("http://backend.example.com/mcp") + server.SetTimeout(5000) + + // Skip HttpContext-dependent test for now - will be tested in integration + // Test cursor parameter handling logic (basic validation) + cursor := "page-2-cursor" + assert.NotNil(t, cursor) + assert.NotEmpty(t, cursor) +} + +// TestBackendErrorHandling tests handling of backend MCP server errors +func TestBackendErrorHandling(t *testing.T) { + server := NewMcpProxyServer("error-test") + + // Set server fields directly + server.SetMcpServerURL("http://failing-backend.example.com/mcp") + server.SetTimeout(5000) + + toolConfig := McpProxyToolConfig{ + Name: "failing_tool", + Description: "Tool that will fail on backend", + Args: []ToolArg{ + { + Name: "input", + Description: "Input parameter", + Type: "string", + Required: true, + }, + }, + } + + err := server.AddProxyTool(toolConfig) + require.NoError(t, err) + + tool, exists := server.GetMCPTools()["failing_tool"] + require.True(t, exists) + + params := map[string]interface{}{ + "input": "test value", + } + paramsBytes, err := json.Marshal(params) + require.NoError(t, err) + + toolInstance := tool.Create(paramsBytes) + require.NotNil(t, toolInstance) + + // Skip HttpContext-dependent test for now - will be tested in integration + // Test tool instance creation for error scenario + assert.NotNil(t, toolInstance) + assert.Equal(t, "failing_tool", toolInstance.(*McpProxyTool).name) +} + +// TestParseSSEResponse tests the SSE response parsing functionality +func TestParseSSEResponse(t *testing.T) { + tests := []struct { + name string + sseData string + expectedData string + shouldErr bool + }{ + { + name: "valid SSE with JSON data", + sseData: `event: message +data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}} + +`, + expectedData: `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}}`, + shouldErr: false, + }, + { + name: "SSE with multiple lines", + sseData: `event: message +data: {"jsonrpc":"2.0","id":2,"result":{"success":true}} + +event: close +data: {"jsonrpc":"2.0","method":"close"} + +`, + expectedData: `{"jsonrpc":"2.0","id":2,"result":{"success":true}}`, + shouldErr: false, + }, + { + name: "SSE with comments and empty lines", + sseData: `: This is a comment +event: message + +data: {"jsonrpc":"2.0","id":3,"result":{"test":true}} + +: Another comment +`, + expectedData: `{"jsonrpc":"2.0","id":3,"result":{"test":true}}`, + shouldErr: false, + }, + { + name: "SSE with any data content", + sseData: `event: message +data: {invalid json} + +`, + expectedData: `{invalid json}`, + shouldErr: false, + }, + { + name: "SSE with no data field", + sseData: `event: message +id: 123 + +`, + shouldErr: true, + }, + { + name: "empty SSE data", + sseData: ``, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseSSEResponse([]byte(tt.sseData)) + + if tt.shouldErr { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tt.expectedData, string(result)) + } + }) + } +} + +// TestIsBackendError tests detection of backend error responses +func TestIsBackendError(t *testing.T) { + tests := []struct { + name string + response string + expectError bool + expectErrType string + }{ + { + name: "JSON-RPC 2.0 error with unknown tool", + response: `{ + "jsonrpc": "2.0", + "id": 3, + "error": { + "code": -32602, + "message": "Unknown tool: invalid_tool_name" + } + }`, + expectError: true, + expectErrType: "jsonrpc_error", + }, + { + name: "JSON-RPC 2.0 error with method not found", + response: `{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32601, + "message": "Method not found" + } + }`, + expectError: true, + expectErrType: "jsonrpc_error", + }, + { + name: "result.isError format", + response: `{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "isError": true, + "content": [ + { + "type": "text", + "text": "Tool execution failed: connection timeout" + } + ] + } + }`, + expectError: true, + expectErrType: "result_isError", + }, + { + name: "successful response with result", + response: `{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "content": [ + { + "type": "text", + "text": "Success!" + } + ] + } + }`, + expectError: false, + expectErrType: "", + }, + { + name: "successful response with isError false", + response: `{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "isError": false, + "content": [ + { + "type": "text", + "text": "Success!" + } + ] + } + }`, + expectError: false, + expectErrType: "", + }, + { + name: "invalid JSON", + response: `{invalid json}`, + expectError: false, + expectErrType: "", + }, + { + name: "empty response", + response: `{}`, + expectError: false, + expectErrType: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isError, errType := IsBackendError([]byte(tt.response)) + assert.Equal(t, tt.expectError, isError, "isError mismatch") + assert.Equal(t, tt.expectErrType, errType, "error type mismatch") + }) + } +} + +// ForwardToolsList is now implemented in proxy_server.go diff --git a/plugins/wasm-go/pkg/mcp/server/rest_server.go b/plugins/wasm-go/pkg/mcp/server/rest_server.go new file mode 100644 index 000000000..f0c769b1c --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/rest_server.go @@ -0,0 +1,1027 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package server + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + _ "time/tzdata" + + template "github.com/higress-group/gjson_template" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/sjson" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +// RestMCPConfig represents the configuration for REST MCP server +type RestMCPConfig struct { + SecuritySchemes []SecurityScheme `json:"securitySchemes,omitempty"` + DefaultDownstreamSecurity SecurityRequirement `json:"defaultDownstreamSecurity,omitempty"` // Default client-to-gateway authentication for all tools + DefaultUpstreamSecurity SecurityRequirement `json:"defaultUpstreamSecurity,omitempty"` // Default gateway-to-backend authentication for all tools +} + +// RestToolArg represents an argument for a REST tool +type RestToolArg struct { + Name string `json:"name"` + Description string `json:"description"` + Type string `json:"type,omitempty"` // JSON Schema type: string, number, integer, boolean, array, object + Required bool `json:"required,omitempty"` + Default interface{} `json:"default,omitempty"` + Enum []interface{} `json:"enum,omitempty"` + // For array type + Items interface{} `json:"items,omitempty"` + // For object type + Properties interface{} `json:"properties,omitempty"` + // Position specifies where the argument should be placed in the request + // Valid values: query, path, header, cookie, body + Position string `json:"position,omitempty"` +} + +// RestToolHeader represents an HTTP header +type RestToolHeader struct { + Key string `json:"key"` + Value string `json:"value"` +} + +// RestToolRequestTemplate defines how to construct the HTTP request +type RestToolRequestTemplate struct { + URL string `json:"url"` + Method string `json:"method"` + Headers []RestToolHeader `json:"headers"` + Body string `json:"body"` + ArgsToJsonBody bool `json:"argsToJsonBody,omitempty"` // Use args as JSON body + ArgsToUrlParam bool `json:"argsToUrlParam,omitempty"` // Add args to URL parameters + ArgsToFormBody bool `json:"argsToFormBody,omitempty"` // Use args as form-urlencoded body + Security SecurityRequirement `json:"security,omitempty"` +} + +// RestToolResponseTemplate defines how to transform the HTTP response +type RestToolResponseTemplate struct { + Body string `json:"body"` + PrependBody string `json:"prependBody,omitempty"` // Text to insert before the response body + AppendBody string `json:"appendBody,omitempty"` // Text to insert after the response body +} + +// RestTool represents a REST API that can be called as an MCP tool +type RestTool struct { + Name string `json:"name"` + Description string `json:"description"` + Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server + Args []RestToolArg `json:"args"` + OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18 + RequestTemplate RestToolRequestTemplate `json:"requestTemplate,omitempty"` + ResponseTemplate RestToolResponseTemplate `json:"responseTemplate"` + ErrorResponseTemplate string `json:"errorResponseTemplate"` + + // Parsed templates (not from JSON) + parsedURLTemplate *template.Template + parsedHeaderTemplates map[string]*template.Template + parsedBodyTemplate *template.Template + parsedResponseTemplate *template.Template + parsedErrorResponseTemplate *template.Template + + // Map of argument names to their positions + argPositions map[string]string + + // Flag to indicate if this is a direct response tool (no HTTP request) + isDirectResponseTool bool +} + +// parseIP +func parseIP(source string, fromHeader bool) string { + if fromHeader { + source = strings.Split(source, ",")[0] + } + source = strings.Trim(source, " ") + if strings.Contains(source, ".") { + // parse ipv4 + return strings.Split(source, ":")[0] + } + //parse ipv6 + if strings.Contains(source, "]") { + return strings.Split(source, "]")[0][1:] + } + return source +} + +// templateFuncs returns the template functions map +func templateFuncs() template.FuncMap { + return template.FuncMap{ + // Get IP from socket + "getSocketIP": func() string { + bs, _ := proxywasm.GetProperty([]string{"source", "address"}) + if len(bs) > 0 { + return parseIP(string(bs), false) + } + return "" + }, + // Get IP from header, fallback to socket if not available + "getRealIP": func() string { + ipStr, _ := proxywasm.GetHttpRequestHeader("x-forwarded-for") + if ipStr != "" { + return parseIP(ipStr, true) + } + // Fallback to socket IP if header is not available + bs, _ := proxywasm.GetProperty([]string{"source", "address"}) + if len(bs) > 0 { + return parseIP(string(bs), false) + } + return "" + }, + } +} + +// parseTemplates parses all templates in the tool configuration +func (t *RestTool) parseTemplates() error { + var err error + + // Check if this is a direct response tool (no RequestTemplate) + if t.RequestTemplate.URL == "" { + t.isDirectResponseTool = true + } else { + // Validate args configuration - only one of the three options can be true + argsOptionCount := 0 + if t.RequestTemplate.ArgsToJsonBody { + argsOptionCount++ + } + if t.RequestTemplate.ArgsToUrlParam { + argsOptionCount++ + } + if t.RequestTemplate.ArgsToFormBody { + argsOptionCount++ + } + if argsOptionCount > 1 { + return fmt.Errorf("only one of argsToJsonBody, argsToUrlParam, or argsToFormBody can be set to true") + } + + // Parse URL template + t.parsedURLTemplate, err = template.New("url").Funcs(templateFuncs()).Parse(t.RequestTemplate.URL) + if err != nil { + return fmt.Errorf("error parsing URL template: %v", err) + } + + // Parse header templates + t.parsedHeaderTemplates = make(map[string]*template.Template) + for i, header := range t.RequestTemplate.Headers { + if header.Key == "" { + log.Warnf("Skipping header with empty key at index %d", i) + continue + } + + tmplName := fmt.Sprintf("header_%d", i) + t.parsedHeaderTemplates[header.Key], err = template.New(tmplName).Funcs(templateFuncs()).Parse(header.Value) + if err != nil { + return fmt.Errorf("error parsing header template for %s: %v", header.Key, err) + } + } + + // Parse body template if present + if t.RequestTemplate.Body != "" { + t.parsedBodyTemplate, err = template.New("body").Funcs(templateFuncs()).Parse(t.RequestTemplate.Body) + if err != nil { + return fmt.Errorf("error parsing body template: %v", err) + } + } + } + + // Parse response template if present + if t.ResponseTemplate.Body != "" { + // Validate that PrependBody and AppendBody are not used with Body + if t.ResponseTemplate.PrependBody != "" || t.ResponseTemplate.AppendBody != "" { + return fmt.Errorf("PrependBody and AppendBody cannot be used when Body is specified") + } + + t.parsedResponseTemplate, err = template.New("response").Funcs(templateFuncs()).Parse(t.ResponseTemplate.Body) + if err != nil { + return fmt.Errorf("error parsing response template: %v", err) + } + } else if t.isDirectResponseTool { + return errors.New("direct response mode must set responseTemplate.body") + } + + // Parse error response template if present + if t.ErrorResponseTemplate != "" { + t.parsedErrorResponseTemplate, err = template.New("errorResponse").Funcs(templateFuncs()).Parse(t.ErrorResponseTemplate) + if err != nil { + return fmt.Errorf("error parsing error response template: %v", err) + } + } + + // Initialize argument positions map + t.argPositions = make(map[string]string) + for _, arg := range t.Args { + if arg.Position != "" { + t.argPositions[arg.Name] = strings.ToLower(arg.Position) + } + } + + return nil +} + +// executeTemplate executes a parsed template with the given data +func executeTemplate(tmpl *template.Template, data []byte) (string, error) { + if tmpl == nil { + return "", errors.New("template is nil") + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", err + } + + return buf.String(), nil +} + +// RestMCPServer implements Server interface for REST-to-MCP conversion +type RestMCPServer struct { + name string + base BaseMCPServer + toolsConfig map[string]RestTool // Store original tool configs for template rendering + securitySchemes map[string]SecurityScheme + defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication + defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication + passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security +} + +// NewRestMCPServer creates a new REST-to-MCP server +func NewRestMCPServer(name string) *RestMCPServer { + return &RestMCPServer{ + name: name, + base: NewBaseMCPServer(), + toolsConfig: make(map[string]RestTool), + securitySchemes: make(map[string]SecurityScheme), // Initialize the map + } +} + +// AddSecurityScheme adds a security scheme to the server's map +func (s *RestMCPServer) AddSecurityScheme(scheme SecurityScheme) { + if s.securitySchemes == nil { + s.securitySchemes = make(map[string]SecurityScheme) + } + s.securitySchemes[scheme.ID] = scheme +} + +// GetSecurityScheme retrieves a security scheme by its ID from the map +func (s *RestMCPServer) GetSecurityScheme(id string) (SecurityScheme, bool) { + scheme, ok := s.securitySchemes[id] + return scheme, ok +} + +// SetDefaultDownstreamSecurity sets the default downstream security configuration +func (s *RestMCPServer) SetDefaultDownstreamSecurity(security SecurityRequirement) { + s.defaultDownstreamSecurity = security +} + +// GetDefaultDownstreamSecurity gets the default downstream security configuration +func (s *RestMCPServer) GetDefaultDownstreamSecurity() SecurityRequirement { + return s.defaultDownstreamSecurity +} + +// SetDefaultUpstreamSecurity sets the default upstream security configuration +func (s *RestMCPServer) SetDefaultUpstreamSecurity(security SecurityRequirement) { + s.defaultUpstreamSecurity = security +} + +// GetDefaultUpstreamSecurity gets the default upstream security configuration +func (s *RestMCPServer) GetDefaultUpstreamSecurity() SecurityRequirement { + return s.defaultUpstreamSecurity +} + +// SetPassthroughAuthHeader sets the passthrough auth header flag +func (s *RestMCPServer) SetPassthroughAuthHeader(passthrough bool) { + s.passthroughAuthHeader = passthrough +} + +// GetPassthroughAuthHeader gets the passthrough auth header flag +func (s *RestMCPServer) GetPassthroughAuthHeader() bool { + return s.passthroughAuthHeader +} + +// AddMCPTool implements Server interface +func (s *RestMCPServer) AddMCPTool(name string, tool Tool) Server { + s.base.AddMCPTool(name, tool) + return s +} + +// AddRestTool adds a REST tool configuration +func (s *RestMCPServer) AddRestTool(toolConfig RestTool) error { + // Parse templates at configuration time + if err := toolConfig.parseTemplates(); err != nil { + return err + } + + s.toolsConfig[toolConfig.Name] = toolConfig + s.base.AddMCPTool(toolConfig.Name, &RestMCPTool{ + serverName: s.name, + name: toolConfig.Name, + toolConfig: toolConfig, + }) + + return nil +} + +// GetMCPTools implements Server interface +func (s *RestMCPServer) GetMCPTools() map[string]Tool { + return s.base.GetMCPTools() +} + +// SetConfig implements Server interface +func (s *RestMCPServer) SetConfig(config []byte) { + s.base.SetConfig(config) +} + +// GetConfig implements Server interface +func (s *RestMCPServer) GetConfig(v any) { + s.base.GetConfig(v) +} + +// Clone implements Server interface +func (s *RestMCPServer) Clone() Server { + newServer := &RestMCPServer{ + name: s.name, + base: s.base.CloneBase(), + toolsConfig: make(map[string]RestTool), + securitySchemes: make(map[string]SecurityScheme), // Initialize the map + } + for k, v := range s.toolsConfig { + newServer.toolsConfig[k] = v + } + // Deep copy securitySchemes + if s.securitySchemes != nil { + for k, v := range s.securitySchemes { + newServer.securitySchemes[k] = v + } + } + return newServer +} + +// GetToolConfig returns the REST tool configuration for a given tool name +func (s *RestMCPServer) GetToolConfig(name string) (RestTool, bool) { + config, ok := s.toolsConfig[name] + return config, ok +} + +// RestMCPTool implements Tool interface for REST-to-MCP +type RestMCPTool struct { + serverName string + name string + toolConfig RestTool + arguments map[string]interface{} +} + +// Create implements Tool interface +func (t *RestMCPTool) Create(params []byte) Tool { + newTool := &RestMCPTool{ + serverName: t.serverName, + name: t.name, + toolConfig: t.toolConfig, + arguments: make(map[string]interface{}), + } + + // Parse raw arguments + var rawArgs map[string]interface{} + if err := json.Unmarshal(params, &rawArgs); err != nil { + log.Warnf("Failed to parse tool arguments: %v", err) + } + + // Process arguments with type conversion + for _, arg := range t.toolConfig.Args { + // Check if argument was provided + rawValue, exists := rawArgs[arg.Name] + if !exists { + // Apply default if available + if arg.Default != nil { + newTool.arguments[arg.Name] = arg.Default + } + continue + } + + // Convert value based on type + switch arg.Type { + case "boolean": + // Convert to boolean + switch v := rawValue.(type) { + case bool: + newTool.arguments[arg.Name] = v + case string: + if v == "true" { + newTool.arguments[arg.Name] = true + } else if v == "false" { + newTool.arguments[arg.Name] = false + } else { + newTool.arguments[arg.Name] = rawValue + } + default: + newTool.arguments[arg.Name] = rawValue + } + case "integer": + // Convert to integer + switch v := rawValue.(type) { + case float64: + newTool.arguments[arg.Name] = int(v) + case string: + if intVal, err := json.Number(v).Int64(); err == nil { + newTool.arguments[arg.Name] = int(intVal) + } else { + newTool.arguments[arg.Name] = rawValue + } + default: + newTool.arguments[arg.Name] = rawValue + } + case "number": + // Convert to number (float64) + switch v := rawValue.(type) { + case string: + if floatVal, err := json.Number(v).Float64(); err == nil { + newTool.arguments[arg.Name] = floatVal + } else { + newTool.arguments[arg.Name] = rawValue + } + default: + newTool.arguments[arg.Name] = rawValue + } + default: + // For string, array, object, or unspecified types, use as is + newTool.arguments[arg.Name] = rawValue + } + } + + return newTool +} + +// convertArgToString converts an argument value to a string representation +func convertArgToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case bool, int, int64, float64: + return fmt.Sprintf("%v", v) + default: + // For complex types, try to marshal to JSON + if jsonBytes, err := json.Marshal(v); err == nil { + return string(jsonBytes) + } + return fmt.Sprintf("%v", v) + } +} + +// hasContentType checks if the headers contain a specific content type +func hasContentType(headers [][2]string, contentTypeSubstr string) bool { + for _, header := range headers { + if strings.EqualFold(header[0], "Content-Type") && strings.Contains(strings.ToLower(header[1]), contentTypeSubstr) { + return true + } + } + return false +} + +// applySecurity applies the configured security scheme to the request with fallback to default upstream security. +// It modifies reqCtx.Headers and reqCtx.ParsedURL (specifically RawQuery) in place if necessary. +func (t *RestMCPTool) applySecurity(serverObj Server, reqCtx *AuthRequestContext) error { + restServer, ok := serverObj.(*RestMCPServer) + if !ok { + return errors.New("server is not a RestMCPServer") + } + + // Determine which upstream security to use: tool-level or server's default + var upstreamSecurity SecurityRequirement + if t.toolConfig.RequestTemplate.Security.ID != "" { + // Use tool-level upstream security if configured + upstreamSecurity = t.toolConfig.RequestTemplate.Security + log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID) + } else { + // Fall back to server's default upstream security + upstreamSecurity = restServer.GetDefaultUpstreamSecurity() + if upstreamSecurity.ID != "" { + log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID) + } + } + + // Apply security using the determined configuration + return ApplySecurity(upstreamSecurity, restServer, reqCtx) +} + +// Call implements Tool interface +func (t *RestMCPTool) Call(httpCtx HttpContext, server Server) error { + ctx := httpCtx.(wrapper.HttpContext) + + // Get server instance for configuration access + restServer, ok := server.(*RestMCPServer) + if !ok { + return fmt.Errorf("server is not a RestMCPServer") + } + + // Handle tool-level or default downstream security: extract credential for passthrough if configured + // toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity + passthroughCredential := "" + var downstreamSecurity SecurityRequirement + if t.toolConfig.Security.ID != "" { + // Use tool-level security if configured + downstreamSecurity = t.toolConfig.Security + log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID) + } else { + // Fall back to server's default downstream security + downstreamSecurity = restServer.GetDefaultDownstreamSecurity() + if downstreamSecurity.ID != "" { + log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID) + } + } + + if downstreamSecurity.ID != "" { + clientScheme, schemeOk := restServer.GetSecurityScheme(downstreamSecurity.ID) + if !schemeOk { + log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name) + } else { + // Extract and remove the credential from the incoming request + extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme) + if err != nil { + log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err) + } else if extractedCred == "" { + log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID) + } + + // Only use passthrough if explicitly configured + if downstreamSecurity.Passthrough && extractedCred != "" { + passthroughCredential = extractedCred + log.Debugf("Passthrough credential set for tool %s.", t.name) + } + } + } + + var templateDataBytes []byte + // Get server config for template data if needed (but don't use for default security) + var serverConfig map[string]interface{} + restServer.GetConfig(&serverConfig) + templateDataBytes, _ = sjson.SetBytes(templateDataBytes, "config", serverConfig) + templateDataBytes, _ = sjson.SetBytes(templateDataBytes, "args", t.arguments) + + // Check if this is a direct response tool (no HTTP request needed) + if t.toolConfig.isDirectResponseTool { + // Process response directly + var result string + + // Render the response template with the arguments + templateResult, err := executeTemplate(t.toolConfig.parsedResponseTemplate, templateDataBytes) + if err != nil { + return fmt.Errorf("error executing response template: %v", err) + } + result = templateResult + + // Check if tool has outputSchema and try to parse templateResult as structured content + var structuredContent json.RawMessage + if t.toolConfig.OutputSchema != nil && len(t.toolConfig.OutputSchema) > 0 { + // For direct response tools, check if templateResult is valid JSON + if json.Valid([]byte(result)) { + structuredContent = json.RawMessage(result) + } + } + + // Send the result using structured content if available + if structuredContent != nil { + utils.SendMCPToolTextResultWithStructuredContent(ctx, result, structuredContent, fmt.Sprintf("mcp:tools/call:%s/%s:result", t.serverName, t.name)) + } else { + utils.SendMCPToolTextResult(ctx, result, fmt.Sprintf("mcp:tools/call:%s/%s:result", t.serverName, t.name)) + } + return nil + } + + // Regular REST tool with HTTP request + // Execute URL template + urlStr, err := executeTemplate(t.toolConfig.parsedURLTemplate, templateDataBytes) + if err != nil { + return fmt.Errorf("error executing URL template: %v", err) + } + + // Execute header templates from tool config + headers := make([][2]string, 0, len(t.toolConfig.RequestTemplate.Headers)) + for i, header := range t.toolConfig.RequestTemplate.Headers { + if header.Key == "" { + log.Warnf("Skipping header with empty key at index %d", i) + continue + } + tmpl, ok := t.toolConfig.parsedHeaderTemplates[header.Key] + if !ok { + return fmt.Errorf("header template not found for %s", header.Key) + } + value, err := executeTemplate(tmpl, templateDataBytes) + if err != nil { + return fmt.Errorf("error executing header template for %s: %v", header.Key, err) + } + headers = append(headers, [2]string{header.Key, value}) + } + + // Authorization or specific API key headers are handled by extractAndRemoveIncomingCredential if tool-level security is defined. + // If no tool-level security is defined, this generic RemoveHttpRequestHeader("Authorization") acts as a fallback. + // Unless passthroughAuthHeader is explicitly set to true. + if t.toolConfig.Security.ID == "" { + if !restServer.GetPassthroughAuthHeader() { + proxywasm.RemoveHttpRequestHeader("Authorization") // Remove if not handled by specific scheme + } + } + // General cleanup of Accept header from original client request. + proxywasm.RemoveHttpRequestHeader("Accept") + + // After applySecurity, urlStr, headers, and parsedURL might have been modified. + + // Categorize arguments by position + pathArgs := make(map[string]interface{}) + queryArgs := make(map[string]interface{}) + headerArgs := make(map[string]interface{}) + cookieArgs := make(map[string]interface{}) + bodyArgs := make(map[string]interface{}) + defaultArgs := make(map[string]interface{}) // Args without explicit position + + // Categorize arguments based on their position + for name, value := range t.arguments { + position, hasPosition := t.toolConfig.argPositions[name] + if !hasPosition { + defaultArgs[name] = value + continue + } + + switch position { + case "path": + pathArgs[name] = value + case "query": + queryArgs[name] = value + case "header": + headerArgs[name] = value + case "cookie": + cookieArgs[name] = value + case "body": + bodyArgs[name] = value + default: + // If position is invalid, treat as default + defaultArgs[name] = value + } + } + + // Process path parameters + for name, value := range pathArgs { + placeholder := fmt.Sprintf("{%s}", name) + // Path parameters are substituted directly into urlStr + urlStr = strings.Replace(urlStr, placeholder, convertArgToString(value), -1) + } + + // After path parameters are substituted, parse urlStr to create/update parsedURL. + // This is the primary point where parsedURL is established before query manipulations. + parsedURL, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("error parsing URL after path param substitution: %v", err) + } + + // Get existing query values + query := parsedURL.Query() + + // Add query parameters from args + for name, value := range queryArgs { + query.Set(name, convertArgToString(value)) + } + + // Process URL parameters if argsToUrlParam is true (add defaultArgs to query) + if t.toolConfig.RequestTemplate.ArgsToUrlParam { + for name, value := range defaultArgs { + query.Set(name, convertArgToString(value)) + } + } + + // Update the URL with the new query string + parsedURL.RawQuery = query.Encode() + + // Add header parameters from args + for name, value := range headerArgs { + headers = append(headers, [2]string{name, convertArgToString(value)}) + } + + // Add cookie parameters from args + for name, value := range cookieArgs { + cookie := fmt.Sprintf("%s=%s", name, convertArgToString(value)) + cookieHeaderFound := false + for i, header := range headers { + if strings.EqualFold(header[0], "Cookie") { + headers[i][1] = header[1] + "; " + cookie + cookieHeaderFound = true + break + } + } + if !cookieHeaderFound { + headers = append(headers, [2]string{"Cookie", cookie}) + } + } + + // Check for existing content types from tool config headers + hasJsonContentType := hasContentType(headers, "application/json") + hasFormContentType := hasContentType(headers, "application/x-www-form-urlencoded") + + // Prepare request body + var requestBody []byte + hasExplicitBody := t.toolConfig.parsedBodyTemplate != nil + + if hasExplicitBody { + // If explicit body template is provided, use it + body, err := executeTemplate(t.toolConfig.parsedBodyTemplate, templateDataBytes) + if err != nil { + return fmt.Errorf("error executing body template: %v", err) + } + requestBody = []byte(body) + + // Check if body is JSON and add content type if needed + trimmedBody := bytes.TrimSpace(requestBody) + if !hasJsonContentType && len(trimmedBody) > 0 && + ((trimmedBody[0] == '{' && trimmedBody[len(trimmedBody)-1] == '}') || + (trimmedBody[0] == '[' && trimmedBody[len(trimmedBody)-1] == ']')) { + // Try to parse as JSON to confirm + var js interface{} + if json.Unmarshal(trimmedBody, &js) == nil { + headers = append(headers, [2]string{"Content-Type", "application/json; charset=utf-8"}) + } + } + } else if t.toolConfig.RequestTemplate.ArgsToJsonBody { + // Combine body args and default args for JSON body + combinedArgs := make(map[string]interface{}) + for k, v := range bodyArgs { + combinedArgs[k] = v + } + for k, v := range defaultArgs { + combinedArgs[k] = v + } + + // Use args directly as JSON in the request body + argsJson, err := json.Marshal(combinedArgs) + if err != nil { + return fmt.Errorf("error marshaling args to JSON: %v", err) + } + requestBody = argsJson + + // Add JSON content type if not already present + if !hasJsonContentType { + headers = append(headers, [2]string{"Content-Type", "application/json; charset=utf-8"}) + } + } else if t.toolConfig.RequestTemplate.ArgsToFormBody { + // Use args as form-urlencoded body + formValues := url.Values{} + for name, value := range bodyArgs { + formValues.Set(name, convertArgToString(value)) + } + for name, value := range defaultArgs { + formValues.Set(name, convertArgToString(value)) + } + + requestBody = []byte(formValues.Encode()) + + // Add form content type if not already present + if !hasFormContentType { + headers = append(headers, [2]string{"Content-Type", "application/x-www-form-urlencoded"}) + } + } else if len(bodyArgs) > 0 { + // If we have body args but no explicit body handling method, + // check if there's already a form content type + if hasFormContentType { + // Format as form-urlencoded + formValues := url.Values{} + for name, value := range bodyArgs { + formValues.Set(name, convertArgToString(value)) + } + requestBody = []byte(formValues.Encode()) + } else { + // Default to JSON + argsJson, err := json.Marshal(bodyArgs) + if err != nil { + return fmt.Errorf("error marshaling body args to JSON: %v", err) + } + requestBody = argsJson + + // Add JSON content type if not already present + if !hasJsonContentType { + headers = append(headers, [2]string{"Content-Type", "application/json; charset=utf-8"}) + } + } + } + + // Ensure Accept header if not already set by tool config or args + hasAcceptHeader := false + for _, kv := range headers { + if strings.EqualFold(kv[0], "accept") { + hasAcceptHeader = true + break + } + } + if !hasAcceptHeader { + headers = append(headers, [2]string{"Accept", "*/*"}) + } + + // Apply security scheme just before making the call, after all other modifications + authReqCtx := AuthRequestContext{ + Method: t.toolConfig.RequestTemplate.Method, + Headers: headers, // Pass the current headers slice + ParsedURL: parsedURL, + RequestBody: requestBody, + PassthroughCredential: passthroughCredential, + } + if err := t.applySecurity(server, &authReqCtx); err != nil { + // Log the error and continue, rather than failing the entire call. + // The request will proceed without the intended security modifications if applySecurity failed. + log.Errorf("Failed to apply security scheme for tool %s: %v. Request will proceed with potentially incomplete authentication.", t.name, err) + } + // After applySecurity, authReqCtx.Headers and authReqCtx.ParsedURL (RawQuery) might have been modified. + // Update urlStr from the potentially modified ParsedURL. + u := authReqCtx.ParsedURL + encodedPath := u.EscapedPath() + if u.Scheme != "" && u.Host != "" { + urlStr = u.Scheme + "://" + u.Host + encodedPath + } else { + urlStr = "/" + strings.TrimPrefix(encodedPath, "/") + } + if u.RawQuery != "" { + urlStr += "?" + u.RawQuery + } + if u.Fragment != "" { + urlStr += "#" + u.Fragment + } + // Make HTTP request using potentially modified headers from authReqCtx + err = ctx.RouteCall(authReqCtx.Method, urlStr, authReqCtx.Headers, authReqCtx.RequestBody, + func(statusCode int, responseHeaders [][2]string, responseBody []byte) { + + if statusCode >= 300 || statusCode < 200 { + if t.toolConfig.parsedErrorResponseTemplate != nil { + // Error response template is provided to customize the error response result. + // Based on the responseBody, access the map-structured responseHeaders through _headers to reference their values within the errorResponseTemplate. + // Usage examples in errorResponseTemplate: + // - {{gjson "_headers.\\:status"}} -> Get HTTP status code + // - {{gjson "_headers.x-ca-error-code"}} -> Get value of header key "x-ca-error-code" + // - {{.data.value}} -> Access original responseBody content (e.g., JSON field "data.value") + errorResponseTemplateDataBytes, _ := sjson.SetBytes(responseBody, "_headers", convertHeaders(responseHeaders)) + errorTemplateResult, err := executeTemplate(t.toolConfig.parsedErrorResponseTemplate, errorResponseTemplateDataBytes) + if err != nil { + utils.OnMCPToolCallError(ctx, fmt.Errorf("error executing error response template: %v", err)) + return + } + if errorTemplateResult != "" { + utils.OnMCPToolCallError(ctx, fmt.Errorf("%s", errorTemplateResult)) + return + } + } + utils.OnMCPToolCallError(ctx, fmt.Errorf("call failed, status: %d, response: %s", statusCode, responseBody)) + return + } + + // Process response + var result string + + headerMap := convertHeaders(responseHeaders) + contentType := headerMap[strings.ToLower("Content-Type")] + // Check if the response is an image + if strings.HasPrefix(contentType, "image/") { + // Handle image response by sending it as an MCP tool result + utils.SendMCPToolImageResult(ctx, responseBody, contentType, fmt.Sprintf("mcp:tools/call:%s/%s:result", t.serverName, t.name)) + return + } + + // Case 1: Full response template is provided + if t.toolConfig.parsedResponseTemplate != nil { + templateResult, err := executeTemplate(t.toolConfig.parsedResponseTemplate, responseBody) + if err != nil { + utils.OnMCPToolCallError(ctx, fmt.Errorf("error executing response template: %v", err)) + return + } + result = templateResult + } else { + // Case 2: No template, but prepend/append might be used + rawResponse := string(responseBody) + + // Apply prepend/append if specified + if t.toolConfig.ResponseTemplate.PrependBody != "" || t.toolConfig.ResponseTemplate.AppendBody != "" { + result = t.toolConfig.ResponseTemplate.PrependBody + rawResponse + t.toolConfig.ResponseTemplate.AppendBody + } else { + // Case 3: No template and no prepend/append, just use raw response + result = rawResponse + } + } + if result == "" { + result = "success" + } + + // Check if tool has outputSchema and try to parse response as structured content + var structuredContent json.RawMessage + if t.toolConfig.OutputSchema != nil && len(t.toolConfig.OutputSchema) > 0 { + // Try to parse response as JSON for structured content + if json.Valid(responseBody) { + structuredContent = json.RawMessage(responseBody) + } + // If not valid JSON, don't force structuredContent creation + // Standard approach: use isError: true + error text (type: "text") + // Only add structuredContent when there's a structured need for errors + } + + // Send the result using structured content if available + if structuredContent != nil { + utils.SendMCPToolTextResultWithStructuredContent(ctx, result, structuredContent, fmt.Sprintf("mcp:tools/call:%s/%s:result", t.serverName, t.name)) + } else { + utils.SendMCPToolTextResult(ctx, result, fmt.Sprintf("mcp:tools/call:%s/%s:result", t.serverName, t.name)) + } + }) + if err != nil { + utils.OnMCPToolCallError(ctx, errors.New("route failed")) + log.Errorf("call api failed, err:%v", err) + } + return nil +} + +// Description implements Tool interface +func (t *RestMCPTool) Description() string { + return t.toolConfig.Description +} + +// InputSchema implements Tool interface +func (t *RestMCPTool) InputSchema() map[string]any { + // Convert tool args to JSON schema + properties := make(map[string]interface{}) + required := []string{} + + for _, arg := range t.toolConfig.Args { + argSchema := map[string]interface{}{ + "description": arg.Description, + } + + // Set type (default to string if not specified) + argType := arg.Type + if argType == "" { + argType = "string" + } + argSchema["type"] = argType + + // Add enum if specified + if arg.Enum != nil && len(arg.Enum) > 0 { + argSchema["enum"] = arg.Enum + } + + // Add default if specified + if arg.Default != nil { + argSchema["default"] = arg.Default + } + + // Add items for array type + if argType == "array" && arg.Items != nil { + argSchema["items"] = arg.Items + } + + // Add properties for object type + if argType == "object" && arg.Properties != nil { + argSchema["properties"] = arg.Properties + } + + properties[arg.Name] = argSchema + + // Add to required list if needed + if arg.Required { + required = append(required, arg.Name) + } + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + + // Add required field only if there are required properties + if len(required) > 0 { + schema["required"] = required + } + + return schema +} + +// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18) +func (t *RestMCPTool) OutputSchema() map[string]any { + return t.toolConfig.OutputSchema +} + +func convertHeaders(responseHeaders [][2]string) map[string]string { + headerMap := make(map[string]string) + for _, h := range responseHeaders { + if len(h) >= 2 { + key := h[0] + value := h[1] + headerMap[key] = value + } + } + return headerMap +} diff --git a/plugins/wasm-go/pkg/mcp/server/rest_server_test.go b/plugins/wasm-go/pkg/mcp/server/rest_server_test.go new file mode 100644 index 000000000..6f77c08b5 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/rest_server_test.go @@ -0,0 +1,922 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package server + +import ( + "encoding/json" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/tidwall/sjson" +) + +func TestConvertArgToString(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + }{ + { + name: "string value", + input: "test string", + expected: "test string", + }, + { + name: "boolean true", + input: true, + expected: "true", + }, + { + name: "boolean false", + input: false, + expected: "false", + }, + { + name: "integer", + input: 42, + expected: "42", + }, + { + name: "float", + input: 3.14, + expected: "3.14", + }, + { + name: "map", + input: map[string]interface{}{"key": "value"}, + expected: `{"key":"value"}`, + }, + { + name: "array", + input: []interface{}{1, 2, 3}, + expected: "[1,2,3]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertArgToString(tt.input) + if result != tt.expected { + t.Errorf("convertArgToString(%v) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestResponseTemplatePrependAppend(t *testing.T) { + // Test response template with PrependBody and AppendBody + sampleResponse := `{"result": "success", "data": {"name": "Test", "value": 42}}` + + tests := []struct { + name string + template RestToolResponseTemplate + expected []string + notExpected []string + }{ + { + name: "with body template only", + template: RestToolResponseTemplate{ + Body: "# Result\n- Name: {{.data.name}}\n- Value: {{.data.value}}", + }, + expected: []string{ + "# Result", + "- Name: Test", + "- Value: 42", + }, + notExpected: []string{ + "Field Descriptions:", + "End of Response", + `{"result": "success"`, + }, + }, + { + name: "with prepend only", + template: RestToolResponseTemplate{ + PrependBody: "# Field Descriptions:\n- result: Operation result\n- data: Response data\n\n", + }, + expected: []string{ + "# Field Descriptions:", + "- result: Operation result", + "- data: Response data", + `{"result": "success"`, + `"name": "Test"`, + }, + }, + { + name: "with append only", + template: RestToolResponseTemplate{ + AppendBody: "\n\n*End of Response*", + }, + expected: []string{ + `{"result": "success"`, + `"name": "Test"`, + "*End of Response*", + }, + }, + { + name: "with both prepend and append", + template: RestToolResponseTemplate{ + PrependBody: "# API Response:\n\n", + AppendBody: "\n\n*This is raw JSON data with field 'name' = Test and 'value' = 42*", + }, + expected: []string{ + "# API Response:", + `{"result": "success"`, + `"name": "Test"`, + "*This is raw JSON data with field 'name' = Test and 'value' = 42*", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a tool with the test template + // For tests with only prepend/append (no body), add a RequestTemplate.URL + // to avoid direct response mode validation + tool := RestTool{ + ResponseTemplate: tt.template, + } + if tt.template.Body == "" && (tt.template.PrependBody != "" || tt.template.AppendBody != "") { + tool.RequestTemplate.URL = "http://example.com/api" + } + + // Parse templates + err := tool.parseTemplates() + if err != nil { + t.Fatalf("Failed to parse templates: %v", err) + } + + // Simulate response processing + var result string + responseBody := []byte(sampleResponse) + + // Case 1: Full response template is provided + if tool.parsedResponseTemplate != nil { + templateResult, err := executeTemplate(tool.parsedResponseTemplate, responseBody) + if err != nil { + t.Fatalf("Failed to execute response template: %v", err) + } + result = templateResult + } else { + // Case 2: No template, but prepend/append might be used + rawResponse := string(responseBody) + + // Apply prepend/append if specified + if tool.ResponseTemplate.PrependBody != "" || tool.ResponseTemplate.AppendBody != "" { + result = tool.ResponseTemplate.PrependBody + rawResponse + tool.ResponseTemplate.AppendBody + } else { + // Case 3: No template and no prepend/append, just use raw response + result = rawResponse + } + } + + // Check that the result contains expected substrings + for _, substr := range tt.expected { + if !strings.Contains(result, substr) { + t.Errorf("Expected substring not found: %s", substr) + } + } + + // Check that the result does not contain unexpected substrings + for _, substr := range tt.notExpected { + if strings.Contains(result, substr) { + t.Errorf("Unexpected substring found: %s", substr) + } + } + }) + } +} + +func TestHasContentType(t *testing.T) { + tests := []struct { + name string + headers [][2]string + contentTypeStr string + expectedOutcome bool + }{ + { + name: "exact match", + headers: [][2]string{ + {"Content-Type", "application/json"}, + }, + contentTypeStr: "application/json", + expectedOutcome: true, + }, + { + name: "case insensitive match", + headers: [][2]string{ + {"content-type", "application/JSON"}, + }, + contentTypeStr: "application/json", + expectedOutcome: true, + }, + { + name: "substring match", + headers: [][2]string{ + {"Content-Type", "application/json; charset=utf-8"}, + }, + contentTypeStr: "application/json", + expectedOutcome: true, + }, + { + name: "no match", + headers: [][2]string{ + {"Content-Type", "text/plain"}, + }, + contentTypeStr: "application/json", + expectedOutcome: false, + }, + { + name: "header not present", + headers: [][2]string{ + {"Accept", "application/json"}, + }, + contentTypeStr: "application/json", + expectedOutcome: false, + }, + { + name: "empty headers", + headers: [][2]string{}, + contentTypeStr: "application/json", + expectedOutcome: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasContentType(tt.headers, tt.contentTypeStr) + if result != tt.expectedOutcome { + t.Errorf("hasContentType(%v, %v) = %v, want %v", tt.headers, tt.contentTypeStr, result, tt.expectedOutcome) + } + }) + } +} + +func TestRestToolValidation(t *testing.T) { + tests := []struct { + name string + tool RestTool + expectedError bool + }{ + { + name: "valid tool with no args options", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + }, + }, + expectedError: false, + }, + { + name: "valid tool with argsToJsonBody", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "POST", + ArgsToJsonBody: true, + }, + }, + expectedError: false, + }, + { + name: "valid tool with argsToUrlParam", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + ArgsToUrlParam: true, + }, + }, + expectedError: false, + }, + { + name: "valid tool with argsToFormBody", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "POST", + ArgsToFormBody: true, + }, + }, + expectedError: false, + }, + { + name: "invalid tool with multiple args options", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "POST", + ArgsToJsonBody: true, + ArgsToFormBody: true, + }, + }, + expectedError: true, + }, + { + name: "invalid tool with all args options", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "POST", + ArgsToJsonBody: true, + ArgsToUrlParam: true, + ArgsToFormBody: true, + }, + }, + expectedError: true, + }, + { + name: "invalid tool with both Body and PrependBody", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + }, + ResponseTemplate: RestToolResponseTemplate{ + Body: "# Result\n{{.data}}", + PrependBody: "# Field Descriptions:\n", + }, + }, + expectedError: true, + }, + { + name: "invalid tool with both Body and AppendBody", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + }, + ResponseTemplate: RestToolResponseTemplate{ + Body: "# Result\n{{.data}}", + AppendBody: "\n*End of response*", + }, + }, + expectedError: true, + }, + { + name: "invalid tool with Body, PrependBody, and AppendBody", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + }, + ResponseTemplate: RestToolResponseTemplate{ + Body: "# Result\n{{.data}}", + PrependBody: "# Field Descriptions:\n", + AppendBody: "\n*End of response*", + }, + }, + expectedError: true, + }, + { + name: "valid tool with PrependBody and AppendBody but no Body", + tool: RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "https://example.com", + Method: "GET", + }, + ResponseTemplate: RestToolResponseTemplate{ + PrependBody: "# Field Descriptions:\n", + AppendBody: "\n*End of response*", + }, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.tool.parseTemplates() + if (err != nil) != tt.expectedError { + t.Errorf("parseTemplates() error = %v, expectedError %v", err, tt.expectedError) + } + }) + } +} + +func TestInputSchemaWithComplexTypes(t *testing.T) { + // Create a tool with array and object type arguments + tool := RestMCPTool{ + toolConfig: RestTool{ + Args: []RestToolArg{ + { + Name: "stringArg", + Description: "A string argument", + Type: "string", + }, + { + Name: "arrayArg", + Description: "An array argument", + Type: "array", + Items: map[string]interface{}{ + "type": "string", + }, + }, + { + Name: "objectArg", + Description: "An object argument", + Type: "object", + Properties: map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "Name property", + }, + "age": map[string]interface{}{ + "type": "integer", + "description": "Age property", + }, + }, + }, + { + Name: "arrayOfObjects", + Description: "An array of objects", + Type: "array", + Items: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "string", + }, + "value": map[string]interface{}{ + "type": "number", + }, + }, + }, + }, + }, + }, + } + + schema := tool.InputSchema() + + // Check schema structure + if schema["type"] != "object" { + t.Errorf("Expected schema type to be 'object', got %v", schema["type"]) + } + + properties, ok := schema["properties"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected properties to be a map, got %T", schema["properties"]) + } + + // Check individual property types + checkProperty := func(name, expectedType string) { + prop, ok := properties[name].(map[string]interface{}) + if !ok { + t.Fatalf("Expected property %s to be a map, got %T", name, properties[name]) + } + if prop["type"] != expectedType { + t.Errorf("Expected property %s type to be '%s', got %v", name, expectedType, prop["type"]) + } + } + + checkProperty("stringArg", "string") + checkProperty("arrayArg", "array") + checkProperty("objectArg", "object") + checkProperty("arrayOfObjects", "array") + + // Check array items + arrayArg, _ := properties["arrayArg"].(map[string]interface{}) + if arrayArg["items"] == nil { + t.Errorf("Expected arrayArg to have items property") + } + + // Check object properties + objectArg, _ := properties["objectArg"].(map[string]interface{}) + if objectArg["properties"] == nil { + t.Errorf("Expected objectArg to have properties property") + } + + // Check array of objects + arrayOfObjects, _ := properties["arrayOfObjects"].(map[string]interface{}) + items, ok := arrayOfObjects["items"].(map[string]interface{}) + if !ok || items["type"] != "object" { + t.Errorf("Expected arrayOfObjects items to be of type object") + } +} + +func TestArgsToUrlParamAndFormBody(t *testing.T) { + // Test argsToUrlParam + t.Run("argsToUrlParam", func(t *testing.T) { + args := map[string]interface{}{ + "string": "value", + "int": 42, + "bool": true, + "array": []interface{}{1, 2, 3}, + "object": map[string]interface{}{"key": "value"}, + } + + // Parse URL and add parameters + baseURL := "https://example.com/api" + parsedURL, _ := url.Parse(baseURL) + query := parsedURL.Query() + + for key, value := range args { + query.Set(key, convertArgToString(value)) + } + + parsedURL.RawQuery = query.Encode() + result := parsedURL.String() + + // Verify each parameter is in the URL + for key, value := range args { + strValue := convertArgToString(value) + encodedValue := url.QueryEscape(strValue) + paramStr := key + "=" + encodedValue + + if !strings.Contains(result, paramStr) { + t.Errorf("URL parameter missing: %s", paramStr) + } + } + }) + + // Test argsToFormBody + t.Run("argsToFormBody", func(t *testing.T) { + args := map[string]interface{}{ + "string": "value", + "int": 42, + "bool": true, + "array": []interface{}{1, 2, 3}, + "object": map[string]interface{}{"key": "value"}, + } + + // Create form values + formValues := url.Values{} + for key, value := range args { + formValues.Set(key, convertArgToString(value)) + } + + formBody := formValues.Encode() + + // Verify each parameter is in the form body + for key, value := range args { + strValue := convertArgToString(value) + encodedValue := url.QueryEscape(strValue) + paramStr := key + "=" + encodedValue + + if !strings.Contains(formBody, paramStr) { + t.Errorf("Form body missing parameter: %s", paramStr) + } + } + }) +} + +func TestRestToolConfig(t *testing.T) { + // Example REST tool configuration + configJSON := ` +{ + "server": { + "name": "rest-amap-server", + "config": { + "apiKey": "xxxxx" + } + }, + "tools": [ + { + "name": "maps-geo", + "description": "å°†čÆ¦ē»†ēš„ē»“ęž„åŒ–åœ°å€č½¬ę¢äøŗē»ēŗ¬åŗ¦åę ‡ć€‚ę”ÆęŒåÆ¹åœ°ę ‡ę€§åčƒœę™ÆåŒŗć€å»ŗē­‘ē‰©åē§°č§£ęžäøŗē»ēŗ¬åŗ¦åę ‡", + "args": [ + { + "name": "address", + "description": "å¾…č§£ęžēš„ē»“ęž„åŒ–åœ°å€äæ”ęÆ", + "type": "string", + "required": true + }, + { + "name": "city", + "description": "ęŒ‡å®šęŸ„čÆ¢ēš„åŸŽåø‚", + "required": false + }, + { + "name": "output", + "description": "č¾“å‡ŗę ¼å¼", + "type": "string", + "enum": ["json", "xml"], + "default": "json" + }, + { + "name": "options", + "description": "é«˜ēŗ§é€‰é”¹", + "type": "object", + "properties": { + "extensions": { + "type": "string", + "enum": ["base", "all"] + }, + "batch": { + "type": "boolean" + } + } + }, + { + "name": "batch_addresses", + "description": "ę‰¹é‡åœ°å€", + "type": "array", + "items": { + "type": "string" + } + } + ], + "requestTemplate": { + "url": "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp", + "method": "GET", + "headers": [ + { + "key": "Content-Type", + "value": "application/json" + } + ] + }, + "responseTemplate": { + "body": "# åœ°ē†ē¼–ē äæ”ęÆ\n{{- range $index, $geo := .Geocodes }}\n## åœ°ē‚¹ {{add $index 1}}\n\n- **国家**: {{ $geo.Country }}\n- **省份**: {{ $geo.Province }}\n- **åŸŽåø‚**: {{ $geo.City }}\n- **åŸŽåø‚ä»£ē **: {{ $geo.Citycode }}\n- **区/åŽæ**: {{ $geo.District }}\n- **蔗道**: {{ $geo.Street }}\n- **é—Øē‰Œå·**: {{ $geo.Number }}\n- **č”Œę”æē¼–ē **: {{ $geo.Adcode }}\n- **åę ‡**: {{ $geo.Location }}\n- **级别**: {{ $geo.Level }}\n{{- end }}" + } + } + ] +} +` + + // Parse the config to verify it's valid JSON + var configData map[string]interface{} + err := json.Unmarshal([]byte(configJSON), &configData) + if err != nil { + t.Fatalf("Invalid JSON config: %v", err) + } + + // Example tool configuration + tool := RestTool{ + Name: "maps-geo", + Description: "å°†čÆ¦ē»†ēš„ē»“ęž„åŒ–åœ°å€č½¬ę¢äøŗē»ēŗ¬åŗ¦åę ‡ć€‚ę”ÆęŒåÆ¹åœ°ę ‡ę€§åčƒœę™ÆåŒŗć€å»ŗē­‘ē‰©åē§°č§£ęžäøŗē»ēŗ¬åŗ¦åę ‡", + Args: []RestToolArg{ + { + Name: "address", + Description: "å¾…č§£ęžēš„ē»“ęž„åŒ–åœ°å€äæ”ęÆ", + Type: "string", + Required: true, + }, + { + Name: "city", + Description: "ęŒ‡å®šęŸ„čÆ¢ēš„åŸŽåø‚", + Required: false, + }, + { + Name: "output", + Description: "č¾“å‡ŗę ¼å¼", + Type: "string", + Enum: []interface{}{"json", "xml"}, + Default: "json", + }, + { + Name: "options", + Description: "é«˜ēŗ§é€‰é”¹", + Type: "object", + Properties: map[string]interface{}{ + "extensions": map[string]interface{}{ + "type": "string", + "enum": []interface{}{"base", "all"}, + }, + "batch": map[string]interface{}{ + "type": "boolean", + }, + }, + }, + { + Name: "batch_addresses", + Description: "ę‰¹é‡åœ°å€", + Type: "array", + Items: map[string]interface{}{ + "type": "string", + }, + }, + }, + RequestTemplate: RestToolRequestTemplate{ + URL: "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp", + Method: "GET", + Headers: []RestToolHeader{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + }, + ResponseTemplate: RestToolResponseTemplate{ + Body: `# åœ°ē†ē¼–ē äæ”ęÆ +{{- range $index, $geo := .Geocodes }} +## åœ°ē‚¹ {{add $index 1}} + +- **国家**: {{ $geo.Country }} +- **省份**: {{ $geo.Province }} +- **åŸŽåø‚**: {{ $geo.City }} +- **åŸŽåø‚ä»£ē **: {{ $geo.Citycode }} +- **区/åŽæ**: {{ $geo.District }} +- **蔗道**: {{ $geo.Street }} +- **é—Øē‰Œå·**: {{ $geo.Number }} +- **č”Œę”æē¼–ē **: {{ $geo.Adcode }} +- **åę ‡**: {{ $geo.Location }} +- **级别**: {{ $geo.Level }} +{{- end }}`, + }, + } + + // Parse templates + err = tool.parseTemplates() + if err != nil { + t.Fatalf("Failed to parse templates: %v", err) + } + + var templateData []byte + templateData, _ = sjson.SetBytes(templateData, "config", map[string]interface{}{"apiKey": "test-api-key"}) + templateData, _ = sjson.SetBytes(templateData, "args", map[string]interface{}{ + "address": "åŒ—äŗ¬åø‚ęœé˜³åŒŗé˜œé€šäøœå¤§č”—6号", + "city": "åŒ—äŗ¬", + "output": "json", + }) + + // Test URL template + url, err := executeTemplate(tool.parsedURLTemplate, templateData) + if err != nil { + t.Fatalf("Failed to execute URL template: %v", err) + } + + expectedURL := "https://restapi.amap.com/v3/geocode/geo?key=test-api-key&address=åŒ—äŗ¬åø‚ęœé˜³åŒŗé˜œé€šäøœå¤§č”—6号&city=åŒ—äŗ¬&output=json&source=ts_mcp" + if url != expectedURL { + t.Errorf("URL template rendering failed. Expected: %s, Got: %s", expectedURL, url) + } + + // Test InputSchema for complex types + mcpTool := &RestMCPTool{ + toolConfig: tool, + } + + schema := mcpTool.InputSchema() + properties := schema["properties"].(map[string]interface{}) + + // Check object type + options, ok := properties["options"].(map[string]interface{}) + if !ok || options["type"] != "object" { + t.Errorf("Expected options to be of type object") + } + + // Check array type + batchAddresses, ok := properties["batch_addresses"].(map[string]interface{}) + if !ok || batchAddresses["type"] != "array" { + t.Errorf("Expected batch_addresses to be of type array") + } + + // Test response template with sample data + sampleResponse := ` + {"Geocodes": [ + { + "Country": "中国", + "Province": "åŒ—äŗ¬åø‚", + "City": "åŒ—äŗ¬åø‚", + "Citycode": "010", + "District": "ęœé˜³åŒŗ", + "Street": "é˜œé€šäøœå¤§č”—", + "Number": "6号", + "Adcode": "110105", + "Location": "116.483038,39.990633", + "Level": "é—Øē‰Œå·", + }]}` + + result, err := executeTemplate(tool.parsedResponseTemplate, []byte(sampleResponse)) + if err != nil { + t.Fatalf("Failed to execute response template: %v", err) + } + + // Just check that the result contains expected substrings + expectedSubstrings := []string{ + "# åœ°ē†ē¼–ē äæ”ęÆ", + "## åœ°ē‚¹ 1", + "**国家**: 中国", + "**省份**: åŒ—äŗ¬åø‚", + "**åę ‡**: 116.483038,39.990633", + } + + for _, substr := range expectedSubstrings { + if !strings.Contains(result, substr) { + t.Errorf("Response template rendering failed. Expected substring not found: %s", substr) + } + } +} + +// TestRestServerDefaultSecurity tests the default security configuration for REST MCP server +func TestRestServerDefaultSecurity(t *testing.T) { + server := NewRestMCPServer("test-rest-server") + + // Add security schemes + defaultScheme := SecurityScheme{ + ID: "DefaultAuth", + Type: "apiKey", + In: "header", + Name: "X-Default-Key", + DefaultCredential: "default-key", + } + toolScheme := SecurityScheme{ + ID: "ToolAuth", + Type: "apiKey", + In: "header", + Name: "X-Tool-Key", + DefaultCredential: "tool-key", + } + server.AddSecurityScheme(defaultScheme) + server.AddSecurityScheme(toolScheme) + + // Test setting default security directly on server + server.SetDefaultDownstreamSecurity(SecurityRequirement{ + ID: "DefaultAuth", + Passthrough: false, + }) + server.SetDefaultUpstreamSecurity(SecurityRequirement{ + ID: "DefaultAuth", + }) + + // Verify default security settings + retrievedDownstream := server.GetDefaultDownstreamSecurity() + assert.Equal(t, "DefaultAuth", retrievedDownstream.ID) + assert.False(t, retrievedDownstream.Passthrough) + + retrievedUpstream := server.GetDefaultUpstreamSecurity() + assert.Equal(t, "DefaultAuth", retrievedUpstream.ID) + + t.Logf("REST server default security configuration test completed successfully") +} + +// TestRestServerSecurityFallback tests the fallback mechanism from tool-level to default security +func TestRestServerSecurityFallback(t *testing.T) { + server := NewRestMCPServer("test-rest-server") + + // Add security schemes + defaultScheme := SecurityScheme{ + ID: "DefaultAuth", + Type: "apiKey", + In: "header", + Name: "X-Default-Key", + DefaultCredential: "default-key", + } + toolScheme := SecurityScheme{ + ID: "ToolAuth", + Type: "apiKey", + In: "header", + Name: "X-Tool-Key", + DefaultCredential: "tool-key", + } + server.AddSecurityScheme(defaultScheme) + server.AddSecurityScheme(toolScheme) + + // Test tool configuration with tool-level security (should use tool-level, not default) + toolConfigWithSecurity := RestTool{ + Name: "secure_tool", + Description: "Tool with its own security", + Security: SecurityRequirement{ + ID: "ToolAuth", + Passthrough: true, + }, + RequestTemplate: RestToolRequestTemplate{ + URL: "http://api.example.com/secure", + Method: "GET", + Security: SecurityRequirement{ + ID: "ToolAuth", + }, + }, + } + + // Test tool configuration without tool-level security (should fallback to default) + toolConfigWithoutSecurity := RestTool{ + Name: "fallback_tool", + Description: "Tool that falls back to default security", + // No Security field configured, should use default + RequestTemplate: RestToolRequestTemplate{ + URL: "http://api.example.com/fallback", + Method: "GET", + // No Security field configured, should use default + }, + } + + // Add tools to server + err := server.AddRestTool(toolConfigWithSecurity) + assert.NoError(t, err) + err = server.AddRestTool(toolConfigWithoutSecurity) + assert.NoError(t, err) + + // Verify tools were added + tools := server.GetMCPTools() + assert.Contains(t, tools, "secure_tool") + assert.Contains(t, tools, "fallback_tool") + + t.Logf("REST server security fallback test completed successfully") +} diff --git a/plugins/wasm-go/pkg/mcp/server/sse_proxy.go b/plugins/wasm-go/pkg/mcp/server/sse_proxy.go new file mode 100644 index 000000000..7a72d22bc --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/sse_proxy.go @@ -0,0 +1,874 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + // Context keys for SSE proxy state management + CtxSSEProxyState = "sse_proxy_state" + CtxSSEProxyEndpointURL = "sse_proxy_endpoint_url" + CtxSSEProxyBuffer = "sse_proxy_buffer" + CtxSSEProxyAuthInfo = "sse_proxy_auth_info" + CtxSSEProxyRequestBody = "sse_proxy_request_body" + CtxSSEProxyRequestID = "sse_proxy_request_id" + CtxSSEProxyFirstChunk = "sse_proxy_first_chunk" + CtxSSEProxyJsonRpcID = "sse_proxy_jsonrpc_id" + + // SSE proxy state values + SSEStateWaitingEndpoint = "waiting_endpoint" + SSEStateWaitingInitResp = "waiting_init_resp" + SSEStateWaitingNotifyResp = "waiting_notify_resp" + SSEStateWaitingToolResp = "waiting_tool_resp" + + // Buffer size limit: 100MB + MaxSSEBufferSize = 100 * 1024 * 1024 +) + +// injectSSEResponseSuccess injects a successful JSON-RPC response in streaming response body phase +func injectSSEResponseSuccess(ctx wrapper.HttpContext, result map[string]any) { + // Get JSON-RPC ID from context + jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID) + if jsonRpcIDRaw == nil { + log.Errorf("JSON-RPC ID not found in context for SSE response") + return + } + jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID) + + var body []byte + var err error + if jsonRpcID.IsString { + body, err = json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": jsonRpcID.StringValue, + "result": result, + }) + } else { + body, err = json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": jsonRpcID.IntValue, + "result": result, + }) + } + + if err != nil { + log.Errorf("Failed to marshal JSON-RPC success response: %v", err) + return + } + + proxywasm.InjectEncodedDataToFilterChain(body, true) +} + +// injectSSEResponseError injects an error JSON-RPC response in streaming response body phase +func injectSSEResponseError(ctx wrapper.HttpContext, err error, errorCode int) { + // Get JSON-RPC ID from context + jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID) + if jsonRpcIDRaw == nil { + log.Errorf("JSON-RPC ID not found in context for SSE error response") + return + } + jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID) + + var body []byte + var marshalErr error + if jsonRpcID.IsString { + body, marshalErr = json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": jsonRpcID.StringValue, + "error": map[string]interface{}{ + "code": errorCode, + "message": err.Error(), + }, + }) + } else { + body, marshalErr = json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": jsonRpcID.IntValue, + "error": map[string]interface{}{ + "code": errorCode, + "message": err.Error(), + }, + }) + } + + if marshalErr != nil { + log.Errorf("Failed to marshal JSON-RPC error response: %v", marshalErr) + return + } + + proxywasm.InjectEncodedDataToFilterChain(body, true) +} + +// SSEMessage represents a parsed SSE message +type SSEMessage struct { + Event string + Data string + ID string +} + +// ParseSSEMessage parses SSE format data and returns complete messages +// Returns the parsed message and the remaining unparsed data +func ParseSSEMessage(data []byte) (*SSEMessage, []byte, error) { + scanner := bufio.NewScanner(bytes.NewReader(data)) + // Set max token size to 32MB to handle large messages + maxTokenSize := 32 * 1024 * 1024 // 32MB + scanner.Buffer(make([]byte, 0, 64*1024), maxTokenSize) + msg := &SSEMessage{} + lineCount := 0 + lastPos := 0 + + for scanner.Scan() { + line := scanner.Text() + lineCount++ + lastPos += len(line) + 1 // +1 for newline + + // Empty line indicates end of message + if strings.TrimSpace(line) == "" { + if msg.Event != "" || msg.Data != "" || msg.ID != "" { + // Found a complete message + return msg, data[lastPos:], nil + } + // Empty message, continue + continue + } + + // Skip comment lines (lines starting with ':') + if strings.HasPrefix(line, ":") { + continue + } + + // Parse field + parts := strings.SplitN(line, ":", 2) + if len(parts) < 2 { + continue + } + + field := parts[0] + value := strings.TrimSpace(parts[1]) + + switch field { + case "event": + msg.Event = value + case "data": + if msg.Data != "" { + msg.Data += "\n" + value + } else { + msg.Data = value + } + case "id": + msg.ID = value + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + return nil, nil, fmt.Errorf("SSE message line exceeds maximum token size (32MB): %w", err) + } + return nil, nil, fmt.Errorf("error scanning SSE data: %v", err) + } + + // No complete message found, return all data as remaining + return nil, data, nil +} + +// ExtractEndpointURL extracts the endpoint URL from an SSE endpoint message +// It handles two cases: +// 1. endpointData is a full URL (e.g., http://example.com/sse) - return as-is +// 2. endpointData is a path - if baseURL has scheme and host, combine them; otherwise return the path as-is +func ExtractEndpointURL(endpointData string, baseURL string) (string, error) { + // Case 1: endpointData is a full URL + if strings.HasPrefix(endpointData, "http://") || strings.HasPrefix(endpointData, "https://") { + return endpointData, nil + } + + // endpointData is a path + parsedBase, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("failed to parse base URL: %v", err) + } + + // Case 2: baseURL has scheme and host, combine them + if parsedBase.Scheme != "" && parsedBase.Host != "" { + // Combine scheme, host, and the new path + // Ensure endpointData starts with "/" + if !strings.HasPrefix(endpointData, "/") { + endpointData = "/" + endpointData + } + result := parsedBase.Scheme + "://" + parsedBase.Host + endpointData + return result, nil + } + + // Case 3: baseURL is also just a path, return endpointData as-is + return endpointData, nil +} + +// sendSSEInitialize sends the initialize request for SSE protocol +func sendSSEInitialize(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error { + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]interface{}{ + "roots": map[string]interface{}{ + "listChanged": true, + }, + "sampling": map[string]interface{}{}, + "elicitation": map[string]interface{}{}, + }, + "clientInfo": map[string]interface{}{ + "name": "Higress-mcp-proxy", + "title": "Higress MCP Proxy", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + return fmt.Errorf("failed to marshal initialize request: %v", err) + } + + // Copy headers from current request (now supported in response phase by Envoy) + finalHeaders := copyHeadersForSSERequest(ctx) + + // Override required headers for SSE initialize + ensureHeader(&finalHeaders, "Content-Type", "application/json") + + // Apply authentication to headers and URL + finalURL := endpointURL + if authInfo != nil && authInfo.SecuritySchemeID != "" { + modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL) + if err != nil { + log.Errorf("Failed to apply authentication for SSE initialize: %v", err) + } else { + finalURL = modifiedURL + } + } + + // Note: headers are already copied from the current request (which has server-level headers applied) + // via copyHeadersForSSERequest, so no need to apply them again + + // Store state for tracking + ctx.SetContext(CtxSSEProxyState, SSEStateWaitingInitResp) + ctx.SetContext(CtxSSEProxyRequestID, 1) + + // Use RouteCluster client to send initialize request + client := wrapper.NewClusterClient(wrapper.RouteCluster{}) + timeout := uint32(proxyServer.GetTimeout()) + if timeout == 0 { + timeout = 5000 // Default 5 seconds + } + + return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 && statusCode != 202 { + log.Errorf("SSE initialize request failed with status %d: %s", statusCode, string(responseBody)) + // At this point, we're in streaming response phase, must use injectSSEResponseError + injectSSEResponseError(ctx, fmt.Errorf("SSE initialize failed with status %d", statusCode), utils.ErrInternalError) + return + } + + log.Debugf("SSE initialize request sent successfully") + // The response will be received through SSE channel and processed in streaming response handler + // State has already been set to SSEStateWaitingInitResp before this POST request + // No need to change state here + }, timeout) +} + +// sendSSENotification sends the notifications/initialized message for SSE protocol +func sendSSENotification(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error { + notification := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + + requestBody, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %v", err) + } + + // Copy headers from current request (now supported in response phase by Envoy) + finalHeaders := copyHeadersForSSERequest(ctx) + + // Override required headers for SSE notification + ensureHeader(&finalHeaders, "Content-Type", "application/json") + + // Apply authentication to headers and URL + finalURL := endpointURL + if authInfo != nil && authInfo.SecuritySchemeID != "" { + modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL) + if err != nil { + log.Errorf("Failed to apply authentication for SSE notification: %v", err) + } else { + finalURL = modifiedURL + } + } + + // Note: headers are already copied from the current request (which has server-level headers applied) + // via copyHeadersForSSERequest, so no need to apply them again + + // Store state for tracking + ctx.SetContext(CtxSSEProxyState, SSEStateWaitingNotifyResp) + + // Use RouteCluster client to send notification + client := wrapper.NewClusterClient(wrapper.RouteCluster{}) + timeout := uint32(proxyServer.GetTimeout()) + if timeout == 0 { + timeout = 5000 // Default 5 seconds + } + + return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 && statusCode != 202 { + log.Warnf("SSE notification request failed with status %d: %s", statusCode, string(responseBody)) + // Even if notification fails, we should try to continue + // Some servers may not strictly require notification success + } + + log.Debugf("SSE notification sent successfully") + + // Now we can send the actual tool request + // Get stored context + endpointURLRaw := ctx.GetContext(CtxSSEProxyEndpointURL) + authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo) + proxyServerRaw := ctx.GetContext("mcp_proxy_server") + requestBodyRaw := ctx.GetContext(CtxSSEProxyRequestBody) + + if endpointURLRaw == nil || proxyServerRaw == nil || requestBodyRaw == nil { + log.Errorf("Missing context for sending tool request") + // At this point, we're in streaming response phase, must use injectSSEResponseError + injectSSEResponseError(ctx, fmt.Errorf("internal error: missing context"), utils.ErrInternalError) + return + } + + endpointURL := endpointURLRaw.(string) + proxyServer := proxyServerRaw.(*McpProxyServer) + requestBody := requestBodyRaw.([]byte) + + var authInfo *ProxyAuthInfo + if authInfoRaw != nil { + authInfo = authInfoRaw.(*ProxyAuthInfo) + } + + // Parse to get request ID + reqID := gjson.GetBytes(requestBody, "id").Int() + if err := sendSSEToolRequest(ctx, endpointURL, authInfo, proxyServer, requestBody, int(reqID)); err != nil { + log.Errorf("Failed to send SSE tool request: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + } + }, timeout) +} + +// sendSSEToolRequest sends the tools/list or tools/call request for SSE protocol +func sendSSEToolRequest(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer, requestBody []byte, requestID int) error { + // Copy headers from current request (now supported in response phase by Envoy) + finalHeaders := copyHeadersForSSERequest(ctx) + + // Override required headers for SSE tool request + ensureHeader(&finalHeaders, "Content-Type", "application/json") + + // Apply authentication to headers and URL + finalURL := endpointURL + if authInfo != nil && authInfo.SecuritySchemeID != "" { + modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL) + if err != nil { + log.Errorf("Failed to apply authentication for SSE tool request: %v", err) + } else { + finalURL = modifiedURL + } + } + + // Note: headers are already copied from the current request (which has server-level headers applied) + // via copyHeadersForSSERequest, so no need to apply them again + + // Store state for tracking + ctx.SetContext(CtxSSEProxyState, SSEStateWaitingToolResp) + ctx.SetContext(CtxSSEProxyRequestID, requestID) + + // Use RouteCluster client to send tool request + client := wrapper.NewClusterClient(wrapper.RouteCluster{}) + timeout := uint32(proxyServer.GetTimeout()) + if timeout == 0 { + timeout = 5000 // Default 5 seconds + } + + return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 && statusCode != 202 { + log.Errorf("SSE tool request failed with status %d: %s", statusCode, string(responseBody)) + // At this point, we're in streaming response phase, must use injectSSEResponseError + injectSSEResponseError(ctx, fmt.Errorf("SSE tool request failed with status %d", statusCode), utils.ErrInternalError) + return + } + + log.Debugf("SSE tool request sent successfully") + // The response will be received through SSE channel and processed in streaming response handler + }, timeout) +} + +// copyHeadersForSSERequest copies headers from current request for SSE RouteCluster calls +// This leverages Envoy's new capability to access request headers in response phase +func copyHeadersForSSERequest(ctx wrapper.HttpContext) [][2]string { + headers := make([][2]string, 0) + + // Headers to skip + skipHeaders := map[string]bool{ + "content-length": true, // Will be set by the client + "transfer-encoding": true, // Will be set by the client + "accept": true, // Will be set explicitly for SSE requests + ":path": true, // Pseudo-header, not needed + ":method": true, // Pseudo-header, not needed + ":scheme": true, // Pseudo-header, not needed + ":authority": true, // Pseudo-header, not needed + } + + // Get all request headers (now supported in response phase by Envoy) + headerMap, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + log.Warnf("Failed to get request headers in response phase: %v", err) + // Return minimal headers + return [][2]string{} + } + + // Copy headers, skipping unwanted ones + for _, header := range headerMap { + headerName := strings.ToLower(header[0]) + if skipHeaders[headerName] { + continue + } + headers = append(headers, header) + } + + log.Debugf("Copied %d headers from request in response phase for SSE", len(headers)) + return headers +} + +// applyProxyAuthenticationForSSE applies authentication for SSE proxy requests +func applyProxyAuthenticationForSSE(server *McpProxyServer, schemeID string, passthroughCredential string, headers *[][2]string, targetURL string) (string, error) { + // Parse the target URL + parsedURL, err := url.Parse(targetURL) + if err != nil { + return "", fmt.Errorf("failed to parse target URL: %v", err) + } + + // Create authentication context + authCtx := AuthRequestContext{ + Method: "POST", + Headers: *headers, + ParsedURL: parsedURL, + RequestBody: []byte{}, + PassthroughCredential: passthroughCredential, + } + + // Create security config + securityConfig := SecurityRequirement{ + ID: schemeID, + Credential: "", + Passthrough: passthroughCredential != "", + } + + // Apply authentication + err = ApplySecurity(securityConfig, server, &authCtx) + if err != nil { + return "", err + } + + // Update headers + *headers = authCtx.Headers + + // Reconstruct URL + u := authCtx.ParsedURL + encodedPath := u.EscapedPath() + var urlStr string + if u.Scheme != "" && u.Host != "" { + urlStr = u.Scheme + "://" + u.Host + encodedPath + } else { + urlStr = "/" + strings.TrimPrefix(encodedPath, "/") + } + if u.RawQuery != "" { + urlStr += "?" + u.RawQuery + } + if u.Fragment != "" { + urlStr += "#" + u.Fragment + } + + return urlStr, nil +} + +// handleSSEStreamingResponse handles the streaming SSE response +func handleSSEStreamingResponse(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte { + // Get the first chunk flag + isFirstChunk := ctx.GetBoolContext(CtxSSEProxyFirstChunk, true) + if isFirstChunk { + ctx.SetContext(CtxSSEProxyFirstChunk, false) + } + log.Debugf("Handling chunk of SSE response, data: %q", string(data)) + // On first chunk, validate content-type and modify headers + if isFirstChunk { + // Validate that backend returned text/event-stream + contentType, err := proxywasm.GetHttpResponseHeader("content-type") + if err != nil || !strings.Contains(strings.ToLower(contentType), "text/event-stream") { + log.Errorf("Backend did not return text/event-stream content-type, got: %s", contentType) + // Return JSON-RPC error + injectSSEResponseError(ctx, fmt.Errorf("invalid content-type, expected text/event-stream but got: %s", contentType), utils.ErrInternalError) + return []byte{} + } + + // Remove content-length and modify content-type + proxywasm.RemoveHttpResponseHeader("content-length") + proxywasm.ReplaceHttpResponseHeader("content-type", "application/json; charset=utf-8") + proxywasm.ReplaceHttpResponseHeader(":status", "200") + } + + // Get or initialize buffer + var buffer []byte + if bufferRaw := ctx.GetContext(CtxSSEProxyBuffer); bufferRaw != nil { + buffer = bufferRaw.([]byte) + } + + // Append new data to buffer + buffer = append(buffer, data...) + + // Check buffer size limit + if len(buffer) > MaxSSEBufferSize { + log.Errorf("SSE buffer exceeded maximum size of %d bytes", MaxSSEBufferSize) + injectSSEResponseError(ctx, errors.New("response too large"), utils.ErrInternalError) + return []byte{} + } + + // Store buffer back + ctx.SetContext(CtxSSEProxyBuffer, buffer) + + // Get current state + state := ctx.GetContext(CtxSSEProxyState) + if state == nil { + state = SSEStateWaitingEndpoint + ctx.SetContext(CtxSSEProxyState, state) + } + + log.Debugf("SSE proxy state: %s, now buffering data: %q", state.(string), string(buffer)) + + // Process based on state + switch state.(string) { + case SSEStateWaitingEndpoint: + return handleWaitingEndpoint(ctx, config, &buffer) + + case SSEStateWaitingInitResp: + return handleWaitingInitResp(ctx, config, &buffer) + + case SSEStateWaitingNotifyResp: + return handleWaitingNotifyResp(ctx, config, &buffer) + + case SSEStateWaitingToolResp: + return handleWaitingToolResp(ctx, config, &buffer) + + default: + log.Warnf("Unknown SSE proxy state: %v", state) + return []byte{} + } +} + +// handleWaitingEndpoint processes SSE messages waiting for endpoint message +func handleWaitingEndpoint(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte { + for { + msg, remaining, err := ParseSSEMessage(*buffer) + if err != nil { + log.Errorf("Failed to parse SSE message: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + if msg == nil { + // No complete message yet + *buffer = remaining + return []byte{} + } + + // Update buffer + *buffer = remaining + ctx.SetContext(CtxSSEProxyBuffer, *buffer) + + // Check for endpoint message + if msg.Event == "endpoint" { + // Extract and store endpoint URL + proxyServerRaw := ctx.GetContext("mcp_proxy_server") + if proxyServerRaw == nil { + log.Errorf("mcp_proxy_server not found in context") + injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError) + return []byte{} + } + proxyServer := proxyServerRaw.(*McpProxyServer) + + endpointURL, err := ExtractEndpointURL(msg.Data, proxyServer.GetMcpServerURL()) + if err != nil { + log.Errorf("Failed to extract endpoint URL: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + log.Infof("Received SSE endpoint URL: %s", endpointURL) + ctx.SetContext(CtxSSEProxyEndpointURL, endpointURL) + + // Get stored auth info + authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo) + + var authInfo *ProxyAuthInfo + if authInfoRaw != nil { + authInfo = authInfoRaw.(*ProxyAuthInfo) + } + + // Send initialize request + if err := sendSSEInitialize(ctx, endpointURL, authInfo, proxyServer); err != nil { + log.Errorf("Failed to send SSE initialize: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + // State has been changed to SSEStateWaitingInitResp in sendSSEInitialize + // Return immediately to allow next chunk to be processed in the new state + return []byte{} + } + + // Skip other message types (like ping) while waiting for endpoint + // Continue to process next message in buffer + log.Debugf("Skipping SSE message with event '%s' while waiting for endpoint", msg.Event) + continue + } +} + +// handleWaitingInitResp processes SSE messages waiting for initialize response +func handleWaitingInitResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte { + requestID := ctx.GetContext(CtxSSEProxyRequestID) + if requestID == nil { + requestID = 1 + } + + for { + msg, remaining, err := ParseSSEMessage(*buffer) + if err != nil { + log.Errorf("Failed to parse SSE message: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + if msg == nil { + // No complete message yet + *buffer = remaining + return []byte{} + } + + // Update buffer + *buffer = remaining + ctx.SetContext(CtxSSEProxyBuffer, *buffer) + + // Check for message event + if msg.Event == "message" { + // Parse JSON-RPC response + var jsonRpcResp map[string]interface{} + if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil { + log.Errorf("Failed to parse JSON-RPC response: %v", err) + continue + } + + // Check if this is the initialize response + respID := jsonRpcResp["id"] + if respID != nil { + var idMatch bool + switch v := respID.(type) { + case float64: + idMatch = int(v) == requestID.(int) + case int: + idMatch = v == requestID.(int) + } + + if idMatch { + // Check for errors + if errorObj, hasError := jsonRpcResp["error"]; hasError { + log.Errorf("Backend initialize error: %v", errorObj) + injectSSEResponseError(ctx, fmt.Errorf("backend initialize failed"), utils.ErrInternalError) + return []byte{} + } + + log.Debugf("Received initialize response, sending notification") + + // Get endpoint URL and auth info + endpointURL := ctx.GetContext(CtxSSEProxyEndpointURL).(string) + authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo) + proxyServerRaw := ctx.GetContext("mcp_proxy_server") + + var authInfo *ProxyAuthInfo + if authInfoRaw != nil { + authInfo = authInfoRaw.(*ProxyAuthInfo) + } + + proxyServer := proxyServerRaw.(*McpProxyServer) + + // Send notification + // The notification callback will send the tool request after notification succeeds + if err := sendSSENotification(ctx, endpointURL, authInfo, proxyServer); err != nil { + log.Errorf("Failed to send SSE notification: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + // State has been changed to SSEStateWaitingNotifyResp in sendSSENotification + // The tool request will be sent in the notification callback + // Return immediately to allow next chunk to be processed in the new state + return []byte{} + } + } + } + + // Skip other message types (like ping) while waiting for init response + // Continue to process next message in buffer + log.Debugf("Skipping SSE message with event '%s' while waiting for init response", msg.Event) + continue + } +} + +// handleWaitingNotifyResp processes SSE messages waiting for notification response +func handleWaitingNotifyResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte { + // For notifications, we don't expect a response in SSE channel + // Just continue to send tool request + // This state should be very brief + return []byte{} +} + +// handleWaitingToolResp processes SSE messages waiting for tool response +func handleWaitingToolResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte { + requestID := ctx.GetContext(CtxSSEProxyRequestID) + if requestID == nil { + log.Errorf("Request ID not found in context") + injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError) + return []byte{} + } + + for { + msg, remaining, err := ParseSSEMessage(*buffer) + if err != nil { + log.Errorf("Failed to parse SSE message: %v", err) + injectSSEResponseError(ctx, err, utils.ErrInternalError) + return []byte{} + } + + if msg == nil { + // No complete message yet + *buffer = remaining + return []byte{} + } + + // Update buffer + *buffer = remaining + ctx.SetContext(CtxSSEProxyBuffer, *buffer) + + // Check for message event + if msg.Event == "message" { + // Parse JSON-RPC response + var jsonRpcResp map[string]interface{} + if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil { + log.Errorf("Failed to parse JSON-RPC response: %v", err) + continue + } + + // Check if this is the expected response + respID := jsonRpcResp["id"] + if respID != nil { + var idMatch bool + switch v := respID.(type) { + case float64: + idMatch = int(v) == requestID.(int) + case int: + idMatch = v == requestID.(int) + } + + if idMatch { + // Check for errors + if errorObj, hasError := jsonRpcResp["error"]; hasError { + log.Errorf("Backend tool error: %v", errorObj) + injectSSEResponseError(ctx, fmt.Errorf("backend tool call failed"), utils.ErrInternalError) + return []byte{} + } + + // Extract result and return to client + if result, hasResult := jsonRpcResp["result"]; hasResult { + if resultMap, ok := result.(map[string]interface{}); ok { + // Apply allowTools filtering if this is a tools/list response + filteredResult := resultMap + if _, hasTools := resultMap["tools"]; hasTools { + // Get pre-computed effective allowTools from context + if allowToolsCtx := ctx.GetContext("mcp_proxy_effective_allow_tools"); allowToolsCtx != nil { + if effectiveAllowTools, ok := allowToolsCtx.(*map[string]struct{}); ok && effectiveAllowTools != nil { + // Apply filtering + if tools, hasToolsArray := resultMap["tools"]; hasToolsArray { + if toolsArray, ok := tools.([]interface{}); ok { + filteredTools := make([]interface{}, 0) + for _, tool := range toolsArray { + if toolMap, ok := tool.(map[string]interface{}); ok { + if name, hasName := toolMap["name"]; hasName { + if toolName, ok := name.(string); ok { + if _, allow := (*effectiveAllowTools)[toolName]; allow { + filteredTools = append(filteredTools, tool) + } + } + } + } + } + // Create filtered result + filteredResult = make(map[string]interface{}) + for k, v := range resultMap { + filteredResult[k] = v + } + filteredResult["tools"] = filteredTools + } + } + } + } + } + + injectSSEResponseSuccess(ctx, filteredResult) + // Clear buffer as we've processed the response + *buffer = []byte{} + ctx.SetContext(CtxSSEProxyBuffer, *buffer) + return []byte{} + } + } + + log.Errorf("Invalid tool response format") + injectSSEResponseError(ctx, errors.New("invalid response format"), utils.ErrInternalError) + return []byte{} + } + } + } + + // Skip other message types (like ping) while waiting for tool response + // Continue to process next message in buffer + log.Debugf("Skipping SSE message with event '%s' while waiting for tool response", msg.Event) + continue + } +} diff --git a/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go b/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go new file mode 100644 index 000000000..7994ec525 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go @@ -0,0 +1,297 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "testing" +) + +// TestParseSSEMessage tests SSE message parsing +func TestParseSSEMessage(t *testing.T) { + tests := []struct { + name string + input []byte + wantEvent string + wantData string + wantID string + shouldParse bool + }{ + { + name: "endpoint message", + input: []byte(`event: endpoint +data: /messages/?session_id=test123 + +`), + wantEvent: "endpoint", + wantData: "/messages/?session_id=test123", + shouldParse: true, + }, + { + name: "message with JSON data", + input: []byte(`event: message +data: {"jsonrpc":"2.0","id":1,"result":{"test":"value"}} + +`), + wantEvent: "message", + wantData: `{"jsonrpc":"2.0","id":1,"result":{"test":"value"}}`, + shouldParse: true, + }, + { + name: "incomplete message", + input: []byte(`event: message +data: {"jsonrpc":"2.0"`), + shouldParse: false, + }, + { + name: "message with id", + input: []byte(`id: 123 +event: message +data: test data + +`), + wantEvent: "message", + wantData: "test data", + wantID: "123", + shouldParse: true, + }, + { + name: "comment line ignored", + input: []byte(`: this is a comment +event: message +data: test data + +`), + wantEvent: "message", + wantData: "test data", + shouldParse: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, remaining, err := ParseSSEMessage(tt.input) + + if err != nil { + t.Fatalf("parseSSEMessage() error = %v", err) + } + + if tt.shouldParse { + if msg == nil { + t.Errorf("parseSSEMessage() expected message but got nil") + return + } + if msg.Event != tt.wantEvent { + t.Errorf("parseSSEMessage() Event = %v, want %v", msg.Event, tt.wantEvent) + } + if msg.Data != tt.wantData { + t.Errorf("parseSSEMessage() Data = %v, want %v", msg.Data, tt.wantData) + } + if msg.ID != tt.wantID { + t.Errorf("parseSSEMessage() ID = %v, want %v", msg.ID, tt.wantID) + } + if len(remaining) != 0 { + t.Errorf("parseSSEMessage() expected no remaining bytes, got %d bytes", len(remaining)) + } + } else { + if msg != nil { + t.Errorf("parseSSEMessage() expected no message but got %v", msg) + } + if len(remaining) != len(tt.input) { + t.Errorf("parseSSEMessage() expected all data as remaining, got %d bytes instead of %d", len(remaining), len(tt.input)) + } + } + }) + } +} + +// TestExtractEndpointURL tests endpoint URL extraction +func TestExtractEndpointURL(t *testing.T) { + tests := []struct { + name string + endpointData string + baseURL string + want string + wantErr bool + }{ + { + name: "full URL", + endpointData: "http://example.com/messages?session=123", + baseURL: "http://backend.com/mcp", + want: "http://example.com/messages?session=123", + wantErr: false, + }, + { + name: "path only", + endpointData: "/messages/?session_id=abc", + baseURL: "http://backend.com/mcp", + want: "http://backend.com/messages/?session_id=abc", + wantErr: false, + }, + { + name: "https base URL", + endpointData: "/sse/endpoint", + baseURL: "https://secure.backend.com:8443/api", + want: "https://secure.backend.com:8443/sse/endpoint", + wantErr: false, + }, + { + name: "path-only base URL", + endpointData: "/messages", + baseURL: "/api/v1", + want: "/messages", + wantErr: false, + }, + { + name: "path without leading slash", + endpointData: "api/v1/messages", + baseURL: "http://backend.com", + want: "http://backend.com/api/v1/messages", + wantErr: false, + }, + { + name: "path without leading slash with port", + endpointData: "sse/endpoint", + baseURL: "https://secure.backend.com:8443", + want: "https://secure.backend.com:8443/sse/endpoint", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ExtractEndpointURL(tt.endpointData, tt.baseURL) + if (err != nil) != tt.wantErr { + t.Errorf("extractEndpointURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("extractEndpointURL() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestTransportProtocolValidation tests transport protocol validation +func TestTransportProtocolValidation(t *testing.T) { + tests := []struct { + name string + transport string + wantValid bool + }{ + { + name: "valid http transport", + transport: "http", + wantValid: true, + }, + { + name: "valid sse transport", + transport: "sse", + wantValid: true, + }, + { + name: "invalid transport", + transport: "websocket", + wantValid: false, + }, + { + name: "empty transport", + transport: "", + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := TransportProtocol(tt.transport) + isValid := transport == TransportHTTP || transport == TransportSSE + if isValid != tt.wantValid { + t.Errorf("TransportProtocol validation = %v, want %v for %s", isValid, tt.wantValid, tt.transport) + } + }) + } +} + +// TestMcpProxyServerTransport tests transport getter/setter +func TestMcpProxyServerTransport(t *testing.T) { + server := NewMcpProxyServer("test-server") + + // Test default transport + if server.GetTransport() != "" { + t.Errorf("Expected empty default transport, got %v", server.GetTransport()) + } + + // Test setting HTTP transport + server.SetTransport(TransportHTTP) + if server.GetTransport() != TransportHTTP { + t.Errorf("Expected HTTP transport, got %v", server.GetTransport()) + } + + // Test setting SSE transport + server.SetTransport(TransportSSE) + if server.GetTransport() != TransportSSE { + t.Errorf("Expected SSE transport, got %v", server.GetTransport()) + } +} + +// TestSSEMessageParsing_MultipleMessages tests parsing multiple SSE messages +func TestSSEMessageParsing_MultipleMessages(t *testing.T) { + data := []byte(`event: endpoint +data: /messages/123 + +event: message +data: {"id":1} + +: comment line +event: message +data: {"id":2} + +`) + + // First message + msg1, remaining, err := ParseSSEMessage(data) + if err != nil { + t.Fatalf("Failed to parse first message: %v", err) + } + if msg1 == nil || msg1.Event != "endpoint" || msg1.Data != "/messages/123" { + t.Errorf("First message incorrect: %+v", msg1) + } + + // Second message + msg2, remaining, err := ParseSSEMessage(remaining) + if err != nil { + t.Fatalf("Failed to parse second message: %v", err) + } + if msg2 == nil || msg2.Event != "message" || msg2.Data != `{"id":1}` { + t.Errorf("Second message incorrect: %+v", msg2) + } + + // Third message + msg3, remaining, err := ParseSSEMessage(remaining) + if err != nil { + t.Fatalf("Failed to parse third message: %v", err) + } + if msg3 == nil || msg3.Event != "message" || msg3.Data != `{"id":2}` { + t.Errorf("Third message incorrect: %+v", msg3) + } + + // Should be no more complete messages + msg4, _, err := ParseSSEMessage(remaining) + if err != nil { + t.Fatalf("Error parsing remaining data: %v", err) + } + if msg4 != nil { + t.Errorf("Expected no more messages, got: %+v", msg4) + } +} diff --git a/plugins/wasm-go/pkg/mcp/utils/json_rpc.go b/plugins/wasm-go/pkg/mcp/utils/json_rpc.go new file mode 100644 index 000000000..8f3c9bd4f --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/utils/json_rpc.go @@ -0,0 +1,209 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "strconv" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "google.golang.org/protobuf/proto" + + "github.com/higress-group/wasm-go/pkg/iface" + "github.com/higress-group/wasm-go/pkg/log" + pb "github.com/higress-group/wasm-go/pkg/protos" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +const ( + CtxJsonRpcID = "jsonRpcID" + CtxNeedPause = "needPause" // Context key to signal if the handler needs to pause + JError = "error" + JCode = "code" + JMessage = "message" + JResult = "result" + + ErrParseError = -32700 + ErrInvalidRequest = -32600 + ErrMethodNotFound = -32601 + ErrInvalidParams = -32602 + ErrInternalError = -32603 +) + +// JsonRpcID represents a JSON-RPC ID which can be either a string or a number +type JsonRpcID struct { + StringValue string + IntValue int64 + IsString bool +} + +// NewJsonRpcIDFromGjson creates a JsonRpcID from a gjson.Result +func NewJsonRpcIDFromGjson(result gjson.Result) JsonRpcID { + if result.Type == gjson.String { + return JsonRpcID{ + StringValue: result.String(), + IsString: true, + } + } + return JsonRpcID{ + IntValue: result.Int(), + IsString: false, + } +} + +type JsonRpcRequestHandler func(context wrapper.HttpContext, id JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action + +type JsonRpcResponseHandler func(context wrapper.HttpContext, id JsonRpcID, result gjson.Result, error gjson.Result, rawBody []byte) types.Action + +type JsonRpcMethodHandler func(context wrapper.HttpContext, id JsonRpcID, params gjson.Result) error + +type MethodHandlers map[string]JsonRpcMethodHandler + +func makeHttpResponse(ctx wrapper.HttpContext, code uint32, debugInfo string, headers [][2]string, body []byte) { + phase := ctx.GetExecutionPhase() + if phase < iface.EncodeHeader { + proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1) + return + } + if debugInfo != "" { + log.Infof("response detail info:%s", debugInfo) + } + proxywasm.RemoveHttpResponseHeader("content-length") + proxywasm.ReplaceHttpResponseHeader(":status", strconv.Itoa(int(code))) + for _, kv := range headers { + proxywasm.ReplaceHttpResponseHeader(kv[0], kv[1]) + } + if phase == iface.EncodeData { + proxywasm.ReplaceHttpResponseBody(body) + return + } + // EncodeHeader phase + args := &pb.InjectEncodedDataToFilterChainArguments{ + Body: string(body), + Endstream: true, + } + argsStr, _ := proto.Marshal(args) + _, err := proxywasm.CallForeignFunction("inject_encoded_data_to_filter_chain_on_header", argsStr) + if err != nil { + log.Warnf("call inject_encoded_data_to_filter_chain_on_header failed, err:%v, fallback to send directly", err) + proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1) + return + } +} + +func sendJsonRpcResponse(ctx wrapper.HttpContext, id JsonRpcID, extras map[string]any, debugInfo string) { + body := []byte(`{"jsonrpc": "2.0"}`) + if id.IsString { + body, _ = sjson.SetBytes(body, "id", id.StringValue) + } else { + body, _ = sjson.SetBytes(body, "id", id.IntValue) + } + for key, value := range extras { + body, _ = sjson.SetBytes(body, key, value) + } + makeHttpResponse(ctx, 200, debugInfo, [][2]string{{"Content-Type", "application/json; charset=utf-8"}}, body) +} + +func OnJsonRpcResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo ...string) { + var ( + id JsonRpcID + ok bool + ) + idRaw := ctx.GetContext(CtxJsonRpcID) + if id, ok = idRaw.(JsonRpcID); !ok { + makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id")) + return + } + responseDebugInfo := "json_rpc_success" + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + sendJsonRpcResponse(ctx, id, map[string]any{JResult: result}, responseDebugInfo) +} + +func OnJsonRpcResponseError(ctx wrapper.HttpContext, err error, errorCode int, debugInfo ...string) { + var ( + id JsonRpcID + ok bool + ) + idRaw := ctx.GetContext(CtxJsonRpcID) + if id, ok = idRaw.(JsonRpcID); !ok { + makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id")) + return + } + responseDebugInfo := fmt.Sprintf("json_rpc_error(%s)", err) + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + sendJsonRpcResponse(ctx, id, map[string]any{JError: map[string]any{ + JMessage: err.Error(), + JCode: errorCode, + }}, responseDebugInfo) +} + +func HandleJsonRpcMethod(ctx wrapper.HttpContext, body []byte, handles MethodHandlers) types.Action { + idResult := gjson.GetBytes(body, "id") + id := NewJsonRpcIDFromGjson(idResult) + ctx.SetContext(CtxJsonRpcID, id) + method := gjson.GetBytes(body, "method").String() + params := gjson.GetBytes(body, "params") + if method != "" { + if handle, ok := handles[method]; ok { + log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw) + + // Clear pause flag before calling handler + ctx.SetContext(CtxNeedPause, false) + + err := handle(ctx, id, params) + if err != nil { + OnJsonRpcResponseError(ctx, err, ErrInvalidRequest) + return types.ActionContinue + } + + // Check if the handler set the pause flag + if needPause := ctx.GetContext(CtxNeedPause); needPause != nil && needPause.(bool) { + return types.ActionPause + } + + return types.ActionContinue + } + OnJsonRpcResponseError(ctx, fmt.Errorf("method not found:%s", method), ErrMethodNotFound) + } else { + proxywasm.SendHttpResponseWithDetail(202, "json_rpc_ack", nil, nil, -1) + } + return types.ActionContinue +} + +func HandleJsonRpcRequest(ctx wrapper.HttpContext, body []byte, handle JsonRpcRequestHandler) types.Action { + idResult := gjson.GetBytes(body, "id") + id := NewJsonRpcIDFromGjson(idResult) + ctx.SetContext(CtxJsonRpcID, id) + method := gjson.GetBytes(body, "method").String() + params := gjson.GetBytes(body, "params") + log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw) + return handle(ctx, id, method, params, body) +} + +func HandleJsonRpcResponse(ctx wrapper.HttpContext, body []byte, handle JsonRpcResponseHandler) types.Action { + idResult := gjson.GetBytes(body, "id") + id := NewJsonRpcIDFromGjson(idResult) + error := gjson.GetBytes(body, "error") + result := gjson.GetBytes(body, "result") + log.Debugf("json rpc response error[%s] result[%s]", error.Raw, result.Raw) + return handle(ctx, id, result, error, body) +} diff --git a/plugins/wasm-go/pkg/mcp/utils/json_rpc_test.go b/plugins/wasm-go/pkg/mcp/utils/json_rpc_test.go new file mode 100644 index 000000000..72aaa7a6b --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/utils/json_rpc_test.go @@ -0,0 +1,160 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestJsonRpcIDFromGjson(t *testing.T) { + tests := []struct { + name string + jsonData string + expected JsonRpcID + }{ + { + name: "integer id", + jsonData: `{"id": 123}`, + expected: JsonRpcID{ + IntValue: 123, + IsString: false, + }, + }, + { + name: "string id", + jsonData: `{"id": "abc-123"}`, + expected: JsonRpcID{ + StringValue: "abc-123", + IsString: true, + }, + }, + { + name: "float id treated as int", + jsonData: `{"id": 123.45}`, + expected: JsonRpcID{ + IntValue: 123, + IsString: false, + }, + }, + { + name: "boolean id treated as int", + jsonData: `{"id": true}`, + expected: JsonRpcID{ + IntValue: 1, + IsString: false, + }, + }, + { + name: "null id treated as int zero", + jsonData: `{"id": null}`, + expected: JsonRpcID{ + IntValue: 0, + IsString: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idResult := gjson.Get(tt.jsonData, "id") + result := NewJsonRpcIDFromGjson(idResult) + + if result.IsString != tt.expected.IsString { + t.Errorf("IsString = %v, want %v", result.IsString, tt.expected.IsString) + } + + if result.IsString { + if result.StringValue != tt.expected.StringValue { + t.Errorf("StringValue = %v, want %v", result.StringValue, tt.expected.StringValue) + } + } else { + if result.IntValue != tt.expected.IntValue { + t.Errorf("IntValue = %v, want %v", result.IntValue, tt.expected.IntValue) + } + } + }) + } +} + +// Skip TestSendJsonRpcResponse because it requires proxywasm which is not available in the test environment +// This function would normally test that sendJsonRpcResponse correctly handles different ID types +func TestSendJsonRpcResponse(t *testing.T) { + t.Skip("Skipping test that requires proxywasm") +} + +func TestJsonRpcIDMarshaling(t *testing.T) { + // Test that JsonRpcID is correctly marshaled in a JSON response + + tests := []struct { + name string + id JsonRpcID + expected string + }{ + { + name: "integer id", + id: JsonRpcID{ + IntValue: 123, + IsString: false, + }, + expected: `"id":123`, + }, + { + name: "string id", + id: JsonRpcID{ + StringValue: "abc-123", + IsString: true, + }, + expected: `"id":"abc-123"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a JSON object with the ID + var jsonObj map[string]interface{} + if tt.id.IsString { + jsonObj = map[string]interface{}{ + "jsonrpc": "2.0", + "id": tt.id.StringValue, + } + } else { + jsonObj = map[string]interface{}{ + "jsonrpc": "2.0", + "id": tt.id.IntValue, + } + } + + // Marshal to JSON + body, err := json.Marshal(jsonObj) + if err != nil { + t.Errorf("Failed to marshal JSON: %v", err) + } + + // Check that the ID is correctly marshaled + if !json.Valid(body) { + t.Errorf("Invalid JSON: %s", string(body)) + } + + // Check that the ID is correctly formatted + if !strings.Contains(string(body), tt.expected) { + t.Errorf("ID not correctly formatted. Expected to contain %s, got %s", tt.expected, string(body)) + } + }) + } +} diff --git a/plugins/wasm-go/pkg/mcp/utils/log.go b/plugins/wasm-go/pkg/mcp/utils/log.go new file mode 100644 index 000000000..644b7d0e2 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/utils/log.go @@ -0,0 +1,130 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +type MCPServerLog struct { +} + +func setMCPInfo(msg string) string { + requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) + requestID := string(requestIDRaw) + if requestID == "" { + requestID = "nil" + } + mcpServerNameRaw, _ := proxywasm.GetProperty([]string{"mcp_server_name"}) + mcpServerName := string(mcpServerNameRaw) + mcpToolNameRaw, _ := proxywasm.GetProperty([]string{"mcp_tool_name"}) + mcpToolName := string(mcpToolNameRaw) + mcpInfo := mcpServerName + if mcpToolName != "" { + mcpInfo = fmt.Sprintf("%s/%s", mcpServerName, mcpToolName) + } + return fmt.Sprintf("[mcp-server] [%s] [%s] %s", mcpInfo, requestID, msg) +} + +func (l MCPServerLog) log(level wrapper.LogLevel, msg string) { + msg = setMCPInfo(msg) + switch level { + case wrapper.LogLevelTrace: + proxywasm.LogTrace(msg) + case wrapper.LogLevelDebug: + proxywasm.LogDebug(msg) + case wrapper.LogLevelInfo: + proxywasm.LogInfo(msg) + case wrapper.LogLevelWarn: + proxywasm.LogWarn(msg) + case wrapper.LogLevelError: + proxywasm.LogError(msg) + case wrapper.LogLevelCritical: + proxywasm.LogCritical(msg) + } +} + +func (l MCPServerLog) logFormat(level wrapper.LogLevel, format string, args ...interface{}) { + format = setMCPInfo(format) + switch level { + case wrapper.LogLevelTrace: + proxywasm.LogTracef(format, args...) + case wrapper.LogLevelDebug: + proxywasm.LogDebugf(format, args...) + case wrapper.LogLevelInfo: + proxywasm.LogInfof(format, args...) + case wrapper.LogLevelWarn: + proxywasm.LogWarnf(format, args...) + case wrapper.LogLevelError: + proxywasm.LogErrorf(format, args...) + case wrapper.LogLevelCritical: + proxywasm.LogCriticalf(format, args...) + } +} + +func (l MCPServerLog) Trace(msg string) { + l.log(wrapper.LogLevelTrace, msg) +} + +func (l MCPServerLog) Tracef(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelTrace, format, args...) +} + +func (l MCPServerLog) Debug(msg string) { + l.log(wrapper.LogLevelDebug, msg) +} + +func (l MCPServerLog) Debugf(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelDebug, format, args...) +} + +func (l MCPServerLog) Info(msg string) { + l.log(wrapper.LogLevelInfo, msg) +} + +func (l MCPServerLog) Infof(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelInfo, format, args...) +} + +func (l MCPServerLog) Warn(msg string) { + l.log(wrapper.LogLevelWarn, msg) +} + +func (l MCPServerLog) Warnf(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelWarn, format, args...) +} + +func (l MCPServerLog) Error(msg string) { + l.log(wrapper.LogLevelError, msg) +} + +func (l MCPServerLog) Errorf(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelError, format, args...) +} + +func (l MCPServerLog) Critical(msg string) { + l.log(wrapper.LogLevelCritical, msg) +} + +func (l MCPServerLog) Criticalf(format string, args ...interface{}) { + l.logFormat(wrapper.LogLevelCritical, format, args...) +} + +func (l MCPServerLog) ResetID(pluginID string) { +} diff --git a/plugins/wasm-go/pkg/mcp/utils/mcp_rpc.go b/plugins/wasm-go/pkg/mcp/utils/mcp_rpc.go new file mode 100644 index 000000000..d8f3e6663 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/utils/mcp_rpc.go @@ -0,0 +1,117 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func OnMCPResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo string) { + OnJsonRpcResponseSuccess(ctx, result, debugInfo) + // TODO: support pub to redis when use POST + SSE +} + +func OnMCPResponseError(ctx wrapper.HttpContext, err error, code int, debugInfo string) { + OnJsonRpcResponseError(ctx, err, code, debugInfo) + // TODO: support pub to redis when use POST + SSE +} + +func OnMCPToolCallSuccess(ctx wrapper.HttpContext, content []map[string]any, debugInfo string) { + OnMCPResponseSuccess(ctx, map[string]any{ + "content": content, + "isError": false, + }, debugInfo) +} + +// OnMCPToolCallSuccessWithStructuredContent sends a successful MCP tool response with structured content +// According to MCP spec, structuredContent is a field in tool results, not a capability +func OnMCPToolCallSuccessWithStructuredContent(ctx wrapper.HttpContext, content []map[string]any, structuredContent json.RawMessage, debugInfo string) { + response := map[string]any{ + "content": content, + "isError": false, + } + if structuredContent != nil && len(structuredContent) > 0 { + response["structuredContent"] = structuredContent + } + OnMCPResponseSuccess(ctx, response, debugInfo) +} + +func OnMCPToolCallError(ctx wrapper.HttpContext, err error, debugInfo ...string) { + responseDebugInfo := fmt.Sprintf("mcp:tools/call:error(%s)", err) + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + OnMCPResponseSuccess(ctx, map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": err.Error(), + }, + }, + "isError": true, + }, responseDebugInfo) +} + +func SendMCPToolTextResult(ctx wrapper.HttpContext, result string, debugInfo ...string) { + responseDebugInfo := "mcp:tools/call::result" + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + OnMCPToolCallSuccess(ctx, []map[string]any{ + { + "type": "text", + "text": result, + }, + }, responseDebugInfo) +} + +func SendMCPToolImageResult(ctx wrapper.HttpContext, image []byte, contentType string, debugInfo ...string) { + responseDebugInfo := "mcp:tools/call::result" + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + + content := []map[string]any{ + { + "type": "image", + "data": base64.StdEncoding.EncodeToString(image), + "mimeType": contentType, + }, + } + + // Use traditional response format since no structured data is provided + OnMCPToolCallSuccess(ctx, content, responseDebugInfo) +} + +// SendMCPToolTextResultWithStructuredContent sends a tool result with both text content and structured content +// According to MCP spec, for backward compatibility, tools that return structured content +// SHOULD also return the serialized JSON in a TextContent block +func SendMCPToolTextResultWithStructuredContent(ctx wrapper.HttpContext, textResult string, structuredContent json.RawMessage, debugInfo ...string) { + responseDebugInfo := "mcp:tools/call::result" + if len(debugInfo) > 0 { + responseDebugInfo = debugInfo[0] + } + content := []map[string]any{ + { + "type": "text", + "text": textResult, + }, + } + OnMCPToolCallSuccessWithStructuredContent(ctx, content, structuredContent, responseDebugInfo) +} diff --git a/plugins/wasm-go/pkg/mcp/utils/session.go b/plugins/wasm-go/pkg/mcp/utils/session.go new file mode 100644 index 000000000..7f3bbafd7 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/utils/session.go @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "net/url" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func IsStatefulSession(ctx wrapper.HttpContext) bool { + parse, err := url.Parse(ctx.Path()) + if err != nil { + log.Errorf("failed to parse request path: %v", err) + return false + } + query, err := url.ParseQuery(parse.RawQuery) + if err != nil { + log.Errorf("failed to parse query params: %v", err) + return false + } + // Protocol version: 2024-11-05 + if query.Get("sessionId") != "" { + return true + } + // Protocol version: 2025-03-26 + sessionHeader, err := proxywasm.GetHttpRequestHeader("mcp-session-id") + if err != nil { + log.Errorf("failed to get request header: %v", err) + return false + } + if sessionHeader != "" { + return true + } + return false +} diff --git a/plugins/wasm-go/pkg/mcp/validator/README.md b/plugins/wasm-go/pkg/mcp/validator/README.md new file mode 100644 index 000000000..701e64681 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/validator/README.md @@ -0,0 +1,164 @@ +# MCP Configuration Validator + +This package provides a configuration validation library for MCP (Model Context Protocol) server configurations. It allows you to validate MCP configurations without requiring the full runtime environment, making it perfect for use in management platforms and frontend applications. + +## Features + +- **Lightweight Validation**: Validates configuration structure and syntax without requiring actual server instances +- **REST Tool Support**: Full validation for REST-based MCP tools including request/response templates +- **ToolSet Support**: Validates composed server configurations (toolSets) +- **Pre-registered Server Handling**: Gracefully handles pre-registered Go-based servers by skipping their validation +- **Minimal Dependencies**: Reuses the core parsing logic from the main MCP server implementation + +## Usage + +### Basic Validation + +```go +import "github.com/higress-group/wasm-go/pkg/mcp/validator" + +// Validate a configuration YAML string +yamlConfig := ` +server: + name: my-server + config: + apiKey: secret +tools: + - name: my-tool + description: A sample tool + args: + - name: input + type: string + required: true + requestTemplate: + url: https://api.example.com/endpoint + method: POST + responseTemplate: + body: "{{.}}" +` +result, err := validator.ValidateConfigYAML(yamlConfig) +if err != nil { + // Handle error + return +} + +if result.IsValid { + fmt.Printf("Configuration is valid for server: %s\n", result.ServerName) + if result.IsComposed { + fmt.Println("This is a composed server (toolSet)") + } else { + fmt.Println("This is a single server") + } +} else { + fmt.Printf("Configuration is invalid: %v\n", result.Error) +} +``` + +## Supported Configuration Types + +### 1. REST Server Configuration + +Validates REST-based MCP servers with tools, security schemes, and templates: + +```yaml +server: + name: weather-api + config: + apiKey: your-api-key + securitySchemes: + - id: bearer-auth + type: http + scheme: bearer +tools: + - name: get_weather + description: Get current weather + args: + - name: city + type: string + required: true + requestTemplate: + url: "https://api.weather.com/v1/current?city={{.args.city}}" + method: GET + responseTemplate: + body: "Weather: {{.temperature}}°C" +``` + +### 2. ToolSet Configuration (Composed Server) + +Validates composed servers that aggregate tools from multiple servers: + +```yaml +toolSet: + name: ai-assistant-tools + serverTools: + - serverName: weather-api + tools: ["get_weather", "get_forecast"] + - serverName: search-api + tools: ["web_search"] +``` + +### 3. Pre-registered Go-based Server + +For pre-registered Go-based servers, validation focuses on basic structure and skips server instance validation: + +```yaml +server: + name: custom-go-server + config: + database_url: "postgres://localhost:5432/mydb" +allowTools: ["query_database"] +``` + +## Validation Result + +The `ValidationResult` struct provides detailed information about the validation: + +```go +type ValidationResult struct { + IsValid bool `json:"isValid"` // Whether the configuration is valid + Error error `json:"error"` // Validation error if any + ServerName string `json:"serverName"` // Parsed server name + IsComposed bool `json:"isComposed"` // Whether it's a composed server +} +``` + +## Architecture + +The validator reuses the core parsing logic from the main MCP server implementation through dependency injection: + +- **parseConfigCore**: Core parsing logic with configurable dependencies +- **ConfigOptions**: Dependency config options +- **SkipPreRegisteredServers**: Flag to skip validation of pre-registered Go servers + +This approach ensures: +- **Consistency**: Same validation logic as runtime +- **Maintainability**: Single source of truth for parsing logic +- **Minimal Code Duplication**: Reuses existing implementation + +## Testing + +Run the tests to verify the validator works correctly: + +```bash +cd pkg/mcp/validator +go test -v +``` + +The test suite covers: +- REST server configuration validation +- ToolSet configuration validation +- Pre-registered server handling +- Invalid configuration detection +- Error cases + +## Error Handling + +The validator provides detailed error messages for common configuration issues: + +- Missing required fields (e.g., `server.name`) +- Invalid JSON structure +- Malformed tool definitions +- Invalid template syntax +- Missing server or toolSet configuration + +These errors help developers quickly identify and fix configuration problems before deployment. diff --git a/plugins/wasm-go/pkg/mcp/validator/config_validator.go b/plugins/wasm-go/pkg/mcp/validator/config_validator.go new file mode 100644 index 000000000..0ee45b086 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/validator/config_validator.go @@ -0,0 +1,129 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/tidwall/gjson" + "gopkg.in/yaml.v3" + + "github.com/higress-group/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server" +) + +// validatorLogger is a simple logger implementation for validation mode +type validatorLogger struct{} + +func (l *validatorLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) } +func (l *validatorLogger) Tracef(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...) +} +func (l *validatorLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) } +func (l *validatorLogger) Debugf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...) +} +func (l *validatorLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) } +func (l *validatorLogger) Infof(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...) +} +func (l *validatorLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) } +func (l *validatorLogger) Warnf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...) +} +func (l *validatorLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) } +func (l *validatorLogger) Errorf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...) +} +func (l *validatorLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) } +func (l *validatorLogger) Criticalf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...) +} +func (l *validatorLogger) ResetID(pluginID string) {} + +// init initializes the validator package +func init() { + // Set a custom logger for validation mode to prevent panics + log.SetPluginLog(&validatorLogger{}) +} + +// ValidationResult contains the result of configuration validation +type ValidationResult struct { + IsValid bool `json:"isValid"` + Error error `json:"error,omitempty"` + ServerName string `json:"serverName,omitempty"` + IsComposed bool `json:"isComposed"` +} + +// ValidateConfig validates MCP configuration +// This function focuses on validating REST tools and toolSet configurations +// It skips validation for pre-registered Go-based servers +func ValidateConfig(configJSON string) (*ValidationResult, error) { + // Create empty dependencies for validation mode + // We skip pre-registered servers validation by setting SkipPreRegisteredServers to true + toolRegistry := &server.GlobalToolRegistry{} + toolRegistry.Initialize() // Initialize the registry to prevent nil map assignment panic + + deps := &server.ConfigOptions{ + Servers: make(map[string]server.Server), // Empty servers map + ToolRegistry: toolRegistry, // Initialized registry + SkipPreRegisteredServers: true, // Skip pre-registered servers + } + + // Call core parsing logic for validation + configGjson := gjson.Parse(configJSON) + mockConfig := &server.McpServerConfig{} + + err := server.ParseConfigCore(configGjson, mockConfig, deps) + + result := &ValidationResult{ + IsValid: err == nil, + Error: err, + } + + if err == nil { + result.ServerName = mockConfig.GetServerName() + result.IsComposed = mockConfig.GetIsComposed() + } + + return result, nil +} + +// ValidateConfigYAML validates MCP configuration from YAML format +// This function converts YAML to JSON first, then validates using the same logic +func ValidateConfigYAML(configYAML string) (*ValidationResult, error) { + // Parse YAML into a generic interface + var yamlData interface{} + if err := yaml.Unmarshal([]byte(configYAML), &yamlData); err != nil { + return &ValidationResult{ + IsValid: false, + Error: fmt.Errorf("failed to parse YAML: %v", err), + }, nil + } + + // Convert to JSON + jsonBytes, err := json.Marshal(yamlData) + if err != nil { + return &ValidationResult{ + IsValid: false, + Error: fmt.Errorf("failed to convert YAML to JSON: %v", err), + }, nil + } + + // Use the existing JSON validation logic + return ValidateConfig(string(jsonBytes)) +} diff --git a/plugins/wasm-go/pkg/mcp/validator/config_validator_test.go b/plugins/wasm-go/pkg/mcp/validator/config_validator_test.go new file mode 100644 index 000000000..0740b27d2 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/validator/config_validator_test.go @@ -0,0 +1,266 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "testing" +) + +func TestValidateConfig_RestServer(t *testing.T) { + // Test REST server configuration + configJSON := `{ + "server": { + "name": "test-rest-server", + "config": {} + }, + "tools": [ + { + "name": "test-tool", + "description": "A test tool", + "args": [ + { + "name": "input", + "description": "Input parameter", + "type": "string", + "required": true + } + ], + "requestTemplate": { + "url": "https://api.example.com/test", + "method": "POST" + }, + "responseTemplate": { + "body": "{{.}}" + } + } + ] + }` + + result, err := ValidateConfig(configJSON) + if err != nil { + t.Fatalf("ValidateConfig returned error: %v", err) + } + + if !result.IsValid { + t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error) + } + + if result.ServerName != "test-rest-server" { + t.Errorf("Expected server name 'test-rest-server', got '%s'", result.ServerName) + } + + if result.IsComposed { + t.Errorf("Expected single server (not composed), but got composed") + } +} + +func TestValidateConfig_ToolSet(t *testing.T) { + // Test toolSet configuration + configJSON := `{ + "toolSet": { + "name": "test-toolset", + "serverTools": [ + { + "serverName": "server1", + "tools": ["tool1", "tool2"] + }, + { + "serverName": "server2", + "tools": ["tool3"] + } + ] + } + }` + + result, err := ValidateConfig(configJSON) + if err != nil { + t.Fatalf("ValidateConfig returned error: %v", err) + } + + if !result.IsValid { + t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error) + } + + if result.ServerName != "test-toolset" { + t.Errorf("Expected server name 'test-toolset', got '%s'", result.ServerName) + } + + if !result.IsComposed { + t.Errorf("Expected composed server, but got single server") + } +} + +func TestValidateConfig_PreRegisteredServer(t *testing.T) { + // Test pre-registered Go-based server configuration (should be skipped in validation) + configJSON := `{ + "server": { + "name": "some-go-server", + "config": { + "someParam": "value" + } + } + }` + + result, err := ValidateConfig(configJSON) + if err != nil { + t.Fatalf("ValidateConfig returned error: %v", err) + } + + if !result.IsValid { + t.Errorf("Expected config to be valid (pre-registered servers should be skipped), but got invalid with error: %v", result.Error) + } + + if result.ServerName != "some-go-server" { + t.Errorf("Expected server name 'some-go-server', got '%s'", result.ServerName) + } + + if result.IsComposed { + t.Errorf("Expected single server (not composed), but got composed") + } +} + +func TestValidateConfig_InvalidConfig(t *testing.T) { + // Test invalid configuration (missing required fields) + configJSON := `{ + "server": { + "config": {} + } + }` + + result, err := ValidateConfig(configJSON) + if err != nil { + t.Fatalf("ValidateConfig returned error: %v", err) + } + + if result.IsValid { + t.Errorf("Expected config to be invalid, but got valid") + } + + if result.Error == nil { + t.Errorf("Expected validation error, but got nil") + } +} + +func TestValidateConfig_MissingServerAndToolSet(t *testing.T) { + // Test configuration missing both server and toolSet + configJSON := `{ + "allowTools": ["tool1"] + }` + + result, err := ValidateConfig(configJSON) + if err != nil { + t.Fatalf("ValidateConfig returned error: %v", err) + } + + if result.IsValid { + t.Errorf("Expected config to be invalid, but got valid") + } + + if result.Error == nil { + t.Errorf("Expected validation error, but got nil") + } +} + +func TestValidateConfigYAML_RestServer(t *testing.T) { + // Test REST server configuration in YAML format + configYAML := ` +server: + name: test-rest-server-yaml + config: {} +tools: + - name: test-tool-yaml + description: A test tool from YAML + args: + - name: input + description: Input parameter + type: string + required: true + requestTemplate: + url: https://api.example.com/test + method: POST + responseTemplate: + body: "{{.}}" +` + + result, err := ValidateConfigYAML(configYAML) + if err != nil { + t.Fatalf("ValidateConfigYAML returned error: %v", err) + } + + if !result.IsValid { + t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error) + } + + if result.ServerName != "test-rest-server-yaml" { + t.Errorf("Expected server name 'test-rest-server-yaml', got '%s'", result.ServerName) + } + + if result.IsComposed { + t.Errorf("Expected single server (not composed), but got composed") + } +} + +func TestValidateConfigYAML_ToolSet(t *testing.T) { + // Test toolSet configuration in YAML format + configYAML := ` +toolSet: + name: test-toolset-yaml + serverTools: + - serverName: server1 + tools: ["tool1", "tool2"] + - serverName: server2 + tools: ["tool3"] +` + + result, err := ValidateConfigYAML(configYAML) + if err != nil { + t.Fatalf("ValidateConfigYAML returned error: %v", err) + } + + if !result.IsValid { + t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error) + } + + if result.ServerName != "test-toolset-yaml" { + t.Errorf("Expected server name 'test-toolset-yaml', got '%s'", result.ServerName) + } + + if !result.IsComposed { + t.Errorf("Expected composed server, but got single server") + } +} + +func TestValidateConfigYAML_InvalidYAML(t *testing.T) { + // Test invalid YAML syntax + configYAML := ` +server: + name: test-server + config: { + invalid yaml syntax here +` + + result, err := ValidateConfigYAML(configYAML) + if err != nil { + t.Fatalf("ValidateConfigYAML returned error: %v", err) + } + + if result.IsValid { + t.Errorf("Expected config to be invalid due to YAML syntax error, but got valid") + } + + if result.Error == nil { + t.Errorf("Expected YAML parsing error, but got nil") + } +} diff --git a/plugins/wasm-go/pkg/mcp/validator/example_usage.go b/plugins/wasm-go/pkg/mcp/validator/example_usage.go new file mode 100644 index 000000000..fcb58e9b0 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/validator/example_usage.go @@ -0,0 +1,250 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "encoding/json" + "fmt" +) + +// ExampleUsage demonstrates how to use the ValidateConfig function +func ExampleUsage() { + // Example 1: REST server configuration + restServerConfig := `{ + "server": { + "name": "weather-api", + "config": { + "apiKey": "your-api-key" + } + }, + "tools": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "args": [ + { + "name": "city", + "description": "City name", + "type": "string", + "required": true + }, + { + "name": "units", + "description": "Temperature units", + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius" + } + ], + "requestTemplate": { + "url": "https://api.weather.com/v1/current?city={{.args.city}}&units={{.args.units}}", + "method": "GET", + "headers": [ + { + "key": "Authorization", + "value": "Bearer {{.config.apiKey}}" + } + ] + }, + "responseTemplate": { + "body": "Current weather in {{.args.city}}: {{.temperature}}°{{.args.units}}" + } + } + ], + "allowTools": ["get_weather"] + }` + + result, err := ValidateConfig(restServerConfig) + if err != nil { + fmt.Printf("Error validating REST server config: %v\n", err) + return + } + + fmt.Printf("REST Server Config Validation:\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + fmt.Printf(" Server Name: %s\n", result.ServerName) + fmt.Printf(" Is Composed: %t\n", result.IsComposed) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } + fmt.Println() + + // Example 2: ToolSet configuration + toolSetConfig := `{ + "toolSet": { + "name": "ai-assistant-tools", + "serverTools": [ + { + "serverName": "weather-api", + "tools": ["get_weather", "get_forecast"] + }, + { + "serverName": "search-api", + "tools": ["web_search", "image_search"] + } + ] + }, + "allowTools": ["weather-api/get_weather", "search-api/web_search"] + }` + + result, err = ValidateConfig(toolSetConfig) + if err != nil { + fmt.Printf("Error validating toolSet config: %v\n", err) + return + } + + fmt.Printf("ToolSet Config Validation:\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + fmt.Printf(" Server Name: %s\n", result.ServerName) + fmt.Printf(" Is Composed: %t\n", result.IsComposed) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } + fmt.Println() + + // Example 3: Pre-registered Go-based server (validation skipped) + goServerConfig := `{ + "server": { + "name": "custom-go-server", + "config": { + "database_url": "postgres://localhost:5432/mydb", + "max_connections": 10 + } + }, + "allowTools": ["query_database", "update_record"] + }` + + result, err = ValidateConfig(goServerConfig) + if err != nil { + fmt.Printf("Error validating Go server config: %v\n", err) + return + } + + fmt.Printf("Go Server Config Validation (skipped):\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + fmt.Printf(" Server Name: %s\n", result.ServerName) + fmt.Printf(" Is Composed: %t\n", result.IsComposed) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } + fmt.Println() + + // Example 4: Invalid configuration + invalidConfig := `{ + "server": { + "config": {} + } + }` + + result, err = ValidateConfig(invalidConfig) + if err != nil { + fmt.Printf("Error validating invalid config: %v\n", err) + return + } + + fmt.Printf("Invalid Config Validation:\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } +} + +// ValidateConfigFromBytes validates configuration from byte array +func ValidateConfigFromBytes(configBytes []byte) (*ValidationResult, error) { + return ValidateConfig(string(configBytes)) +} + +// ValidateConfigFromMap validates configuration from a map +func ValidateConfigFromMap(configMap map[string]interface{}) (*ValidationResult, error) { + configBytes, err := json.Marshal(configMap) + if err != nil { + return nil, fmt.Errorf("failed to marshal config map: %v", err) + } + return ValidateConfig(string(configBytes)) +} + +// ExampleYAMLUsage demonstrates how to use the ValidateConfigYAML function +func ExampleYAMLUsage() { + // Example YAML configuration for REST server + yamlConfig := ` +server: + name: weather-api-yaml + config: + apiKey: your-api-key +tools: + - name: get_weather + description: Get current weather for a city + args: + - name: city + description: City name + type: string + required: true + - name: units + description: Temperature units + type: string + enum: ["celsius", "fahrenheit"] + default: celsius + requestTemplate: + url: "https://api.weather.com/v1/current?city={{.args.city}}&units={{.args.units}}" + method: GET + headers: + - key: Authorization + value: "Bearer {{.config.apiKey}}" + responseTemplate: + body: "Current weather in {{.args.city}}: {{.temperature}}°{{.args.units}}" +allowTools: ["get_weather"] +` + + result, err := ValidateConfigYAML(yamlConfig) + if err != nil { + fmt.Printf("Error validating YAML config: %v\n", err) + return + } + + fmt.Printf("YAML Config Validation:\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + fmt.Printf(" Server Name: %s\n", result.ServerName) + fmt.Printf(" Is Composed: %t\n", result.IsComposed) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } + fmt.Println() + + // Example YAML configuration for ToolSet + yamlToolSetConfig := ` +toolSet: + name: ai-assistant-tools-yaml + serverTools: + - serverName: weather-api + tools: ["get_weather", "get_forecast"] + - serverName: search-api + tools: ["web_search", "image_search"] +allowTools: ["weather-api/get_weather", "search-api/web_search"] +` + + result, err = ValidateConfigYAML(yamlToolSetConfig) + if err != nil { + fmt.Printf("Error validating YAML toolSet config: %v\n", err) + return + } + + fmt.Printf("YAML ToolSet Config Validation:\n") + fmt.Printf(" Valid: %t\n", result.IsValid) + fmt.Printf(" Server Name: %s\n", result.ServerName) + fmt.Printf(" Is Composed: %t\n", result.IsComposed) + if result.Error != nil { + fmt.Printf(" Error: %v\n", result.Error) + } +} diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go deleted file mode 100644 index e797394b5..000000000 --- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "fmt" - "strings" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" -) - -type Cluster interface { - ClusterName() string - HostName() string -} - -type RouteCluster struct { - Host string -} - -func (c RouteCluster) ClusterName() string { - routeName, err := proxywasm.GetProperty([]string{"cluster_name"}) - if err != nil { - proxywasm.LogErrorf("get route cluster failed, err:%v", err) - } - return string(routeName) -} - -func (c RouteCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return GetRequestHost() -} - -type TargetCluster struct { - Host string - Cluster string -} - -func (c TargetCluster) ClusterName() string { - return c.Cluster -} - -func (c TargetCluster) HostName() string { - return c.Host -} - -type K8sCluster struct { - ServiceName string - Namespace string - Port int64 - Version string - Host string -} - -func (c K8sCluster) ClusterName() string { - namespace := "default" - if c.Namespace != "" { - namespace = c.Namespace - } - return fmt.Sprintf("outbound|%d|%s|%s.%s.svc.cluster.local", - c.Port, c.Version, c.ServiceName, namespace) -} - -func (c K8sCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return fmt.Sprintf("%s.%s.svc.cluster.local", c.ServiceName, c.Namespace) -} - -type NacosCluster struct { - ServiceName string - // use DEFAULT-GROUP by default - Group string - NamespaceID string - Port int64 - // set true if use edas/sae registry - IsExtRegistry bool - Version string - Host string -} - -func (c NacosCluster) ClusterName() string { - group := "DEFAULT-GROUP" - if c.Group != "" { - group = strings.ReplaceAll(c.Group, "_", "-") - } - tail := "nacos" - if c.IsExtRegistry { - tail += "-ext" - } - return fmt.Sprintf("outbound|%d|%s|%s.%s.%s.%s", - c.Port, c.Version, c.ServiceName, group, c.NamespaceID, tail) -} - -func (c NacosCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return c.ServiceName -} - -type StaticIpCluster struct { - ServiceName string - Port int64 - Host string -} - -func (c StaticIpCluster) ClusterName() string { - return fmt.Sprintf("outbound|%d||%s.static", c.Port, c.ServiceName) -} - -func (c StaticIpCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return c.ServiceName -} - -type DnsCluster struct { - ServiceName string - Domain string - Port int64 -} - -func (c DnsCluster) ClusterName() string { - return fmt.Sprintf("outbound|%d||%s.dns", c.Port, c.ServiceName) -} - -func (c DnsCluster) HostName() string { - return c.Domain -} - -type ConsulCluster struct { - ServiceName string - Datacenter string - Port int64 - Host string -} - -func (c ConsulCluster) ClusterName() string { - tail := "consul" - return fmt.Sprintf("outbound|%d||%s.%s.%s", - c.Port, c.ServiceName, c.Datacenter, tail) -} - -func (c ConsulCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return c.ServiceName -} - -type FQDNCluster struct { - FQDN string - Host string - Port int64 -} - -func (c FQDNCluster) ClusterName() string { - return fmt.Sprintf("outbound|%d||%s", c.Port, c.FQDN) -} - -func (c FQDNCluster) HostName() string { - if c.Host != "" { - return c.Host - } - return c.FQDN -} diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper_test.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper_test.go deleted file mode 100644 index 01823851b..000000000 --- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestClusterAndHost(t *testing.T) { - cases := []struct { - name string - cluster Cluster - expectCluster string - expectHost string - }{ - { - name: "k8s", - cluster: K8sCluster{ - ServiceName: "foo", - Namespace: "bar", - Port: 8080, - Version: "1.0", - }, - expectCluster: "outbound|8080|1.0|foo.bar.svc.cluster.local", - expectHost: "foo.bar.svc.cluster.local", - }, - { - name: "k8s default", - cluster: K8sCluster{ - ServiceName: "foo", - Port: 8080, - Host: "www.example.com", - }, - expectCluster: "outbound|8080||foo.default.svc.cluster.local", - expectHost: "www.example.com", - }, - { - name: "nacos", - cluster: NacosCluster{ - ServiceName: "foo", - Group: "DEFAULT_GROUP", - NamespaceID: "xxxx", - Port: 8080, - Version: "1.0", - }, - expectCluster: "outbound|8080|1.0|foo.DEFAULT-GROUP.xxxx.nacos", - expectHost: "foo", - }, - { - name: "nacos ext", - cluster: NacosCluster{ - ServiceName: "foo", - NamespaceID: "xxxx", - Port: 8080, - IsExtRegistry: true, - Host: "www.test.com", - }, - expectCluster: "outbound|8080||foo.DEFAULT-GROUP.xxxx.nacos-ext", - expectHost: "www.test.com", - }, - { - name: "static", - cluster: StaticIpCluster{ - ServiceName: "foo", - Port: 8080, - Host: "www.test.com", - }, - expectCluster: "outbound|8080||foo.static", - expectHost: "www.test.com", - }, - { - name: "dns", - cluster: DnsCluster{ - ServiceName: "foo", - Port: 8080, - Domain: "www.test.com", - }, - expectCluster: "outbound|8080||foo.dns", - expectHost: "www.test.com", - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - assert.Equal(t, c.expectCluster, c.cluster.ClusterName()) - assert.Equal(t, c.expectHost, c.cluster.HostName()) - }) - } -} diff --git a/plugins/wasm-go/pkg/wrapper/http_wrapper.go b/plugins/wasm-go/pkg/wrapper/http_wrapper.go deleted file mode 100644 index 85220834c..000000000 --- a/plugins/wasm-go/pkg/wrapper/http_wrapper.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - - "github.com/google/uuid" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" -) - -type ResponseCallback func(statusCode int, responseHeaders http.Header, responseBody []byte) - -type HttpClient interface { - Get(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error - Head(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error - Options(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error - Post(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Put(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Patch(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Delete(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Connect(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Trace(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - Call(method, rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error - ClusterName() string -} - -type ClusterClient[C Cluster] struct { - cluster C -} - -func NewClusterClient[C Cluster](cluster C) *ClusterClient[C] { - return &ClusterClient[C]{cluster: cluster} -} - -func (c ClusterClient[C]) Get(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodGet, rawURL, headers, nil, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Head(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodHead, rawURL, headers, nil, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Options(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodOptions, rawURL, headers, nil, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Post(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodPost, rawURL, headers, body, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Put(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodPut, rawURL, headers, body, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Patch(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodPatch, rawURL, headers, body, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Delete(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodDelete, rawURL, headers, body, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Connect(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodConnect, rawURL, headers, body, cb, timeoutMillisecond...) -} -func (c ClusterClient[C]) Trace(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, http.MethodTrace, rawURL, headers, body, cb, timeoutMillisecond...) -} - -func (c ClusterClient[C]) Call(method, rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error { - return HttpCall(c.cluster, method, rawURL, headers, body, cb, timeoutMillisecond...) -} - -func (c ClusterClient[C]) ClusterName() string { return c.cluster.ClusterName() } - -func HttpCall(cluster Cluster, method, rawURL string, headers [][2]string, body []byte, - callback ResponseCallback, timeoutMillisecond ...uint32) error { - for i := len(headers) - 1; i >= 0; i-- { - key := headers[i][0] - if key == ":method" || key == ":path" || key == ":authority" { - headers = append(headers[:i], headers[i+1:]...) - } - } - parsedURL, err := url.Parse(rawURL) - if err != nil { - proxywasm.LogCriticalf("invalid rawURL:%s", rawURL) - return err - } - authority := cluster.HostName() - if parsedURL.Host != "" { - authority = parsedURL.Host - } - path := "/" + strings.TrimPrefix(parsedURL.Path, "/") - if parsedURL.RawQuery != "" { - path = fmt.Sprintf("%s?%s", path, parsedURL.RawQuery) - } - // default timeout is 500ms - var timeout uint32 = 500 - if len(timeoutMillisecond) > 0 { - timeout = timeoutMillisecond[0] - } - headers = append(headers, [2]string{":method", method}, [2]string{":path", path}, [2]string{":authority", authority}) - requestID := uuid.New().String() - _, err = proxywasm.DispatchHttpCall(cluster.ClusterName(), headers, body, nil, timeout, func(numHeaders, bodySize, numTrailers int) { - respBody, err := proxywasm.GetHttpCallResponseBody(0, bodySize) - if err != nil { - proxywasm.LogCriticalf("failed to get response body: %v", err) - } - respHeaders, err := proxywasm.GetHttpCallResponseHeaders() - if err != nil { - proxywasm.LogCriticalf("failed to get response headers: %v", err) - } - code := http.StatusBadGateway - var normalResponse bool - headers := make(http.Header) - for _, h := range respHeaders { - if h[0] == ":status" { - code, err = strconv.Atoi(h[1]) - if err != nil { - proxywasm.LogErrorf("failed to parse status: %v", err) - code = http.StatusInternalServerError - } else { - normalResponse = true - } - } - headers.Add(h[0], h[1]) - } - proxywasm.LogDebugf("http call end, id: %s, code: %d, normal: %t, body: %s", - requestID, code, normalResponse, respBody) - callback(code, headers, respBody) - }) - proxywasm.LogDebugf("http call start, id: %s, cluster: %s, method: %s, url: %s, headers: %#v, body: %s, timeout: %d", - requestID, cluster.ClusterName(), method, rawURL, headers, body, timeout) - return err -} diff --git a/plugins/wasm-go/pkg/wrapper/log_wrapper.go b/plugins/wasm-go/pkg/wrapper/log_wrapper.go deleted file mode 100644 index 866081f8a..000000000 --- a/plugins/wasm-go/pkg/wrapper/log_wrapper.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "fmt" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" -) - -type LogLevel uint32 - -const ( - LogLevelTrace LogLevel = iota - LogLevelDebug - LogLevelInfo - LogLevelWarn - LogLevelError - LogLevelCritical -) - -type DefaultLog struct { - pluginName string - pluginID string -} - -func (l *DefaultLog) log(level LogLevel, msg string) { - requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) - requestID := string(requestIDRaw) - if requestID == "" { - requestID = "nil" - } - msg = fmt.Sprintf("[%s] [%s] [%s] %s", l.pluginName, l.pluginID, requestID, msg) - switch level { - case LogLevelTrace: - proxywasm.LogTrace(msg) - case LogLevelDebug: - proxywasm.LogDebug(msg) - case LogLevelInfo: - proxywasm.LogInfo(msg) - case LogLevelWarn: - proxywasm.LogWarn(msg) - case LogLevelError: - proxywasm.LogError(msg) - case LogLevelCritical: - proxywasm.LogCritical(msg) - } -} - -func (l *DefaultLog) logFormat(level LogLevel, format string, args ...interface{}) { - requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) - requestID := string(requestIDRaw) - if requestID == "" { - requestID = "nil" - } - format = fmt.Sprintf("[%s] [%s] [%s] %s", l.pluginName, l.pluginID, requestID, format) - switch level { - case LogLevelTrace: - proxywasm.LogTracef(format, args...) - case LogLevelDebug: - proxywasm.LogDebugf(format, args...) - case LogLevelInfo: - proxywasm.LogInfof(format, args...) - case LogLevelWarn: - proxywasm.LogWarnf(format, args...) - case LogLevelError: - proxywasm.LogErrorf(format, args...) - case LogLevelCritical: - proxywasm.LogCriticalf(format, args...) - } -} - -func (l *DefaultLog) Trace(msg string) { - l.log(LogLevelTrace, msg) -} - -func (l *DefaultLog) Tracef(format string, args ...interface{}) { - l.logFormat(LogLevelTrace, format, args...) -} - -func (l *DefaultLog) Debug(msg string) { - l.log(LogLevelDebug, msg) -} - -func (l *DefaultLog) Debugf(format string, args ...interface{}) { - l.logFormat(LogLevelDebug, format, args...) -} - -func (l *DefaultLog) Info(msg string) { - l.log(LogLevelInfo, msg) -} - -func (l *DefaultLog) Infof(format string, args ...interface{}) { - l.logFormat(LogLevelInfo, format, args...) -} - -func (l *DefaultLog) Warn(msg string) { - l.log(LogLevelWarn, msg) -} - -func (l *DefaultLog) Warnf(format string, args ...interface{}) { - l.logFormat(LogLevelWarn, format, args...) -} - -func (l *DefaultLog) Error(msg string) { - l.log(LogLevelError, msg) -} - -func (l *DefaultLog) Errorf(format string, args ...interface{}) { - l.logFormat(LogLevelError, format, args...) -} - -func (l *DefaultLog) Critical(msg string) { - l.log(LogLevelCritical, msg) -} - -func (l *DefaultLog) Criticalf(format string, args ...interface{}) { - l.logFormat(LogLevelCritical, format, args...) -} - -func (l *DefaultLog) ResetID(pluginID string) { - l.pluginID = pluginID -} diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go deleted file mode 100644 index 7bfc7b9f9..000000000 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ /dev/null @@ -1,791 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "encoding/json" - "fmt" - "strconv" - "time" - "unsafe" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - "github.com/alibaba/higress/plugins/wasm-go/pkg/log" - "github.com/alibaba/higress/plugins/wasm-go/pkg/matcher" -) - -type Log log.Log - -const ( - CustomLogKey = "custom_log" - AILogKey = "ai_log" - TraceSpanTagPrefix = "trace_span_tag." - PluginIDKey = "_plugin_id_" -) - -type HttpContext interface { - Scheme() string - Host() string - Path() string - Method() string - SetContext(key string, value interface{}) - GetContext(key string) interface{} - GetBoolContext(key string, defaultValue bool) bool - GetStringContext(key, defaultValue string) string - GetByteSliceContext(key string, defaultValue []byte) []byte - GetUserAttribute(key string) interface{} - SetUserAttribute(key string, value interface{}) - SetUserAttributeMap(kvmap map[string]interface{}) - GetUserAttributeMap() map[string]interface{} - // You can call this function to set custom log - WriteUserAttributeToLog() error - // You can call this function to set custom log with your specific key - WriteUserAttributeToLogWithKey(key string) error - // You can call this function to set custom trace span attribute - WriteUserAttributeToTrace() error - // If the onHttpRequestBody handle is not set, the request body will not be read by default - DontReadRequestBody() - // If the onHttpResponseBody handle is not set, the request body will not be read by default - DontReadResponseBody() - // If the onHttpStreamingRequestBody handle is not set, and the onHttpRequestBody handle is set, the request body will be buffered by default - BufferRequestBody() - // If the onHttpStreamingResponseBody handle is not set, and the onHttpResponseBody handle is set, the response body will be buffered by default - BufferResponseBody() - // If any request header is changed in onHttpRequestHeaders, envoy will re-calculate the route. Call this function to disable the re-routing. - // You need to call this before making any header modification operations. - DisableReroute() - // Note that this parameter affects the gateway's memory usage!Support setting a maximum buffer size for each request body individually in request phase. - SetRequestBodyBufferLimit(byteSize uint32) - // Note that this parameter affects the gateway's memory usage! Support setting a maximum buffer size for each response body individually in response phase. - SetResponseBodyBufferLimit(byteSize uint32) - // Get contextId of HttpContext - GetContextId() uint32 -} - -type oldParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig, log Log) error -type oldParseRuleConfigFunc[PluginConfig any] func(json gjson.Result, global PluginConfig, config *PluginConfig, log Log) error -type oldOnHttpHeadersFunc[PluginConfig any] func(context HttpContext, config PluginConfig, log Log) types.Action -type oldOnHttpBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, body []byte, log Log) types.Action -type oldOnHttpStreamingBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log Log) []byte -type oldOnHttpStreamDoneFunc[PluginConfig any] func(context HttpContext, config PluginConfig, log Log) - -type ParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig) error -type ParseRuleConfigFunc[PluginConfig any] func(json gjson.Result, global PluginConfig, config *PluginConfig) error -type onHttpHeadersFunc[PluginConfig any] func(context HttpContext, config PluginConfig) types.Action -type onHttpBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, body []byte) types.Action -type onHttpStreamingBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte -type onHttpStreamDoneFunc[PluginConfig any] func(context HttpContext, config PluginConfig) - -type CommonVmCtx[PluginConfig any] struct { - types.DefaultVMContext - pluginName string - log Log - hasCustomConfig bool - parseConfig ParseConfigFunc[PluginConfig] - parseRuleConfig ParseRuleConfigFunc[PluginConfig] - onHttpRequestHeaders onHttpHeadersFunc[PluginConfig] - onHttpRequestBody onHttpBodyFunc[PluginConfig] - onHttpStreamingRequestBody onHttpStreamingBodyFunc[PluginConfig] - onHttpResponseHeaders onHttpHeadersFunc[PluginConfig] - onHttpResponseBody onHttpBodyFunc[PluginConfig] - onHttpStreamingResponseBody onHttpStreamingBodyFunc[PluginConfig] - onHttpStreamDone onHttpStreamDoneFunc[PluginConfig] -} - -type TickFuncEntry struct { - lastExecuted int64 - tickPeriod int64 - tickFunc func() -} - -var globalOnTickFuncs []TickFuncEntry = []TickFuncEntry{} - -// Register multiple onTick functions. Parameters include: -// 1) tickPeriod: the execution period of tickFunc, this value should be a multiple of 100 -// 2) tickFunc: the function to be executed -// -// You should call this function in parseConfig phase, for example: -// -// func parseConfig(json gjson.Result, config *HelloWorldConfig, log wrapper.Log) error { -// wrapper.RegisterTickFunc(1000, func() { proxywasm.LogInfo("onTick 1s") }) -// wrapper.RegisterTickFunc(3000, func() { proxywasm.LogInfo("onTick 3s") }) -// return nil -// } -func RegisterTickFunc(tickPeriod int64, tickFunc func()) { - globalOnTickFuncs = append(globalOnTickFuncs, TickFuncEntry{0, tickPeriod, tickFunc}) -} - -func SetCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) { - proxywasm.SetVMContext(NewCommonVmCtx(pluginName, options...)) -} - -func SetCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) { - proxywasm.SetVMContext(NewCommonVmCtxWithOptions(pluginName, options...)) -} - -type CtxOption[PluginConfig any] interface { - Apply(*CommonVmCtx[PluginConfig]) -} - -type parseConfigOption[PluginConfig any] struct { - f ParseConfigFunc[PluginConfig] - oldF oldParseConfigFunc[PluginConfig] -} - -func (o parseConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.parseConfig = o.f - } else { - ctx.parseConfig = func(json gjson.Result, config *PluginConfig) error { return o.oldF(json, config, ctx.log) } - } -} - -// Deprecated: Please use `ParseConfig` instead. -func ParseConfigBy[PluginConfig any](f oldParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] { - return &parseConfigOption[PluginConfig]{oldF: f} -} - -func ParseConfig[PluginConfig any](f ParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] { - return &parseConfigOption[PluginConfig]{f: f} -} - -type parseOverrideConfigOption[PluginConfig any] struct { - parseConfigF ParseConfigFunc[PluginConfig] - parseRuleConfigF ParseRuleConfigFunc[PluginConfig] - oldParseConfigF oldParseConfigFunc[PluginConfig] - oldParseRuleConfigF oldParseRuleConfigFunc[PluginConfig] -} - -func (o *parseOverrideConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.parseConfigF != nil && o.parseRuleConfigF != nil { - ctx.parseConfig = o.parseConfigF - ctx.parseRuleConfig = o.parseRuleConfigF - } else { - ctx.parseConfig = func(json gjson.Result, config *PluginConfig) error { - return o.oldParseConfigF(json, config, ctx.log) - } - ctx.parseRuleConfig = func(json gjson.Result, global PluginConfig, config *PluginConfig) error { - return o.oldParseRuleConfigF(json, global, config, ctx.log) - } - } -} - -// Deprecated: Please use `ParseOverrideConfig` instead. -func ParseOverrideConfigBy[PluginConfig any](f oldParseConfigFunc[PluginConfig], g oldParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] { - return &parseOverrideConfigOption[PluginConfig]{ - oldParseConfigF: f, - oldParseRuleConfigF: g, - } -} - -func ParseOverrideConfig[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] { - return &parseOverrideConfigOption[PluginConfig]{ - parseConfigF: f, - parseRuleConfigF: g, - } -} - -type onProcessRequestHeadersOption[PluginConfig any] struct { - f onHttpHeadersFunc[PluginConfig] - oldF oldOnHttpHeadersFunc[PluginConfig] -} - -func (o *onProcessRequestHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpRequestHeaders = o.f - } else { - ctx.onHttpRequestHeaders = func(context HttpContext, config PluginConfig) types.Action { - return o.oldF(context, config, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessRequestHeaders` instead. -func ProcessRequestHeadersBy[PluginConfig any](f oldOnHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessRequestHeadersOption[PluginConfig]{oldF: f} -} - -func ProcessRequestHeaders[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessRequestHeadersOption[PluginConfig]{f: f} -} - -type onProcessRequestBodyOption[PluginConfig any] struct { - f onHttpBodyFunc[PluginConfig] - oldF oldOnHttpBodyFunc[PluginConfig] -} - -func (o *onProcessRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpRequestBody = o.f - } else { - ctx.onHttpRequestBody = func(context HttpContext, config PluginConfig, body []byte) types.Action { - return o.oldF(context, config, body, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessRequestBody` instead. -func ProcessRequestBodyBy[PluginConfig any](f oldOnHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessRequestBodyOption[PluginConfig]{oldF: f} -} - -func ProcessRequestBody[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessRequestBodyOption[PluginConfig]{f: f} -} - -type onProcessStreamingRequestBodyOption[PluginConfig any] struct { - f onHttpStreamingBodyFunc[PluginConfig] - oldF oldOnHttpStreamingBodyFunc[PluginConfig] -} - -func (o *onProcessStreamingRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpStreamingRequestBody = o.f - } else { - ctx.onHttpStreamingRequestBody = func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte { - return o.oldF(context, config, chunk, isLastChunk, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessStreamingRequestBody` instead. -func ProcessStreamingRequestBodyBy[PluginConfig any](f oldOnHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamingRequestBodyOption[PluginConfig]{oldF: f} -} - -func ProcessStreamingRequestBody[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamingRequestBodyOption[PluginConfig]{f: f} -} - -type onProcessResponseHeadersOption[PluginConfig any] struct { - f onHttpHeadersFunc[PluginConfig] - oldF oldOnHttpHeadersFunc[PluginConfig] -} - -func (o *onProcessResponseHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpResponseHeaders = o.f - } else { - ctx.onHttpResponseHeaders = func(context HttpContext, config PluginConfig) types.Action { - return o.oldF(context, config, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessResponseHeaders` instead. -func ProcessResponseHeadersBy[PluginConfig any](f oldOnHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessResponseHeadersOption[PluginConfig]{oldF: f} -} - -func ProcessResponseHeaders[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessResponseHeadersOption[PluginConfig]{f: f} -} - -type onProcessResponseBodyOption[PluginConfig any] struct { - f onHttpBodyFunc[PluginConfig] - oldF oldOnHttpBodyFunc[PluginConfig] -} - -func (o *onProcessResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpResponseBody = o.f - } else { - ctx.onHttpResponseBody = func(context HttpContext, config PluginConfig, body []byte) types.Action { - return o.oldF(context, config, body, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessResponseBody` instead. -func ProcessResponseBodyBy[PluginConfig any](f oldOnHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessResponseBodyOption[PluginConfig]{oldF: f} -} - -func ProcessResponseBody[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessResponseBodyOption[PluginConfig]{f: f} -} - -type onProcessStreamingResponseBodyOption[PluginConfig any] struct { - f onHttpStreamingBodyFunc[PluginConfig] - oldF oldOnHttpStreamingBodyFunc[PluginConfig] -} - -func (o *onProcessStreamingResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpStreamingResponseBody = o.f - } else { - ctx.onHttpStreamingResponseBody = func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte { - return o.oldF(context, config, chunk, isLastChunk, ctx.log) - } - } -} - -// Deprecated: Please use `ProcessStreamingResponseBody` instead. -func ProcessStreamingResponseBodyBy[PluginConfig any](f oldOnHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamingResponseBodyOption[PluginConfig]{oldF: f} -} - -func ProcessStreamingResponseBody[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamingResponseBodyOption[PluginConfig]{f: f} -} - -type onProcessStreamDoneOption[PluginConfig any] struct { - f onHttpStreamDoneFunc[PluginConfig] - oldF oldOnHttpStreamDoneFunc[PluginConfig] -} - -func (o *onProcessStreamDoneOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - if o.f != nil { - ctx.onHttpStreamDone = o.f - } else { - ctx.onHttpStreamDone = func(context HttpContext, config PluginConfig) { o.oldF(context, config, ctx.log) } - } - -} - -// Deprecated: Please use `ProcessStreamDoneBy` instead. -func ProcessStreamDoneBy[PluginConfig any](f oldOnHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamDoneOption[PluginConfig]{oldF: f} -} - -func ProcessStreamDone[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] { - return &onProcessStreamDoneOption[PluginConfig]{f: f} -} - -type logOption[PluginConfig any] struct { - logger Log -} - -func (o *logOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) { - log.SetPluginLog(o.logger) - ctx.log = o.logger -} - -func WithLogger[PluginConfig any](logger Log) CtxOption[PluginConfig] { - return &logOption[PluginConfig]{logger} -} - -func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig) error { - return nil -} - -func NewCommonVmCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] { - logger := &DefaultLog{pluginName, "nil"} - opts := []CtxOption[PluginConfig]{WithLogger[PluginConfig](logger)} - for _, opt := range options { - if opt == nil { - continue - } - opts = append(opts, opt) - } - return NewCommonVmCtxWithOptions(pluginName, opts...) -} - -func NewCommonVmCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] { - ctx := &CommonVmCtx[PluginConfig]{ - pluginName: pluginName, - hasCustomConfig: true, - } - for _, opt := range options { - opt.Apply(ctx) - } - if ctx.parseConfig == nil { - var config PluginConfig - if unsafe.Sizeof(config) != 0 { - msg := "the `parseConfig` is missing in NewCommonVmCtx's arguments" - panic(msg) - } - ctx.hasCustomConfig = false - ctx.parseConfig = parseEmptyPluginConfig[PluginConfig] - } - return ctx -} - -func (ctx *CommonVmCtx[PluginConfig]) NewPluginContext(uint32) types.PluginContext { - return &CommonPluginCtx[PluginConfig]{ - vm: ctx, - } -} - -type CommonPluginCtx[PluginConfig any] struct { - types.DefaultPluginContext - matcher.RuleMatcher[PluginConfig] - vm *CommonVmCtx[PluginConfig] - onTickFuncs []TickFuncEntry -} - -func (ctx *CommonPluginCtx[PluginConfig]) OnPluginStart(int) types.OnPluginStartStatus { - data, err := proxywasm.GetPluginConfiguration() - globalOnTickFuncs = nil - if err != nil && err != types.ErrorStatusNotFound { - ctx.vm.log.Criticalf("error reading plugin configuration: %v", err) - return types.OnPluginStartStatusFailed - } - var jsonData gjson.Result - if len(data) == 0 { - if ctx.vm.hasCustomConfig { - ctx.vm.log.Warn("config is empty, but has ParseConfigFunc") - } - } else { - if !gjson.ValidBytes(data) { - ctx.vm.log.Warnf("the plugin configuration is not a valid json: %s", string(data)) - return types.OnPluginStartStatusFailed - } - pluginID := gjson.GetBytes(data, PluginIDKey).String() - if pluginID != "" { - ctx.vm.log.ResetID(pluginID) - data, _ = sjson.DeleteBytes([]byte(data), PluginIDKey) - } - jsonData = gjson.ParseBytes(data) - } - - var parseOverrideConfig func(gjson.Result, PluginConfig, *PluginConfig) error - if ctx.vm.parseRuleConfig != nil { - parseOverrideConfig = func(js gjson.Result, global PluginConfig, cfg *PluginConfig) error { - return ctx.vm.parseRuleConfig(js, global, cfg) - } - } - err = ctx.ParseRuleConfig(jsonData, - func(js gjson.Result, cfg *PluginConfig) error { - return ctx.vm.parseConfig(js, cfg) - }, - parseOverrideConfig, - ) - if err != nil { - ctx.vm.log.Warnf("parse rule config failed: %v", err) - ctx.vm.log.Error("plugin start failed") - return types.OnPluginStartStatusFailed - } - if globalOnTickFuncs != nil { - ctx.onTickFuncs = globalOnTickFuncs - if err := proxywasm.SetTickPeriodMilliSeconds(100); err != nil { - ctx.vm.log.Error("SetTickPeriodMilliSeconds failed, onTick functions will not take effect.") - ctx.vm.log.Error("plugin start failed") - return types.OnPluginStartStatusFailed - } - } - ctx.vm.log.Info("plugin start successfully") - return types.OnPluginStartStatusOK -} - -func (ctx *CommonPluginCtx[PluginConfig]) OnTick() { - for i := range ctx.onTickFuncs { - currentTimeStamp := time.Now().UnixMilli() - if currentTimeStamp-ctx.onTickFuncs[i].lastExecuted >= ctx.onTickFuncs[i].tickPeriod { - ctx.onTickFuncs[i].tickFunc() - ctx.onTickFuncs[i].lastExecuted = currentTimeStamp - } - } -} - -func (ctx *CommonPluginCtx[PluginConfig]) NewHttpContext(contextID uint32) types.HttpContext { - httpCtx := &CommonHttpCtx[PluginConfig]{ - plugin: ctx, - contextID: contextID, - userContext: map[string]interface{}{}, - userAttribute: map[string]interface{}{}, - } - if ctx.vm.onHttpRequestBody != nil || ctx.vm.onHttpStreamingRequestBody != nil { - httpCtx.needRequestBody = true - } - if ctx.vm.onHttpResponseBody != nil || ctx.vm.onHttpStreamingResponseBody != nil { - httpCtx.needResponseBody = true - } - if ctx.vm.onHttpStreamingRequestBody != nil { - httpCtx.streamingRequestBody = true - } - if ctx.vm.onHttpStreamingResponseBody != nil { - httpCtx.streamingResponseBody = true - } - - return httpCtx -} - -type CommonHttpCtx[PluginConfig any] struct { - types.DefaultHttpContext - plugin *CommonPluginCtx[PluginConfig] - config *PluginConfig - needRequestBody bool - needResponseBody bool - streamingRequestBody bool - streamingResponseBody bool - requestBodySize int - responseBodySize int - contextID uint32 - userContext map[string]interface{} - userAttribute map[string]interface{} -} - -func (ctx *CommonHttpCtx[PluginConfig]) SetContext(key string, value interface{}) { - ctx.userContext[key] = value -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} { - return ctx.userContext[key] -} - -func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttribute(key string, value interface{}) { - ctx.userAttribute[key] = value -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{} { - return ctx.userAttribute[key] -} - -func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) { - ctx.userAttribute = kvmap -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} { - return ctx.userAttribute -} - -func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error { - return ctx.WriteUserAttributeToLogWithKey(CustomLogKey) -} - -func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLogWithKey(key string) error { - // e.g. {\"field1\":\"value1\",\"field2\":\"value2\"} - preMarshalledJsonLogStr, _ := proxywasm.GetProperty([]string{key}) - newAttributeMap := map[string]interface{}{} - if string(preMarshalledJsonLogStr) != "" { - // e.g. {"field1":"value1","field2":"value2"} - preJsonLogStr := UnmarshalStr(fmt.Sprintf(`"%s"`, string(preMarshalledJsonLogStr))) - err := json.Unmarshal([]byte(preJsonLogStr), &newAttributeMap) - if err != nil { - ctx.plugin.vm.log.Warnf("Unmarshal failed, will overwrite %s, pre value is: %s", key, string(preMarshalledJsonLogStr)) - return err - } - } - // update customLog - for k, v := range ctx.userAttribute { - newAttributeMap[k] = v - } - // e.g. {"field1":"value1","field2":2,"field3":"value3"} - jsonStr, _ := json.Marshal(newAttributeMap) - // e.g. {\"field1\":\"value1\",\"field2\":2,\"field3\":\"value3\"} - marshalledJsonStr := MarshalStr(string(jsonStr)) - if err := proxywasm.SetProperty([]string{key}, []byte(marshalledJsonStr)); err != nil { - ctx.plugin.vm.log.Warnf("failed to set %s in filter state, raw is %s, err is %v", key, marshalledJsonStr, err) - return err - } - return nil -} - -func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToTrace() error { - for k, v := range ctx.userAttribute { - traceSpanTag := TraceSpanTagPrefix + k - traceSpanValue := fmt.Sprint(v) - var err error - if traceSpanValue != "" { - err = proxywasm.SetProperty([]string{traceSpanTag}, []byte(traceSpanValue)) - } else { - err = fmt.Errorf("value of %s is empty", traceSpanTag) - } - if err != nil { - ctx.plugin.vm.log.Warnf("Failed to set trace attribute - %s: %s, error message: %v", traceSpanTag, traceSpanValue, err) - } - } - return nil -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetBoolContext(key string, defaultValue bool) bool { - if b, ok := ctx.userContext[key].(bool); ok { - return b - } - return defaultValue -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetStringContext(key, defaultValue string) string { - if s, ok := ctx.userContext[key].(string); ok { - return s - } - return defaultValue -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetByteSliceContext(key string, defaultValue []byte) []byte { - if s, ok := ctx.userContext[key].([]byte); ok { - return s - } - return defaultValue -} - -func (ctx *CommonHttpCtx[PluginConfig]) Scheme() string { - proxywasm.SetEffectiveContext(ctx.contextID) - return GetRequestScheme() -} - -func (ctx *CommonHttpCtx[PluginConfig]) Host() string { - proxywasm.SetEffectiveContext(ctx.contextID) - return GetRequestHost() -} - -func (ctx *CommonHttpCtx[PluginConfig]) Path() string { - proxywasm.SetEffectiveContext(ctx.contextID) - return GetRequestPath() -} - -func (ctx *CommonHttpCtx[PluginConfig]) Method() string { - proxywasm.SetEffectiveContext(ctx.contextID) - return GetRequestMethod() -} - -func (ctx *CommonHttpCtx[PluginConfig]) DontReadRequestBody() { - ctx.needRequestBody = false -} - -func (ctx *CommonHttpCtx[PluginConfig]) DontReadResponseBody() { - ctx.needResponseBody = false -} - -func (ctx *CommonHttpCtx[PluginConfig]) BufferRequestBody() { - ctx.streamingRequestBody = false -} - -func (ctx *CommonHttpCtx[PluginConfig]) BufferResponseBody() { - ctx.streamingResponseBody = false -} - -func (ctx *CommonHttpCtx[PluginConfig]) DisableReroute() { - _ = proxywasm.SetProperty([]string{"clear_route_cache"}, []byte("off")) -} - -func (ctx *CommonHttpCtx[PluginConfig]) SetRequestBodyBufferLimit(size uint32) { - ctx.plugin.vm.log.Infof("SetRequestBodyBufferLimit: %d", size) - _ = proxywasm.SetProperty([]string{"set_decoder_buffer_limit"}, []byte(strconv.Itoa(int(size)))) -} - -func (ctx *CommonHttpCtx[PluginConfig]) SetResponseBodyBufferLimit(size uint32) { - ctx.plugin.vm.log.Infof("SetResponseBodyBufferLimit: %d", size) - _ = proxywasm.SetProperty([]string{"set_encoder_buffer_limit"}, []byte(strconv.Itoa(int(size)))) -} - -func (ctx *CommonHttpCtx[PluginConfig]) GetContextId() uint32 { - return ctx.contextID -} - -func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { - requestID, _ := proxywasm.GetHttpRequestHeader("x-request-id") - _ = proxywasm.SetProperty([]string{"x_request_id"}, []byte(requestID)) - config, err := ctx.plugin.GetMatchConfig() - if err != nil { - ctx.plugin.vm.log.Errorf("get match config failed, err:%v", err) - return types.ActionContinue - } - if config == nil { - return types.ActionContinue - } - ctx.config = config - // To avoid unexpected operations, plugins do not read the binary content body - if IsBinaryRequestBody() { - ctx.needRequestBody = false - } - if ctx.plugin.vm.onHttpRequestHeaders == nil { - return types.ActionContinue - } - return ctx.plugin.vm.onHttpRequestHeaders(ctx, *config) -} - -func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { - if ctx.config == nil { - return types.ActionContinue - } - if !ctx.needRequestBody { - return types.ActionContinue - } - if ctx.plugin.vm.onHttpStreamingRequestBody != nil && ctx.streamingRequestBody { - chunk, _ := proxywasm.GetHttpRequestBody(0, bodySize) - modifiedChunk := ctx.plugin.vm.onHttpStreamingRequestBody(ctx, *ctx.config, chunk, endOfStream) - err := proxywasm.ReplaceHttpRequestBody(modifiedChunk) - if err != nil { - ctx.plugin.vm.log.Warnf("replace request body chunk failed: %v", err) - return types.ActionContinue - } - return types.ActionContinue - } - if ctx.plugin.vm.onHttpRequestBody != nil { - ctx.requestBodySize += bodySize - if !endOfStream { - return types.ActionPause - } - body, err := proxywasm.GetHttpRequestBody(0, ctx.requestBodySize) - if err != nil { - ctx.plugin.vm.log.Warnf("get request body failed: %v", err) - return types.ActionContinue - } - return ctx.plugin.vm.onHttpRequestBody(ctx, *ctx.config, body) - } - return types.ActionContinue -} - -func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseHeaders(numHeaders int, endOfStream bool) types.Action { - if ctx.config == nil { - return types.ActionContinue - } - // To avoid unexpected operations, plugins do not read the binary content body - if IsBinaryResponseBody() { - ctx.needResponseBody = false - } - if ctx.plugin.vm.onHttpResponseHeaders == nil { - return types.ActionContinue - } - return ctx.plugin.vm.onHttpResponseHeaders(ctx, *ctx.config) -} - -func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseBody(bodySize int, endOfStream bool) types.Action { - if ctx.config == nil { - return types.ActionContinue - } - if !ctx.needResponseBody { - return types.ActionContinue - } - if ctx.plugin.vm.onHttpStreamingResponseBody != nil && ctx.streamingResponseBody { - chunk, _ := proxywasm.GetHttpResponseBody(0, bodySize) - modifiedChunk := ctx.plugin.vm.onHttpStreamingResponseBody(ctx, *ctx.config, chunk, endOfStream) - err := proxywasm.ReplaceHttpResponseBody(modifiedChunk) - if err != nil { - ctx.plugin.vm.log.Warnf("replace response body chunk failed: %v", err) - return types.ActionContinue - } - return types.ActionContinue - } - if ctx.plugin.vm.onHttpResponseBody != nil { - ctx.responseBodySize += bodySize - if !endOfStream { - return types.ActionPause - } - body, err := proxywasm.GetHttpResponseBody(0, ctx.responseBodySize) - if err != nil { - ctx.plugin.vm.log.Warnf("get response body failed: %v", err) - return types.ActionContinue - } - return ctx.plugin.vm.onHttpResponseBody(ctx, *ctx.config, body) - } - return types.ActionContinue -} - -func (ctx *CommonHttpCtx[PluginConfig]) OnHttpStreamDone() { - if ctx.config == nil { - return - } - if ctx.plugin.vm.onHttpStreamDone == nil { - return - } - ctx.plugin.vm.onHttpStreamDone(ctx, *ctx.config) -} diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go deleted file mode 100644 index 91422cf41..000000000 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ /dev/null @@ -1,902 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "bytes" - "encoding/base64" - "errors" - "fmt" - "io" - - "github.com/google/uuid" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/tidwall/resp" -) - -type RedisResponseCallback func(response resp.Value) - -type RedisClient interface { - Init(username, password string, timeout int64, opts ...optionFunc) error - // return whether redis client is ready - Ready() bool - // with this function, you can call redis as if you are using redis-cli - Command(cmds []interface{}, callback RedisResponseCallback) error - Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error - - // Key - Del(key string, callback RedisResponseCallback) error - Exists(key string, callback RedisResponseCallback) error - Expire(key string, ttl int, callback RedisResponseCallback) error - Persist(key string, callback RedisResponseCallback) error - - // String - Get(key string, callback RedisResponseCallback) error - Set(key string, value interface{}, callback RedisResponseCallback) error - SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error - SetNX(key string, value interface{}, ttl int, callback RedisResponseCallback) error - MGet(keys []string, callback RedisResponseCallback) error - MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error - Incr(key string, callback RedisResponseCallback) error - Decr(key string, callback RedisResponseCallback) error - IncrBy(key string, delta int, callback RedisResponseCallback) error - DecrBy(key string, delta int, callback RedisResponseCallback) error - - // List - LLen(key string, callback RedisResponseCallback) error - RPush(key string, vals []interface{}, callback RedisResponseCallback) error - RPop(key string, callback RedisResponseCallback) error - LPush(key string, vals []interface{}, callback RedisResponseCallback) error - LPop(key string, callback RedisResponseCallback) error - LIndex(key string, index int, callback RedisResponseCallback) error - LRange(key string, start, stop int, callback RedisResponseCallback) error - LRem(key string, count int, value interface{}, callback RedisResponseCallback) error - LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error - LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error - - // Hash - HExists(key, field string, callback RedisResponseCallback) error - HDel(key string, fields []string, callback RedisResponseCallback) error - HLen(key string, callback RedisResponseCallback) error - HGet(key, field string, callback RedisResponseCallback) error - HSet(key, field string, value interface{}, callback RedisResponseCallback) error - HMGet(key string, fields []string, callback RedisResponseCallback) error - HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error - HKeys(key string, callback RedisResponseCallback) error - HVals(key string, callback RedisResponseCallback) error - HGetAll(key string, callback RedisResponseCallback) error - HIncrBy(key, field string, delta int, callback RedisResponseCallback) error - HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error - - // Set - SCard(key string, callback RedisResponseCallback) error - SAdd(key string, value []interface{}, callback RedisResponseCallback) error - SRem(key string, values []interface{}, callback RedisResponseCallback) error - SIsMember(key string, value interface{}, callback RedisResponseCallback) error - SMembers(key string, callback RedisResponseCallback) error - SDiff(key1, key2 string, callback RedisResponseCallback) error - SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error - SInter(key1, key2 string, callback RedisResponseCallback) error - SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error - SUnion(key1, key2 string, callback RedisResponseCallback) error - SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error - - // Sorted Set - ZCard(key string, callback RedisResponseCallback) error - ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error - ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error - ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error - ZScore(key, member string, callback RedisResponseCallback) error - ZRank(key, member string, callback RedisResponseCallback) error - ZRevRank(key, member string, callback RedisResponseCallback) error - ZRem(key string, members []string, callback RedisResponseCallback) error - ZRange(key string, start, stop int, callback RedisResponseCallback) error - ZRevRange(key string, start, stop int, callback RedisResponseCallback) error -} - -type RedisClusterClient[C Cluster] struct { - cluster C - ready bool - checkReadyFunc func() error - option redisOption -} - -type redisOption struct { - dataBase int -} - -type optionFunc func(*redisOption) - -func WithDataBase(dataBase int) optionFunc { - return func(o *redisOption) { - o.dataBase = dataBase - } -} - -func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] { - return &RedisClusterClient[C]{ - cluster: cluster, - checkReadyFunc: func() error { - return errors.New("redis client is not ready, please call Init() first") - }, - } -} - -func RedisCall(cluster Cluster, respQuery []byte, callback RedisResponseCallback) error { - requestID := uuid.New().String() - _, err := proxywasm.DispatchRedisCall( - cluster.ClusterName(), - respQuery, - func(status int, responseSize int) { - response, err := proxywasm.GetRedisCallResponse(0, responseSize) - var responseValue resp.Value - if status != 0 { - proxywasm.LogCriticalf("Error occurred while calling redis, it seems cannot connect to the redis cluster. request-id: %s", requestID) - responseValue = resp.ErrorValue(fmt.Errorf("cannot connect to redis cluster")) - } else { - if err != nil { - proxywasm.LogCriticalf("failed to get redis response body, request-id: %s, error: %v", requestID, err) - responseValue = resp.ErrorValue(fmt.Errorf("cannot get redis response")) - } else { - rd := resp.NewReader(bytes.NewReader(response)) - value, _, err := rd.ReadValue() - if err != nil && err != io.EOF { - proxywasm.LogCriticalf("failed to read redis response body, request-id: %s, error: %v", requestID, err) - responseValue = resp.ErrorValue(fmt.Errorf("cannot read redis response")) - } else { - responseValue = value - proxywasm.LogDebugf("redis call end, request-id: %s, respQuery: %s, respValue: %s", - requestID, base64.StdEncoding.EncodeToString([]byte(respQuery)), base64.StdEncoding.EncodeToString(response)) - } - } - } - if callback != nil { - callback(responseValue) - } - }) - if err != nil { - proxywasm.LogCriticalf("redis call failed, request-id: %s, error: %v", requestID, err) - } else { - proxywasm.LogDebugf("redis call start, request-id: %s, respQuery: %s", requestID, base64.StdEncoding.EncodeToString([]byte(respQuery))) - } - return err -} - -func respString(args []interface{}) []byte { - var buf bytes.Buffer - wr := resp.NewWriter(&buf) - arr := make([]resp.Value, 0) - for _, arg := range args { - arr = append(arr, resp.StringValue(fmt.Sprint(arg))) - } - wr.WriteArray(arr) - return buf.Bytes() -} - -func (c *RedisClusterClient[C]) Ready() bool { - return c.ready -} - -func (c *RedisClusterClient[C]) Init(username, password string, timeout int64, opts ...optionFunc) error { - for _, opt := range opts { - opt(&c.option) - } - clusterName := c.cluster.ClusterName() - if c.option.dataBase != 0 { - clusterName = fmt.Sprintf("%s?db=%d", clusterName, c.option.dataBase) - } - err := proxywasm.RedisInit(clusterName, username, password, uint32(timeout)) - if err != nil { - c.checkReadyFunc = func() error { - if c.ready { - return nil - } - initErr := proxywasm.RedisInit(clusterName, username, password, uint32(timeout)) - if initErr != nil { - return initErr - } - c.ready = true - return nil - } - proxywasm.LogWarnf("failed to init redis: %v, will retry after", err) - return nil - } - c.checkReadyFunc = func() error { return nil } - c.ready = true - return nil -} - -func (c *RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - return RedisCall(c.cluster, respString(cmds), callback) -} - -func (c *RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - params := make([]interface{}, 0) - params = append(params, "eval") - params = append(params, script) - params = append(params, numkeys) - params = append(params, keys...) - params = append(params, args...) - return RedisCall(c.cluster, respString(params), callback) -} - -// Key -func (c *RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "del") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "exists") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "expire") - args = append(args, key) - args = append(args, ttl) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "persist") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -// String -func (c *RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "get") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "set") - args = append(args, key) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "set") - args = append(args, key) - args = append(args, value) - args = append(args, "ex") - args = append(args, ttl) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SetNX(key string, value interface{}, ttl int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "set") - args = append(args, key) - args = append(args, value) - args = append(args, "nx") - if ttl > 0 { - args = append(args, "ex") - args = append(args, ttl) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "mget") - for _, k := range keys { - args = append(args, k) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "mset") - for k, v := range kvMap { - args = append(args, k) - args = append(args, v) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "incr") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "decr") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "incrby") - args = append(args, key) - args = append(args, delta) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "decrby") - args = append(args, key) - args = append(args, delta) - return RedisCall(c.cluster, respString(args), callback) -} - -// List -func (c *RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "llen") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "rpush") - args = append(args, key) - for _, val := range vals { - args = append(args, val) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "rpop") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "lpush") - args = append(args, key) - for _, val := range vals { - args = append(args, val) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "lpop") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "lindex") - args = append(args, key) - args = append(args, index) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "lrange") - args = append(args, key) - args = append(args, start) - args = append(args, stop) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "lrem") - args = append(args, key) - args = append(args, count) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "linsert") - args = append(args, key) - args = append(args, "before") - args = append(args, pivot) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "linsert") - args = append(args, key) - args = append(args, "after") - args = append(args, pivot) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -// Hash -func (c *RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hexists") - args = append(args, key) - args = append(args, field) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hdel") - args = append(args, key) - for _, field := range fields { - args = append(args, field) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hlen") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hget") - args = append(args, key) - args = append(args, field) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hset") - args = append(args, key) - args = append(args, field) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hmget") - args = append(args, key) - for _, field := range fields { - args = append(args, field) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hmset") - args = append(args, key) - for k, v := range kvMap { - args = append(args, k) - args = append(args, v) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hkeys") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hvals") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hgetall") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hincrby") - args = append(args, key) - args = append(args, field) - args = append(args, delta) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "hincrbyfloat") - args = append(args, key) - args = append(args, field) - args = append(args, delta) - return RedisCall(c.cluster, respString(args), callback) -} - -// Set -func (c *RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "scard") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sadd") - args = append(args, key) - for _, val := range vals { - args = append(args, val) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "srem") - args = append(args, key) - for _, val := range vals { - args = append(args, val) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sismember") - args = append(args, key) - args = append(args, value) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "smembers") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sdiff") - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sdiffstore") - args = append(args, destination) - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sinter") - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sinterstore") - args = append(args, destination) - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sunion") - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "sunionstore") - args = append(args, destination) - args = append(args, key1) - args = append(args, key2) - return RedisCall(c.cluster, respString(args), callback) -} - -// ZSet -func (c *RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zcard") - args = append(args, key) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zadd") - args = append(args, key) - for m, s := range msMap { - args = append(args, s) - args = append(args, m) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zcount") - args = append(args, key) - args = append(args, min) - args = append(args, max) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zincrby") - args = append(args, key) - args = append(args, delta) - args = append(args, member) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zscore") - args = append(args, key) - args = append(args, member) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zrank") - args = append(args, key) - args = append(args, member) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zrevrank") - args = append(args, key) - args = append(args, member) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zrem") - args = append(args, key) - for _, m := range members { - args = append(args, m) - } - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zrange") - args = append(args, key) - args = append(args, start) - args = append(args, stop) - return RedisCall(c.cluster, respString(args), callback) -} - -func (c *RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error { - if err := c.checkReadyFunc(); err != nil { - return err - } - args := make([]interface{}, 0) - args = append(args, "zrevrange") - args = append(args, key) - args = append(args, start) - args = append(args, stop) - return RedisCall(c.cluster, respString(args), callback) -} diff --git a/plugins/wasm-go/pkg/wrapper/request_wrapper.go b/plugins/wasm-go/pkg/wrapper/request_wrapper.go deleted file mode 100644 index 896f1bb8f..000000000 --- a/plugins/wasm-go/pkg/wrapper/request_wrapper.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2022 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wrapper - -import ( - "net/url" - "strconv" - "strings" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" -) - -func GetRequestScheme() string { - scheme, err := proxywasm.GetHttpRequestHeader(":scheme") - if err != nil { - proxywasm.LogErrorf("get request scheme failed: %v", err) - return "" - } - return scheme -} - -func GetRequestHost() string { - host, err := proxywasm.GetHttpRequestHeader(":authority") - if err != nil { - proxywasm.LogErrorf("get request host failed: %v", err) - return "" - } - return host -} - -func GetRequestPath() string { - path, err := proxywasm.GetHttpRequestHeader(":path") - if err != nil { - proxywasm.LogErrorf("get request path failed: %v", err) - return "" - } - return path -} - -func GetRequestPathWithoutQuery() string { - rawPath := GetRequestPath() - if rawPath == "" { - return "" - } - path, err := url.Parse(rawPath) - if err != nil { - proxywasm.LogErrorf("failed to parse request path '%s': %v", rawPath, err) - return "" - } - return path.Path -} - -func GetRequestMethod() string { - method, err := proxywasm.GetHttpRequestHeader(":method") - if err != nil { - proxywasm.LogErrorf("get request method failed: %v", err) - return "" - } - return method -} - -func IsBinaryRequestBody() bool { - contentType, _ := proxywasm.GetHttpRequestHeader("content-type") - if strings.Contains(contentType, "octet-stream") || - strings.Contains(contentType, "grpc") { - return true - } - encoding, _ := proxywasm.GetHttpRequestHeader("content-encoding") - if encoding != "" { - return true - } - return false -} - -func IsBinaryResponseBody() bool { - contentType, _ := proxywasm.GetHttpResponseHeader("content-type") - if strings.Contains(contentType, "octet-stream") || - strings.Contains(contentType, "grpc") { - return true - } - encoding, _ := proxywasm.GetHttpResponseHeader("content-encoding") - if encoding != "" { - return true - } - return false -} - -func HasRequestBody() bool { - contentTypeStr, _ := proxywasm.GetHttpRequestHeader("content-type") - contentLengthStr, _ := proxywasm.GetHttpRequestHeader("content-length") - transferEncodingStr, _ := proxywasm.GetHttpRequestHeader("transfer-encoding") - proxywasm.LogDebugf("check has request body: contentType:%s, contentLengthStr:%s, transferEncodingStr:%s", - contentTypeStr, contentLengthStr, transferEncodingStr) - if contentTypeStr != "" { - return true - } - if contentLengthStr != "" { - contentLength, err := strconv.Atoi(contentLengthStr) - if err == nil && contentLength > 0 { - return true - } - } - return strings.Contains(transferEncodingStr, "chunked") -} diff --git a/plugins/wasm-go/pkg/wrapper/utils.go b/plugins/wasm-go/pkg/wrapper/utils.go deleted file mode 100644 index a34d1d619..000000000 --- a/plugins/wasm-go/pkg/wrapper/utils.go +++ /dev/null @@ -1,36 +0,0 @@ -package wrapper - -import ( - "encoding/json" - - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/tidwall/gjson" -) - -func UnmarshalStr(marshalledJsonStr string) string { - // e.g. "{\"field1\":\"value1\",\"field2\":\"value2\"}" - var jsonStr string - err := json.Unmarshal([]byte(marshalledJsonStr), &jsonStr) - if err != nil { - proxywasm.LogErrorf("failed to unmarshal json string, raw string is: %s, err is: %v", marshalledJsonStr, err) - return "" - } - // e.g. {"field1":"value1","field2":"value2"} - return jsonStr -} - -func MarshalStr(raw string) string { - // e.g. {"field1":"value1","field2":"value2"} - helper := map[string]string{ - "placeholder": raw, - } - marshalledHelper, _ := json.Marshal(helper) - marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw - if len(marshalledRaw) >= 2 { - // e.g. {\"field1\":\"value1\",\"field2\":\"value2\"} - return marshalledRaw[1 : len(marshalledRaw)-1] - } else { - proxywasm.LogErrorf("failed to marshal json string, raw string is: %s", raw) - return "" - } -}