refactor: migrate MCP SDK to main repo (#3516)

This commit is contained in:
澄潭
2026-02-16 23:39:18 +08:00
committed by GitHub
parent 87c6cc9c9f
commit 9346f1340b
75 changed files with 10117 additions and 3392 deletions

View File

@@ -1,10 +1,14 @@
module jsonrpc-converter module jsonrpc-converter
go 1.24.3 go 1.24.1
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
github.com/higress-group/wasm-go v1.0.4 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
) )
@@ -15,6 +19,7 @@ require (
github.com/Masterminds/sprig/v3 v3.3.0 // indirect github.com/Masterminds/sprig/v3 v3.3.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
@@ -22,8 +27,10 @@ require (
github.com/mailru/easyjson v0.7.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cast v1.7.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect github.com/tidwall/resp v0.1.1 // indirect

View File

@@ -20,10 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.4 h1:/GqbzCw4oWqJc8UbKEfF94E3/+4CPZGbzxpKo2L3Ldk= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
github.com/higress-group/wasm-go v1.0.4/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
@@ -49,6 +49,8 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=

View File

@@ -9,8 +9,8 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )

View File

@@ -1,9 +1,15 @@
package main package main
import ( import (
"encoding/json"
"testing" "testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
) )
// TestTruncateString tests the truncateString function
func TestTruncateString(t *testing.T) { func TestTruncateString(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -14,6 +20,8 @@ func TestTruncateString(t *testing.T) {
{"Short String", "Higress Is an AI-Native API Gateway", 1000, "Higress Is an AI-Native API Gateway"}, {"Short String", "Higress Is an AI-Native API Gateway", 1000, "Higress Is an AI-Native API Gateway"},
{"Exact Length", "Higress Is an AI-Native API Gateway", 35, "Higress Is an AI-Native API Gateway"}, {"Exact Length", "Higress Is an AI-Native API Gateway", 35, "Higress Is an AI-Native API Gateway"},
{"Truncated String", "Higress Is an AI-Native API Gateway", 20, "Higress Is...(truncated)...PI Gateway"}, {"Truncated String", "Higress Is an AI-Native API Gateway", 20, "Higress Is...(truncated)...PI Gateway"},
{"Empty String", "", 10, ""},
{"Single Char", "A", 10, "A"},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -26,3 +34,248 @@ func TestTruncateString(t *testing.T) {
}) })
} }
} }
// TestIsPreRequestStage tests the isPreRequestStage function
func TestIsPreRequestStage(t *testing.T) {
config := McpConverterConfig{Stage: ProcessRequest}
require.True(t, isPreRequestStage(config))
config = McpConverterConfig{Stage: ProcessResponse}
require.False(t, isPreRequestStage(config))
}
// TestIsPreResponseStage tests the isPreResponseStage function
func TestIsPreResponseStage(t *testing.T) {
config := McpConverterConfig{Stage: ProcessResponse}
require.True(t, isPreResponseStage(config))
config = McpConverterConfig{Stage: ProcessRequest}
require.False(t, isPreResponseStage(config))
}
// TestIsMethodAllowed tests the isMethodAllowed function
func TestIsMethodAllowed(t *testing.T) {
config := McpConverterConfig{AllowedMethods: []string{MethodToolList, MethodToolCall}}
require.True(t, isMethodAllowed(config, MethodToolList))
require.True(t, isMethodAllowed(config, MethodToolCall))
require.False(t, isMethodAllowed(config, "invalid/method"))
}
// TestConstants tests the constant values
func TestConstants(t *testing.T) {
require.Equal(t, "x-envoy-jsonrpc-id", JsonRpcId)
require.Equal(t, "x-envoy-jsonrpc-method", JsonRpcMethod)
require.Equal(t, "x-envoy-jsonrpc-params", JsonRpcParams)
require.Equal(t, "x-envoy-jsonrpc-result", JsonRpcResult)
require.Equal(t, "x-envoy-jsonrpc-error", JsonRpcError)
require.Equal(t, "x-envoy-mcp-tool-name", McpToolName)
require.Equal(t, "x-envoy-mcp-tool-arguments", McpToolArguments)
require.Equal(t, "x-envoy-mcp-tool-response", McpToolResponse)
require.Equal(t, "x-envoy-mcp-tool-error", McpToolError)
require.Equal(t, 4000, DefaultMaxHeaderLength)
require.Equal(t, "tools/list", MethodToolList)
require.Equal(t, "tools/call", MethodToolCall)
require.Equal(t, ProcessStage("request"), ProcessRequest)
require.Equal(t, ProcessStage("response"), ProcessResponse)
}
// TestMcpConverterConfigDefaults tests config default values
func TestMcpConverterConfigDefaults(t *testing.T) {
config := McpConverterConfig{}
require.Equal(t, 0, config.MaxHeaderLength)
require.Equal(t, ProcessStage(""), config.Stage)
require.Nil(t, config.AllowedMethods)
}
// TestProcessStage tests ProcessStage type
func TestProcessStage(t *testing.T) {
require.Equal(t, ProcessStage("request"), ProcessRequest)
require.Equal(t, ProcessStage("response"), ProcessResponse)
}
// TestRemoveJsonRpcHeadersFunction tests removeJsonRpcHeaders function logic
func TestRemoveJsonRpcHeadersFunction(t *testing.T) {
headersToRemove := []string{
JsonRpcId,
JsonRpcMethod,
JsonRpcParams,
JsonRpcResult,
McpToolName,
McpToolArguments,
McpToolResponse,
McpToolError,
}
require.Len(t, headersToRemove, 8)
}
// TestTruncateStringLong tests truncation of very long strings
func TestTruncateStringLong(t *testing.T) {
longString := ""
for i := 0; i < 5000; i++ {
longString += "a"
}
config := McpConverterConfig{MaxHeaderLength: 1000}
result := truncateString(longString, config)
require.Contains(t, result, "...(truncated)...")
require.LessOrEqual(t, len(result), 1020)
}
// TestTruncateStringWithSmallMaxLength tests truncation with small max length
func TestTruncateStringWithSmallMaxLength(t *testing.T) {
config := McpConverterConfig{MaxHeaderLength: 10}
result := truncateString("This is a very long string", config)
require.Contains(t, result, "...(truncated)...")
}
// TestPluginInit tests plugin initialization
func TestPluginInit(t *testing.T) {
configBytes, _ := json.Marshal(McpConverterConfig{
Stage: ProcessRequest,
MaxHeaderLength: DefaultMaxHeaderLength,
AllowedMethods: []string{MethodToolList, MethodToolCall},
})
host, status := test.NewTestHost(configBytes)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
}
// TestProcessJsonRpcRequest tests processJsonRpcRequest function
func TestProcessJsonRpcRequest(t *testing.T) {
configBytes, _ := json.Marshal(McpConverterConfig{
Stage: ProcessRequest,
MaxHeaderLength: DefaultMaxHeaderLength,
AllowedMethods: []string{MethodToolList, MethodToolCall},
})
host, status := test.NewTestHost(configBytes)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp-server.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
})
toolsListRequest := `{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
}`
action := host.CallOnHttpRequestBody([]byte(toolsListRequest))
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
}
// TestProcessToolCallRequest tests processToolCallRequest function
func TestProcessToolCallRequest(t *testing.T) {
configBytes, _ := json.Marshal(McpConverterConfig{
Stage: ProcessRequest,
MaxHeaderLength: DefaultMaxHeaderLength,
AllowedMethods: []string{MethodToolCall},
})
host, status := test.NewTestHost(configBytes)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp-server.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
})
toolCallRequest := `{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "test_tool",
"arguments": {"arg1": "value1"}
}
}`
action := host.CallOnHttpRequestBody([]byte(toolCallRequest))
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
}
// TestProcessJsonRpcResponse tests processJsonRpcResponse function
func TestProcessJsonRpcResponse(t *testing.T) {
configBytes, _ := json.Marshal(McpConverterConfig{
Stage: ProcessResponse,
MaxHeaderLength: DefaultMaxHeaderLength,
AllowedMethods: []string{MethodToolList, MethodToolCall},
})
host, status := test.NewTestHost(configBytes)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp-server.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
})
responseBody := `{
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [{"name": "test_tool"}]
}
}`
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
host.CallOnHttpResponseBody([]byte(responseBody))
host.CompleteHttp()
}
// TestProcessToolListResponse tests processToolListResponse function
func TestProcessToolListResponse(t *testing.T) {
configBytes, _ := json.Marshal(McpConverterConfig{
Stage: ProcessResponse,
MaxHeaderLength: DefaultMaxHeaderLength,
AllowedMethods: []string{MethodToolList},
})
host, status := test.NewTestHost(configBytes)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp-server.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
})
responseBody := `{
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [{"name": "test_tool"}]
}
}`
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
host.CallOnHttpResponseBody([]byte(responseBody))
host.CompleteHttp()
}

View File

@@ -2,9 +2,12 @@ module mcp-router
go 1.24.1 go 1.24.1
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
) )

View File

@@ -20,12 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57 h1:WhNdnKSDtAQrh4Yil8HAtbl7VW+WC85m7WS8kirnHAA= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 h1:2wlbNpFJCQNbPBFYgswz7Zvxo9O3L0PH0AJxwiCc5lk=
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=

View File

@@ -22,8 +22,8 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
"github.com/higress-group/wasm-go/pkg/mcp/consts" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts"
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"

View File

@@ -1,13 +1,16 @@
module all-in-one module mcp-server
go 1.24.1 go 1.24.1
replace quark-search => ../quark-search replace (
amap-tools => ../../mcp-servers/amap-tools
replace amap-tools => ../amap-tools github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
quark-search => ../../mcp-servers/quark-search
)
require ( require (
amap-tools v0.0.0-00010101000000-000000000000 amap-tools v0.0.0-00010101000000-000000000000
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0

View File

@@ -22,10 +22,6 @@ github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rR
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500 h1:4BKKZ3BreIaIGub88nlvzihTK1uJmZYYoQ7r7Xkgb5Q=
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1 h1:+usoX0B1cwECTA2qf73IaLGyCIMVopIMev5cBWGgEZk=
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=

View File

@@ -18,7 +18,7 @@ import (
amap "amap-tools/tools" amap "amap-tools/tools"
quark "quark-search/tools" quark "quark-search/tools"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
) )
func main() {} func main() {}

View File

@@ -1,14 +0,0 @@
# Use a minimal base image as we only need to store the wasm file.
FROM scratch
# Add build argument for the filter name. This will be passed by the Makefile.
ARG FILTER_NAME
# Copy the compiled WASM binary into the image's root directory.
# The wasm file will be named after the filter.
COPY ${FILTER_NAME}/main.wasm /plugin.wasm
# Metadata
LABEL org.opencontainers.image.title="${FILTER_NAME}"
LABEL org.opencontainers.image.description="Higress MCP filter - ${FILTER_NAME}"
LABEL org.opencontainers.image.source="https://github.com/alibaba/higress"

View File

@@ -1,54 +0,0 @@
# MCP Filter Makefile
# Variables
FILTER_NAME ?= mcp-router
REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
BUILD_TIME := $(shell date "+%Y%m%d-%H%M%S")
COMMIT_ID := $(shell git rev-parse --short HEAD 2>/dev/null)
IMAGE_TAG = $(if $(strip $(FILTER_VERSION)),${FILTER_VERSION},${BUILD_TIME}-${COMMIT_ID})
IMG ?= ${REGISTRY}${FILTER_NAME}:${IMAGE_TAG}
# Default target
.DEFAULT: build
build:
@echo "Building WASM binary for filter: ${FILTER_NAME}..."
@if [ ! -d "${FILTER_NAME}" ]; then \
echo "Error: Filter directory '${FILTER_NAME}' not found."; \
exit 1; \
fi
cd ${FILTER_NAME} && GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
@echo ""
@echo "Output WASM file: ${FILTER_NAME}/main.wasm"
# Build Docker image (depends on build target to ensure WASM binary exists)
build-image: build
@echo "Building Docker image for ${FILTER_NAME}..."
docker build -t ${IMG} \
--build-arg FILTER_NAME=${FILTER_NAME} \
-f Dockerfile .
@echo ""
@echo "Image: ${IMG}"
# Build and push Docker image
build-push: build-image
docker push ${IMG}
# Clean build artifacts
clean:
@echo "Cleaning build artifacts for filter: ${FILTER_NAME}..."
rm -f ${FILTER_NAME}/main.wasm
# Help
help:
@echo "Available targets:"
@echo " build - Build WASM binary for a specific filter"
@echo " build-image - Build Docker image"
@echo " build-push - Build and push Docker image"
@echo " clean - Remove build artifacts for a specific filter"
@echo ""
@echo "Variables:"
@echo " FILTER_NAME - Name of the MCP filter to build (default: ${FILTER_NAME})"
@echo " REGISTRY - Docker registry (default: ${REGISTRY})"
@echo " FILTER_VERSION - Version tag for the image (default: timestamp-commit)"
@echo " IMG - Full image name (default: ${IMG})"

View File

@@ -80,8 +80,8 @@ import (
"net/http" "net/http"
"my-mcp-server/config" "my-mcp-server/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
// Define your tool structure with input parameters // Define your tool structure with input parameters
@@ -145,8 +145,8 @@ For better organization, you can create a separate file to load all your tools:
package tools package tools
import ( import (
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
) )
func LoadTools(server *mcp.MCPServer) server.Server { func LoadTools(server *mcp.MCPServer) server.Server {
@@ -170,7 +170,7 @@ import (
amap "amap-tools/tools" amap "amap-tools/tools"
quark "quark-search/tools" quark "quark-search/tools"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
) )
func main() {} func main() {}
@@ -375,7 +375,7 @@ package main
import ( import (
"my-mcp-server/tools" "my-mcp-server/tools"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
) )
func main() {} func main() {}

View File

@@ -2,9 +2,12 @@ module amap-tools
go 1.24.1 go 1.24.1
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
require ( require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
github.com/higress-group/wasm-go v1.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
) )
require ( require (
@@ -23,6 +26,7 @@ require (
github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/shopspring/decimal v1.4.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cast v1.7.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect

View File

@@ -17,7 +17,7 @@ package main
import ( import (
"amap-tools/tools" "amap-tools/tools"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
) )
func main() {} func main() {}

View File

@@ -15,8 +15,8 @@
package tools package tools
import ( import (
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
) )
func LoadTools(server *mcp.MCPServer) server.Server { func LoadTools(server *mcp.MCPServer) server.Server {

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = AroundSearchRequest{} var _ server.Tool = AroundSearchRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = BicyclingRequest{} var _ server.Tool = BicyclingRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = DrivingRequest{} var _ server.Tool = DrivingRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = TransitIntegratedRequest{} var _ server.Tool = TransitIntegratedRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = WalkingRequest{} var _ server.Tool = WalkingRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = DistanceRequest{} var _ server.Tool = DistanceRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = GeoRequest{} var _ server.Tool = GeoRequest{}

View File

@@ -24,8 +24,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
) )

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = ReGeocodeRequest{} var _ server.Tool = ReGeocodeRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = SearchDetailRequest{} var _ server.Tool = SearchDetailRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = TextSearchRequest{} var _ server.Tool = TextSearchRequest{}

View File

@@ -23,8 +23,8 @@ import (
"amap-tools/config" "amap-tools/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
) )
var _ server.Tool = WeatherRequest{} var _ server.Tool = WeatherRequest{}

View File

@@ -2,8 +2,11 @@ module quark-search
go 1.24.1 go 1.24.1
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
require ( require (
github.com/higress-group/wasm-go v1.0.0 github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
) )
@@ -16,7 +19,7 @@ require (
github.com/buger/jsonparser v1.1.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 // indirect github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect
@@ -24,6 +27,7 @@ require (
github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/shopspring/decimal v1.4.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cast v1.7.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect github.com/tidwall/resp v0.1.1 // indirect

View File

@@ -17,7 +17,7 @@ package main
import ( import (
"quark-search/tools" "quark-search/tools"
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
) )
func main() {} func main() {}

View File

@@ -15,8 +15,8 @@
package tools package tools
import ( import (
"github.com/higress-group/wasm-go/pkg/mcp" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
) )
func LoadTools(server *mcp.MCPServer) server.Server { func LoadTools(server *mcp.MCPServer) server.Server {

View File

@@ -24,8 +24,8 @@ import (
"quark-search/config" "quark-search/config"
"github.com/higress-group/wasm-go/pkg/mcp/server" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
"github.com/higress-group/wasm-go/pkg/mcp/utils" "github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )

View File

@@ -1,85 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package log
type Log interface {
Trace(msg string)
Tracef(format string, args ...interface{})
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Warn(msg string)
Warnf(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
Critical(msg string)
Criticalf(format string, args ...interface{})
ResetID(pluginID string)
}
var pluginLog Log
func SetPluginLog(log Log) {
pluginLog = log
}
func Trace(msg string) {
pluginLog.Trace(msg)
}
func Tracef(format string, args ...interface{}) {
pluginLog.Tracef(format, args...)
}
func Debug(msg string) {
pluginLog.Debug(msg)
}
func Debugf(format string, args ...interface{}) {
pluginLog.Debugf(format, args...)
}
func Info(msg string) {
pluginLog.Info(msg)
}
func Infof(format string, args ...interface{}) {
pluginLog.Infof(format, args...)
}
func Warn(msg string) {
pluginLog.Warn(msg)
}
func Warnf(format string, args ...interface{}) {
pluginLog.Warnf(format, args...)
}
func Error(msg string) {
pluginLog.Error(msg)
}
func Errorf(format string, args ...interface{}) {
pluginLog.Errorf(format, args...)
}
func Critical(msg string) {
pluginLog.Critical(msg)
}
func Criticalf(format string, args ...interface{}) {
pluginLog.Criticalf(format, args...)
}

View File

@@ -1,300 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package matcher
import (
"errors"
"fmt"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
type Category int
const (
Route Category = iota
Host
Service
RoutePrefix
)
type MatchType int
const (
Prefix MatchType = iota
Exact
Suffix
)
const (
RULES_KEY = "_rules_"
MATCH_ROUTE_KEY = "_match_route_"
MATCH_DOMAIN_KEY = "_match_domain_"
MATCH_SERVICE_KEY = "_match_service_"
MATCH_ROUTE_PREFIX_KEY = "_match_route_prefix_"
)
type HostMatcher struct {
matchType MatchType
host string
}
type RuleConfig[PluginConfig any] struct {
category Category
routes map[string]struct{}
services map[string]struct{}
routePrefixs map[string]struct{}
hosts []HostMatcher
config PluginConfig
}
type RuleMatcher[PluginConfig any] struct {
ruleConfig []RuleConfig[PluginConfig]
globalConfig PluginConfig
hasGlobalConfig bool
}
func (m RuleMatcher[PluginConfig]) GetMatchConfig() (*PluginConfig, error) {
host, err := proxywasm.GetHttpRequestHeader(":authority")
if err != nil {
return nil, err
}
routeName, err := proxywasm.GetProperty([]string{"route_name"})
if err != nil && err != types.ErrorStatusNotFound {
return nil, err
}
serviceName, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil && err != types.ErrorStatusNotFound {
return nil, err
}
for _, rule := range m.ruleConfig {
// category == Host
if rule.category == Host {
if m.hostMatch(rule, host) {
return &rule.config, nil
}
}
// category == Route
if rule.category == Route {
if _, ok := rule.routes[string(routeName)]; ok {
return &rule.config, nil
}
}
// category == RoutePrefix
if rule.category == RoutePrefix {
for routePrefix := range rule.routePrefixs {
if strings.HasPrefix(string(routeName), routePrefix) {
return &rule.config, nil
}
}
}
// category == Cluster
if m.serviceMatch(rule, string(serviceName)) {
return &rule.config, nil
}
}
if m.hasGlobalConfig {
return &m.globalConfig, nil
}
return nil, nil
}
func (m *RuleMatcher[PluginConfig]) ParseRuleConfig(config gjson.Result,
parsePluginConfig func(gjson.Result, *PluginConfig) error,
parseOverrideConfig func(gjson.Result, PluginConfig, *PluginConfig) error) error {
var rules []gjson.Result
obj := config.Map()
keyCount := len(obj)
if keyCount == 0 {
// enable globally for empty config
m.hasGlobalConfig = true
return parsePluginConfig(config, &m.globalConfig)
}
if rulesJson, ok := obj[RULES_KEY]; ok {
rules = rulesJson.Array()
keyCount--
}
var pluginConfig PluginConfig
var globalConfigError error
if keyCount > 0 {
err := parsePluginConfig(config, &pluginConfig)
if err != nil {
globalConfigError = err
} else {
m.globalConfig = pluginConfig
m.hasGlobalConfig = true
}
}
if len(rules) == 0 {
if m.hasGlobalConfig {
return nil
}
return fmt.Errorf("parse config failed, no valid rules; global config parse error:%v", globalConfigError)
}
for _, ruleJson := range rules {
var (
rule RuleConfig[PluginConfig]
err error
)
if parseOverrideConfig != nil {
err = parseOverrideConfig(ruleJson, m.globalConfig, &rule.config)
} else {
err = parsePluginConfig(ruleJson, &rule.config)
}
if err != nil {
return err
}
rule.routes = m.parseRouteMatchConfig(ruleJson)
rule.hosts = m.parseHostMatchConfig(ruleJson)
rule.services = m.parseServiceMatchConfig(ruleJson)
rule.routePrefixs = m.parseRoutePrefixMatchConfig(ruleJson)
noRoute := len(rule.routes) == 0
noHosts := len(rule.hosts) == 0
noService := len(rule.services) == 0
noRoutePrefix := len(rule.routePrefixs) == 0
if boolToInt(noRoute)+boolToInt(noService)+boolToInt(noHosts)+boolToInt(noRoutePrefix) != 3 {
return errors.New("there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.")
}
if !noRoute {
rule.category = Route
} else if !noHosts {
rule.category = Host
} else if !noService {
rule.category = Service
} else {
rule.category = RoutePrefix
}
m.ruleConfig = append(m.ruleConfig, rule)
}
return nil
}
func (m RuleMatcher[PluginConfig]) parseRouteMatchConfig(config gjson.Result) map[string]struct{} {
keys := config.Get(MATCH_ROUTE_KEY).Array()
routes := make(map[string]struct{})
for _, item := range keys {
routeName := item.String()
if routeName != "" {
routes[routeName] = struct{}{}
}
}
return routes
}
func (m RuleMatcher[PluginConfig]) parseRoutePrefixMatchConfig(config gjson.Result) map[string]struct{} {
keys := config.Get(MATCH_ROUTE_PREFIX_KEY).Array()
routePrefixs := make(map[string]struct{})
for _, item := range keys {
routePrefix := item.String()
if routePrefix != "" {
routePrefixs[routePrefix] = struct{}{}
}
}
return routePrefixs
}
func (m RuleMatcher[PluginConfig]) parseServiceMatchConfig(config gjson.Result) map[string]struct{} {
keys := config.Get(MATCH_SERVICE_KEY).Array()
clusters := make(map[string]struct{})
for _, item := range keys {
clusterName := item.String()
if clusterName != "" {
clusters[clusterName] = struct{}{}
}
}
return clusters
}
func (m RuleMatcher[PluginConfig]) parseHostMatchConfig(config gjson.Result) []HostMatcher {
keys := config.Get(MATCH_DOMAIN_KEY).Array()
var hostMatchers []HostMatcher
for _, item := range keys {
host := item.String()
var hostMatcher HostMatcher
if strings.HasPrefix(host, "*") {
hostMatcher.matchType = Suffix
hostMatcher.host = host[1:]
} else if strings.HasSuffix(host, "*") {
hostMatcher.matchType = Prefix
hostMatcher.host = host[:len(host)-1]
} else {
hostMatcher.matchType = Exact
hostMatcher.host = host
}
hostMatchers = append(hostMatchers, hostMatcher)
}
return hostMatchers
}
func stripPortFromHost(reqHost string) string {
// Port removing code is inspired by
// https://github.com/envoyproxy/envoy/blob/v1.17.0/source/common/http/header_utility.cc#L219
portStart := strings.LastIndexByte(reqHost, ':')
if portStart != -1 {
// According to RFC3986 v6 address is always enclosed in "[]".
// section 3.2.2.
v6EndIndex := strings.LastIndexByte(reqHost, ']')
if v6EndIndex == -1 || v6EndIndex < portStart {
if portStart+1 <= len(reqHost) {
return reqHost[:portStart]
}
}
}
return reqHost
}
func (m RuleMatcher[PluginConfig]) hostMatch(rule RuleConfig[PluginConfig], reqHost string) bool {
reqHost = stripPortFromHost(reqHost)
for _, hostMatch := range rule.hosts {
switch hostMatch.matchType {
case Suffix:
if strings.HasSuffix(reqHost, hostMatch.host) {
return true
}
case Prefix:
if strings.HasPrefix(reqHost, hostMatch.host) {
return true
}
case Exact:
if reqHost == hostMatch.host {
return true
}
default:
return false
}
}
return false
}
func (m RuleMatcher[PluginConfig]) serviceMatch(rule RuleConfig[PluginConfig], serviceName string) bool {
parts := strings.Split(serviceName, "|")
if len(parts) != 4 {
return false
}
port := parts[1]
fqdn := parts[3]
for configServiceName := range rule.services {
colonIndex := strings.LastIndexByte(configServiceName, ':')
if colonIndex != -1 && fqdn == string(configServiceName[:colonIndex]) && port == string(configServiceName[colonIndex+1:]) {
return true
} else if fqdn == string(configServiceName) {
return true
}
}
return false
}

View File

@@ -1,438 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package matcher
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
type customConfig struct {
name string
age int64
}
func parseConfig(json gjson.Result, config *customConfig) error {
config.name = json.Get("name").String()
config.age = json.Get("age").Int()
return nil
}
func TestHostMatch(t *testing.T) {
cases := []struct {
name string
config RuleConfig[customConfig]
host string
result bool
}{
{
name: "prefix",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Prefix,
host: "www.",
},
},
},
host: "www.test.com",
result: true,
},
{
name: "prefix failed",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Prefix,
host: "www.",
},
},
},
host: "test.com",
result: false,
},
{
name: "suffix",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Suffix,
host: ".example.com",
},
},
},
host: "www.example.com",
result: true,
},
{
name: "suffix failed",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Suffix,
host: ".example.com",
},
},
},
host: "example.com",
result: false,
},
{
name: "exact",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Exact,
host: "www.example.com",
},
},
},
host: "www.example.com",
result: true,
},
{
name: "exact failed",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Exact,
host: "www.example.com",
},
},
},
host: "example.com",
result: false,
},
{
name: "exact port",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Exact,
host: "www.example.com",
},
},
},
host: "www.example.com:8080",
result: true,
},
{
name: "any",
config: RuleConfig[customConfig]{
hosts: []HostMatcher{
{
matchType: Suffix,
host: "",
},
},
},
host: "www.example.com",
result: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var m RuleMatcher[customConfig]
assert.Equal(t, c.result, m.hostMatch(c.config, c.host))
})
}
}
func TestServiceMatch(t *testing.T) {
cases := []struct {
name string
config RuleConfig[customConfig]
service string
result bool
}{
{
name: "fqdn",
config: RuleConfig[customConfig]{
services: map[string]struct{}{
"qwen.dns": {},
},
},
service: "outbound|443||qwen.dns",
result: true,
},
{
name: "fqdn with port",
config: RuleConfig[customConfig]{
services: map[string]struct{}{
"qwen.dns:443": {},
},
},
service: "outbound|443||qwen.dns",
result: true,
},
{
name: "not match",
config: RuleConfig[customConfig]{
services: map[string]struct{}{
"moonshot.dns:443": {},
},
},
service: "outbound|443||qwen.dns",
result: false,
},
{
name: "error config format",
config: RuleConfig[customConfig]{
services: map[string]struct{}{
"qwen.dns:": {},
},
},
service: "outbound|443||qwen.dns",
result: false,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var m RuleMatcher[customConfig]
assert.Equal(t, c.result, m.serviceMatch(c.config, c.service))
})
}
}
func TestParseRuleConfig(t *testing.T) {
cases := []struct {
name string
config string
errMsg string
expected RuleMatcher[customConfig]
}{
{
name: "global config",
config: `{"name":"john", "age":18}`,
expected: RuleMatcher[customConfig]{
globalConfig: customConfig{
name: "john",
age: 18,
},
hasGlobalConfig: true,
},
},
{
name: "rules config",
config: `{"_rules_":[{"_match_domain_":["*.example.com","www.*","*","www.abc.com"],"name":"john", "age":18},{"_match_route_":["test1","test2"],"name":"ann", "age":16},{"_match_service_":["test1.dns","test2.static:8080"],"name":"ann", "age":16},{"_match_route_prefix_":["api1","api2"],"name":"ann", "age":16}]}`,
expected: RuleMatcher[customConfig]{
ruleConfig: []RuleConfig[customConfig]{
{
category: Host,
hosts: []HostMatcher{
{
matchType: Suffix,
host: ".example.com",
},
{
matchType: Prefix,
host: "www.",
},
{
matchType: Suffix,
host: "",
},
{
matchType: Exact,
host: "www.abc.com",
},
},
routes: map[string]struct{}{},
services: map[string]struct{}{},
routePrefixs: map[string]struct{}{},
config: customConfig{
name: "john",
age: 18,
},
},
{
category: Route,
routes: map[string]struct{}{
"test1": {},
"test2": {},
},
services: map[string]struct{}{},
routePrefixs: map[string]struct{}{},
config: customConfig{
name: "ann",
age: 16,
},
},
{
category: Service,
routes: map[string]struct{}{},
services: map[string]struct{}{
"test1.dns": {},
"test2.static:8080": {},
},
routePrefixs: map[string]struct{}{},
config: customConfig{
name: "ann",
age: 16,
},
},
{
category: RoutePrefix,
routes: map[string]struct{}{},
services: map[string]struct{}{},
routePrefixs: map[string]struct{}{
"api1": {},
"api2": {},
},
config: customConfig{
name: "ann",
age: 16,
},
},
},
},
},
{
name: "no rule",
config: `{"_rules_":[]}`,
errMsg: "parse config failed, no valid rules; global config parse error:<nil>",
},
{
name: "invalid rule",
config: `{"_rules_":[{"_match_domain_":["*"],"_match_route_":["test"]}]}`,
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
},
{
name: "invalid rule",
config: `{"_rules_":[{"_match_domain_":["*"],"_match_service_":["test.dns"]}]}`,
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
},
{
name: "invalid rule",
config: `{"_rules_":[{"age":16}]}`,
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var actual RuleMatcher[customConfig]
err := actual.ParseRuleConfig(gjson.Parse(c.config), parseConfig, nil)
if err != nil {
if c.errMsg == "" {
t.Errorf("parse failed: %v", err)
}
if err.Error() != c.errMsg {
t.Errorf("expect err: %s, actual err: %s", c.errMsg,
err.Error())
}
return
}
assert.Equal(t, c.expected, actual)
})
}
}
type completeConfig struct {
// global config
consumers []string
// rule config
allow []string
}
func parseGlobalConfig(json gjson.Result, global *completeConfig) error {
if json.Get("consumers").Exists() && json.Get("allow").Exists() {
return errors.New("consumers and allow should not be configured at the same level")
}
for _, item := range json.Get("consumers").Array() {
global.consumers = append(global.consumers, item.String())
}
return nil
}
func parseOverrideRuleConfig(json gjson.Result, global completeConfig, config *completeConfig) error {
if json.Get("consumers").Exists() && json.Get("allow").Exists() {
return errors.New("consumers and allow should not be configured at the same level")
}
// override config via global
*config = global
for _, item := range json.Get("allow").Array() {
config.allow = append(config.allow, item.String())
}
return nil
}
func TestParseOverrideConfig(t *testing.T) {
cases := []struct {
name string
config string
errMsg string
expected RuleMatcher[completeConfig]
}{
{
name: "override rule config",
config: `{"consumers":["c1","c2","c3"],"_rules_":[{"_match_route_":["r1","r2"],"allow":["c1","c3"]}]}`,
expected: RuleMatcher[completeConfig]{
ruleConfig: []RuleConfig[completeConfig]{
{
category: Route,
routes: map[string]struct{}{
"r1": {},
"r2": {},
},
services: map[string]struct{}{},
routePrefixs: map[string]struct{}{},
config: completeConfig{
consumers: []string{"c1", "c2", "c3"},
allow: []string{"c1", "c3"},
},
},
},
globalConfig: completeConfig{
consumers: []string{"c1", "c2", "c3"},
},
hasGlobalConfig: true,
},
},
{
name: "invalid config",
config: `{"consumers":["c1","c2","c3"],"allow":["c1"]}`,
errMsg: "parse config failed, no valid rules; global config parse error:consumers and allow should not be configured at the same level",
},
{
name: "invalid config",
config: `{"_rules_":[{"_match_route_":["r1","r2"],"consumers":["c1","c2"],"allow":["c1"]}]}`,
errMsg: "consumers and allow should not be configured at the same level",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var actual RuleMatcher[completeConfig]
err := actual.ParseRuleConfig(gjson.Parse(c.config), parseGlobalConfig, parseOverrideRuleConfig)
if err != nil {
if c.errMsg == "" {
t.Errorf("parse failed: %v", err)
}
if err.Error() != c.errMsg {
t.Errorf("expect err: %s, actual err: %s", c.errMsg, err.Error())
}
return
}
assert.Equal(t, c.expected, actual)
})
}
}

View File

@@ -1,8 +0,0 @@
package matcher
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}

View File

@@ -12,17 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package wrapper package consts
import ( const (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ToolSetNameSplitter = "___"
) )
func IsResponseFromUpstream() bool {
if codeDetails, err := proxywasm.GetProperty([]string{"response", "code_details"}); err == nil {
return string(codeDetails) == "via_upstream"
} else {
proxywasm.LogErrorf("get response code details failed: %v", err)
return false
}
}

View File

@@ -0,0 +1,353 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package filter
import (
"github.com/tidwall/gjson"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
const (
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
)
type HTTPFilterF func(context wrapper.HttpContext, config any, headers [][2]string, body []byte) types.Action
type ToolCallRequestFilterF func(context wrapper.HttpContext, config any, toolName string, toolArgs gjson.Result, rawBody []byte) types.Action
type ToolCallResponseFilterF func(context wrapper.HttpContext, config any, isError bool, content gjson.Result, rawBody []byte) types.Action
type ToolListResponseFilterF func(context wrapper.HttpContext, config any, tools gjson.Result, rawBody []byte) types.Action
type JsonRpcRequestFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action
type JsonRpcResponseFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action
type Context struct {
filterName string
httpRequestFilter HTTPFilterF
httpResponseFilter HTTPFilterF
jsonRpcRequestFilter JsonRpcRequestFilterF
jsonRpcResponseFilter JsonRpcResponseFilterF
toolCallRequestFilter ToolCallRequestFilterF
toolCallResponseFilter ToolCallResponseFilterF
toolListResponseFilter ToolListResponseFilterF
parseFilterConfig ParseFilterConfigF
parseFilterRuleOverrideConfig ParseFilterRuleOverrideConfigF
}
type CtxOption interface {
Apply(*Context)
}
var globalContext Context
type ParseFilterConfigF func(configBytes []byte, filterConfig *any) error
type ParseFilterRuleOverrideConfigF func(configBytes []byte, filterGlobalConfig any, filterConfig *any) error
type setConfigParserOption struct {
f ParseFilterConfigF
g ParseFilterRuleOverrideConfigF
}
func SetConfigParser(f ParseFilterConfigF) CtxOption {
return &setConfigParserOption{
f: f,
}
}
func SetConfigOverrideParser(f ParseFilterConfigF, g ParseFilterRuleOverrideConfigF) CtxOption {
return &setConfigParserOption{
f: f,
g: g,
}
}
func (o *setConfigParserOption) Apply(ctx *Context) {
ctx.parseFilterConfig = o.f
ctx.parseFilterRuleOverrideConfig = o.g
}
type filterNameOption struct {
name string
}
func FilterName(name string) CtxOption {
return &filterNameOption{name}
}
func (o *filterNameOption) Apply(ctx *Context) {
ctx.filterName = o.name
}
type setJsonRpcRequestFilterOption struct {
f JsonRpcRequestFilterF
}
func SetJsonRpcRequestFilter(f JsonRpcRequestFilterF) CtxOption {
return &setJsonRpcRequestFilterOption{f}
}
func (o *setJsonRpcRequestFilterOption) Apply(ctx *Context) {
ctx.jsonRpcRequestFilter = o.f
}
type setJsonRpcResponseFilterOption struct {
f JsonRpcResponseFilterF
}
func SetJsonRpcResponseFilter(f JsonRpcResponseFilterF) CtxOption {
return &setJsonRpcResponseFilterOption{f}
}
func (o *setJsonRpcResponseFilterOption) Apply(ctx *Context) {
ctx.jsonRpcResponseFilter = o.f
}
type setFallbackHTTPRequestFilterOption struct {
f HTTPFilterF
}
func SetFallbackHTTPRequestFilter(f HTTPFilterF) CtxOption {
return &setFallbackHTTPRequestFilterOption{f}
}
func (o *setFallbackHTTPRequestFilterOption) Apply(ctx *Context) {
ctx.httpRequestFilter = o.f
}
type setFallbackHTTPResponseFilterOption struct {
f HTTPFilterF
}
func SetFallbackHTTPResponseFilter(f HTTPFilterF) CtxOption {
return &setFallbackHTTPResponseFilterOption{f}
}
func (o *setFallbackHTTPResponseFilterOption) Apply(ctx *Context) {
ctx.httpResponseFilter = o.f
}
type toolCallRequestFilterOption struct {
f ToolCallRequestFilterF
}
func SetToolCallRequestFilter(f ToolCallRequestFilterF) CtxOption {
return &toolCallRequestFilterOption{f: f}
}
func (o *toolCallRequestFilterOption) Apply(ctx *Context) {
ctx.toolCallRequestFilter = o.f
}
type toolCallResponseFilterOption struct {
f ToolCallResponseFilterF
}
func SetToolCallResponseFilter(f ToolCallResponseFilterF) CtxOption {
return &toolCallResponseFilterOption{f: f}
}
func (o *toolCallResponseFilterOption) Apply(ctx *Context) {
ctx.toolCallResponseFilter = o.f
}
type toolListResponseFilterOption struct {
f ToolListResponseFilterF
}
func SetToolListResponseFilter(f ToolListResponseFilterF) CtxOption {
return &toolListResponseFilterOption{f: f}
}
func (o *toolListResponseFilterOption) Apply(ctx *Context) {
ctx.toolListResponseFilter = o.f
}
func Load(options ...CtxOption) {
for _, opt := range options {
opt.Apply(&globalContext)
}
}
func Initialize() {
if globalContext.filterName == "" {
panic("FilterName not set")
}
if globalContext.parseFilterConfig == nil {
panic("SetConfigParser not set")
}
var configOption wrapper.CtxOption[mcpFilterConfig]
if globalContext.parseFilterRuleOverrideConfig == nil {
configOption = wrapper.ParseRawConfig(parseRawConfig)
} else {
configOption = wrapper.ParseOverrideRawConfig(parseGlobalConfig, parseOverrideConfig)
}
wrapper.SetCtx(
globalContext.filterName,
configOption,
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
type mcpFilterConfig struct {
config any
httpRequestHandler HTTPFilterF
httpResponseHandler HTTPFilterF
jsonRpcRequestHandler utils.JsonRpcRequestHandler
jsonRpcResponseHandler utils.JsonRpcResponseHandler
}
func installHandler(config *mcpFilterConfig) {
config.httpRequestHandler = globalContext.httpRequestFilter
config.httpResponseHandler = globalContext.httpResponseFilter
bizConfig := config.config
if globalContext.jsonRpcRequestFilter != nil || globalContext.toolCallRequestFilter != nil {
config.jsonRpcRequestHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action {
if globalContext.jsonRpcRequestFilter != nil {
ret := globalContext.jsonRpcRequestFilter(context, bizConfig, id, method, params, rawBody)
if ret != types.ActionContinue {
return ret
}
}
context.SetContext("JSONRPC_METHOD", method)
if method == "tools/call" && globalContext.toolCallRequestFilter != nil {
toolName := params.Get("name").String()
toolArgs := params.Get("arguments")
return globalContext.toolCallRequestFilter(context, bizConfig, toolName, toolArgs, rawBody)
}
return types.ActionContinue
}
}
if globalContext.jsonRpcResponseFilter != nil || globalContext.toolListResponseFilter != nil || globalContext.toolCallResponseFilter != nil {
config.jsonRpcResponseHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action {
if globalContext.jsonRpcResponseFilter != nil {
ret := globalContext.jsonRpcResponseFilter(context, bizConfig, id, result, error, rawBody)
if ret != types.ActionContinue {
return ret
}
}
method := context.GetStringContext("JSONRPC_METHOD", "")
if method == "tools/list" && globalContext.toolListResponseFilter != nil {
return globalContext.toolListResponseFilter(context, bizConfig, result.Get("tools"), rawBody)
}
if method == "tools/call" && globalContext.toolCallResponseFilter != nil {
return globalContext.toolCallResponseFilter(context, bizConfig, result.Get("isError").Bool(), result.Get("content"), rawBody)
}
return types.ActionContinue
}
}
log.Debugf("installHandler called, config is: %#v", config)
}
func parseRawConfig(configBytes []byte, config *mcpFilterConfig) error {
err := globalContext.parseFilterConfig(configBytes, &config.config)
if err != nil {
return err
}
installHandler(config)
return nil
}
func parseGlobalConfig(configBytes []byte, config *mcpFilterConfig) error {
err := globalContext.parseFilterConfig(configBytes, &config.config)
if err != nil {
return err
}
return nil
}
func parseOverrideConfig(configBytes []byte, global mcpFilterConfig, config *mcpFilterConfig) error {
err := globalContext.parseFilterRuleOverrideConfig(configBytes, global.config, &config.config)
if err != nil {
return err
}
installHandler(config)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action {
log.Debugf("onHttpRequestHeaders called")
if !ctx.HasRequestBody() || (config.httpRequestHandler == nil && config.jsonRpcRequestHandler == nil) {
log.Debugf("no request body or no handler, skip reading body")
ctx.DontReadRequestBody()
return types.ActionContinue
}
log.Debugf("has request body and handler, read body")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action {
log.Debugf("onHttpRequestBody called, body size: %d", len(body))
if !gjson.GetBytes(body, "jsonrpc").Exists() {
if config.httpRequestHandler != nil {
log.Debugf("body is not jsonrpc, using httpRequestHandler")
headers, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
log.Errorf("get request headers failed, err:%v", err)
return types.ActionContinue
}
return config.httpRequestHandler(ctx, config.config, headers, body)
}
log.Debugf("body is not jsonrpc, but no httpRequestHandler, skip")
return types.ActionContinue
}
log.Debugf("body is jsonrpc, using HandleJsonRpcRequest")
return utils.HandleJsonRpcRequest(ctx, body, config.jsonRpcRequestHandler)
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action {
log.Debugf("onHttpResponseHeaders called")
// IsApplicationJson checks if the content type is application/json, so we can skip reading the body if it's application/octet-stream
if !ctx.HasResponseBody() || !wrapper.IsApplicationJson() || (config.httpResponseHandler == nil && config.jsonRpcResponseHandler == nil) {
log.Debugf("no response body or no handler, skip reading body")
ctx.DontReadResponseBody()
return types.ActionContinue
}
log.Debugf("has response body and handler, read body")
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
return types.HeaderStopIteration
}
func onHttpResponseBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action {
log.Debugf("onHttpResponseBody called, body size: %d", len(body))
if !gjson.GetBytes(body, "jsonrpc").Exists() {
if config.httpResponseHandler != nil {
log.Debugf("body is not jsonrpc, using httpResponseHandler")
headers, err := proxywasm.GetHttpResponseHeaders()
if err != nil {
log.Errorf("get response headers failed, err:%v", err)
return types.ActionContinue
}
return config.httpResponseHandler(ctx, config.config, headers, body)
}
log.Debugf("body is not jsonrpc, but no httpResponseHandler, skip")
return types.ActionContinue
}
log.Debugf("body is jsonrpc, using HandleJsonRpcResponse")
return utils.HandleJsonRpcResponse(ctx, body, config.jsonRpcResponseHandler)
}

View File

@@ -0,0 +1,39 @@
module github.com/alibaba/higress/plugins/wasm-go/pkg/mcp
go 1.24.1
require (
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
github.com/invopop/jsonschema v0.13.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
require (
dario.cat/mergo v1.0.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
golang.org/x/crypto v0.26.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,84 @@
package mcp
import (
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/filter"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
)
var _ server.Server = &MCPServer{}
// MCPServer implements the Server interface using BaseMCPServer
type MCPServer struct {
base server.BaseMCPServer
}
// NewMCPServer creates a new MCPServer
func NewMCPServer() *MCPServer {
return &MCPServer{
base: server.NewBaseMCPServer(),
}
}
// Clone implements Server interface
func (s *MCPServer) Clone() server.Server {
return &MCPServer{
base: s.base.CloneBase(),
}
}
// AddMCPTool implements Server interface
func (s *MCPServer) AddMCPTool(name string, tool server.Tool) server.Server {
s.base.AddMCPTool(name, tool)
return s
}
// GetConfig implements Server interface
func (s *MCPServer) GetConfig(v any) {
s.base.GetConfig(v)
}
// GetMCPTools implements Server interface
func (s *MCPServer) GetMCPTools() map[string]server.Tool {
return s.base.GetMCPTools()
}
// SetConfig implements Server interface
func (s *MCPServer) SetConfig(config []byte) {
s.base.SetConfig(config)
}
// mcp server function
var (
LoadMCPServer = server.Load
InitMCPServer = server.Initialize
AddMCPServer = server.AddMCPServer
)
// mcp filter function
var (
LoadMCPFilter = filter.Load
InitMCPFilter = filter.Initialize
SetConfigParser = filter.SetConfigParser
SetConfigOverrideParser = filter.SetConfigOverrideParser
FilterName = filter.FilterName
SetJsonRpcRequestFilter = filter.SetJsonRpcRequestFilter
SetJsonRpcResponseFilter = filter.SetJsonRpcResponseFilter
SetFallbackHTTPRequestFilter = filter.SetFallbackHTTPRequestFilter
SetFallbackHTTPResponseFilter = filter.SetFallbackHTTPResponseFilter
SetToolCallRequestFilter = filter.SetToolCallRequestFilter
SetToolCallResponseFilter = filter.SetToolCallResponseFilter
SetToolListResponseFilter = filter.SetToolListResponseFilter
)

View File

@@ -0,0 +1,232 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
)
// setOrReplaceHeader sets or replaces a header in the headers slice.
// If the header exists (case-insensitive comparison), it replaces the value.
// If the header doesn't exist, it appends a new header.
func setOrReplaceHeader(headers *[][2]string, key, value string) {
lowerKey := strings.ToLower(key)
// Check if header already exists
for i, header := range *headers {
if strings.ToLower(header[0]) == lowerKey {
// Replace existing header value
(*headers)[i][1] = value
return
}
}
// Header doesn't exist, append new one
*headers = append(*headers, [2]string{key, value})
}
// SecurityScheme defines a security scheme for the REST API
type SecurityScheme struct {
ID string `json:"id"`
Type string `json:"type"` // http, apiKey
Scheme string `json:"scheme,omitempty"` // basic, bearer (for type: http)
In string `json:"in,omitempty"` // header, query (for type: apiKey)
Name string `json:"name,omitempty"` // Header or query parameter name (for type: apiKey)
DefaultCredential string `json:"defaultCredential,omitempty"`
}
// SecurityRequirement specifies a security scheme requirement for a tool
type SecurityRequirement struct {
ID string `json:"id"` // References a security scheme ID
Credential string `json:"credential,omitempty"` // Overrides default credential
Passthrough bool `json:"passthrough,omitempty"` // If true, credentials from client request will be passed through
}
// AuthRequestContext holds the data needed for applying security schemes.
type AuthRequestContext struct {
Method string
Headers [][2]string // Direct slice, modifications within applySecurity will update this field in the struct instance
ParsedURL *url.URL // Pointer to allow modification (e.g., RawQuery)
RequestBody []byte // For future security types that might inspect the body
PassthroughCredential string // Credential extracted from client request for passthrough
}
// SecuritySchemeProvider provides access to security schemes
type SecuritySchemeProvider interface {
GetSecurityScheme(id string) (SecurityScheme, bool)
}
// ExtractAndRemoveIncomingCredential extracts a credential from the current incoming HTTP request
// and removes it. It uses global proxywasm functions to access request details.
// For query parameters, "removal" is conceptual as we build a new request;
// this function primarily extracts the value for potential passthrough.
func ExtractAndRemoveIncomingCredential(scheme SecurityScheme) (string, error) {
credentialValue := ""
var err error
switch scheme.Type {
case "http":
authHeader, _ := proxywasm.GetHttpRequestHeader("Authorization") // Error ignored, check content
if authHeader == "" {
// If no header, it's not an error for extraction if not required, but indicates not found.
// For removal, there's nothing to remove.
return "", nil // Or a specific "not found" error if scheme implies it must be there.
}
if scheme.Scheme == "bearer" {
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
return "", fmt.Errorf("incoming Authorization header is not Bearer auth: %s", authHeader)
}
credentialValue = strings.TrimSpace(authHeader[len("Bearer "):])
} else if scheme.Scheme == "basic" {
if !strings.HasPrefix(strings.ToLower(authHeader), "basic ") {
return "", fmt.Errorf("incoming Authorization header is not Basic auth: %s", authHeader)
}
credentialValue = strings.TrimSpace(authHeader[len("Basic "):])
} else {
return "", fmt.Errorf("unsupported http scheme for credential extraction/removal: %s", scheme.Scheme)
}
proxywasm.RemoveHttpRequestHeader("Authorization")
log.Debugf("Extracted and removed Authorization header for incoming %s scheme.", scheme.Scheme)
case "apiKey":
if scheme.In == "header" {
if scheme.Name == "" {
return "", errors.New("apiKey in header requires a name for the header")
}
headerValue, _ := proxywasm.GetHttpRequestHeader(scheme.Name) // Error ignored, check content
if headerValue == "" {
return "", nil // Not found, not necessarily an error for extraction.
}
credentialValue = headerValue
proxywasm.RemoveHttpRequestHeader(scheme.Name)
log.Debugf("Extracted and removed %s header for incoming apiKey auth.", scheme.Name)
} else if scheme.In == "query" {
if scheme.Name == "" {
return "", errors.New("apiKey in query requires a name for the query parameter")
}
pathHeader, _ := proxywasm.GetHttpRequestHeader(":path") // Error ignored, check content
if pathHeader == "" {
// This case might be an error as :path should generally exist.
return "", fmt.Errorf("no :path header found in incoming request for apiKey in query")
}
requestURL, parseErr := url.Parse(pathHeader)
if parseErr != nil {
return "", fmt.Errorf("failed to parse incoming :path header '%s': %v", pathHeader, parseErr)
}
queryValues := requestURL.Query()
apiKeyValue := queryValues.Get(scheme.Name)
if apiKeyValue == "" {
return "", nil // Not found
}
credentialValue = apiKeyValue
log.Debugf("Extracted %s query parameter from incoming request. Removal from original :path is implicit.", scheme.Name)
} else {
return "", fmt.Errorf("unsupported apiKey 'in' value: %s", scheme.In)
}
default:
return "", fmt.Errorf("unsupported security scheme type for credential extraction/removal: %s", scheme.Type)
}
return credentialValue, err
}
// ApplySecurity applies the configured security scheme to the request.
// It modifies reqCtx.Headers and reqCtx.ParsedURL (specifically RawQuery) in place if necessary.
func ApplySecurity(securityConfig SecurityRequirement, provider SecuritySchemeProvider, reqCtx *AuthRequestContext) error {
if securityConfig.ID == "" {
return nil // No security scheme defined
}
if reqCtx.ParsedURL == nil {
return errors.New("ParsedURL in AuthRequestContext cannot be nil for ApplySecurity")
}
upstreamScheme, schemeOk := provider.GetSecurityScheme(securityConfig.ID)
if !schemeOk {
return fmt.Errorf("upstream security scheme with id '%s' not found", securityConfig.ID)
}
var credentialToUse string
if reqCtx.PassthroughCredential != "" {
// Use the passthrough credential value.
// The upstreamScheme dictates how this value is formatted and applied.
credentialToUse = reqCtx.PassthroughCredential
log.Debugf("Using passthrough credential for upstream request with scheme %s.", upstreamScheme.ID)
} else {
// Use configured credential for the upstream request.
credentialToUse = upstreamScheme.DefaultCredential
if securityConfig.Credential != "" {
credentialToUse = securityConfig.Credential
}
if credentialToUse == "" {
return fmt.Errorf("no credential found or configured for upstream security scheme '%s'", upstreamScheme.ID)
}
log.Debugf("Using configured credential for upstream request with scheme %s.", upstreamScheme.ID)
}
switch upstreamScheme.Type {
case "http":
authValue := credentialToUse
if upstreamScheme.Scheme == "basic" {
if !strings.HasPrefix(authValue, "Basic ") {
if reqCtx.PassthroughCredential != "" { // Came from passthrough, it's the base64 token part
authValue = "Basic " + credentialToUse
} else { // Came from config
if strings.Contains(credentialToUse, ":") { // Assumed to be "user:pass"
authValue = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentialToUse))
} else { // Assumed to be already base64 encoded string (token part)
authValue = "Basic " + credentialToUse
}
}
}
} else if upstreamScheme.Scheme == "bearer" {
// Passthrough for Bearer gives the token part. Configured credential is the token.
if !strings.HasPrefix(authValue, "Bearer ") {
authValue = "Bearer " + credentialToUse
}
} else {
return fmt.Errorf("unsupported http scheme type for upstream: %s", upstreamScheme.Scheme)
}
setOrReplaceHeader(&reqCtx.Headers, "Authorization", authValue)
case "apiKey":
if upstreamScheme.In == "header" {
if upstreamScheme.Name == "" {
return errors.New("apiKey in header requires a name for the header for upstream")
}
setOrReplaceHeader(&reqCtx.Headers, upstreamScheme.Name, credentialToUse)
} else if upstreamScheme.In == "query" {
if upstreamScheme.Name == "" {
return errors.New("apiKey in query requires a name for the query parameter for upstream")
}
queryValues := reqCtx.ParsedURL.Query()
queryValues.Set(upstreamScheme.Name, credentialToUse)
reqCtx.ParsedURL.RawQuery = queryValues.Encode()
} else {
return fmt.Errorf("unsupported apiKey 'in' value for upstream: %s", upstreamScheme.In)
}
default:
return fmt.Errorf("unsupported security scheme type: %s", upstreamScheme.Type)
}
return nil
}

View File

@@ -0,0 +1,100 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/base64"
"encoding/json"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
)
// BaseMCPServer provides common functionality for MCP servers
type BaseMCPServer struct {
tools map[string]Tool
config []byte
}
// NewBaseMCPServer creates a new BaseMCPServer
func NewBaseMCPServer() BaseMCPServer {
return BaseMCPServer{
tools: make(map[string]Tool),
}
}
// AddMCPTool adds a tool to the server
func (s *BaseMCPServer) AddMCPTool(name string, tool Tool) Server {
if _, exist := s.tools[name]; exist {
log.Errorf("Conflict! There is a tool with the same name:%s", name)
return s
}
s.tools[name] = tool
return s
}
// GetMCPTools returns all tools registered with the server
func (s *BaseMCPServer) GetMCPTools() map[string]Tool {
return s.tools
}
// SetConfig sets the server configuration
func (s *BaseMCPServer) SetConfig(config []byte) {
s.config = config
}
// GetConfig gets the server configuration
// It first tries to get the config from the request header, then falls back to the stored config
func (s *BaseMCPServer) GetConfig(v any) {
var config []byte
serverConfigBase64, _ := proxywasm.GetHttpRequestHeader("x-higress-mcpserver-config")
proxywasm.RemoveHttpRequestHeader("x-higress-mcpserver-config")
if serverConfigBase64 != "" {
serverConfig, err := base64.StdEncoding.DecodeString(serverConfigBase64)
if err != nil {
log.Errorf("base64 decode mcp server config failed:%s, bytes:%s", err, serverConfigBase64)
} else {
config = serverConfig
}
log.Infof("parse server config from request, config:%s", serverConfig)
} else {
config = s.config
}
if len(config) == 0 {
return
}
err := json.Unmarshal(config, v)
if err != nil {
log.Errorf("json unmarshal server config failed:%v, config:%s", err, config)
}
}
// Clone creates a copy of the server
// This method should be overridden by derived types
func (s *BaseMCPServer) Clone() Server {
panic("Clone method must be implemented by derived types")
}
// CloneBase creates a copy of the base server
func (s *BaseMCPServer) CloneBase() BaseMCPServer {
newServer := BaseMCPServer{
tools: make(map[string]Tool),
config: s.config,
}
for k, v := range s.tools {
newServer.tools[k] = v
}
return newServer
}

View File

@@ -0,0 +1,127 @@
package server
import (
"fmt"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts"
)
// ComposedMCPServer represents a server composed of tools from other servers.
type ComposedMCPServer struct {
name string // Name of the composed server (from toolSet.name)
serverTools []ServerToolConfig // Configuration of which tools to include
registry *GlobalToolRegistry // Reference to the global tool registry
config []byte // Configuration for the composed server itself (if any)
}
// NewComposedMCPServer creates a new ComposedMCPServer.
func NewComposedMCPServer(name string, serverToolsConfig []ServerToolConfig, registry *GlobalToolRegistry) *ComposedMCPServer {
return &ComposedMCPServer{
name: name,
serverTools: serverToolsConfig,
registry: registry,
}
}
// GetName returns the name of the composed server.
func (cs *ComposedMCPServer) GetName() string {
return cs.name
}
// AddMCPTool for ComposedMCPServer is a no-op as tools are defined by toolSet.
func (cs *ComposedMCPServer) AddMCPTool(name string, tool Tool) Server {
log.Warnf("AddMCPTool called on ComposedMCPServer '%s'; this is a no-op.", cs.name)
return cs
}
// GetMCPTools constructs and returns the map of tools exposed by this composed server.
// The tool names are prefixed with their original server name, e.g., "${originalServer}___${toolName}".
// The Tool instances are DescriptiveTool, only providing Description and InputSchema.
func (cs *ComposedMCPServer) GetMCPTools() map[string]Tool {
composedTools := make(map[string]Tool)
for _, stc := range cs.serverTools {
originalServerName := stc.ServerName
for _, originalToolName := range stc.Tools {
toolInfo, found := cs.registry.GetToolInfo(originalServerName, originalToolName)
if !found {
log.Warnf("Tool %s/%s not found in global registry for composed server %s", originalServerName, originalToolName, cs.name)
continue
}
composedToolName := fmt.Sprintf("%s%s%s", originalServerName, consts.ToolSetNameSplitter, originalToolName)
composedTools[composedToolName] = &DescriptiveTool{
description: toolInfo.Description,
inputSchema: toolInfo.InputSchema,
outputSchema: toolInfo.OutputSchema, // New field for MCP Protocol Version 2025-06-18
}
}
}
return composedTools
}
// SetConfig sets the configuration for the composed server itself.
func (cs *ComposedMCPServer) SetConfig(config []byte) {
cs.config = config
}
// GetConfig retrieves the configuration of the composed server itself.
func (cs *ComposedMCPServer) GetConfig(v any) {
if len(cs.config) == 0 {
return
}
if ptrBytes, ok := v.(*[]byte); ok {
*ptrBytes = cs.config
} else {
// If you need to unmarshal to a struct, you'd do it here.
// For now, keeping it simple as per previous discussions.
log.Warnf("ComposedMCPServer.GetConfig called with unhandled type for v. Config not set.")
}
}
// Clone creates a new instance of the ComposedMCPServer with the same configuration.
func (cs *ComposedMCPServer) Clone() Server {
cloned := NewComposedMCPServer(cs.name, cs.serverTools, cs.registry)
cloned.SetConfig(cs.config)
return cloned
}
// DescriptiveTool is a placeholder Tool implementation for ComposedMCPServer.
// Its Call and Create methods should never be invoked.
type DescriptiveTool struct {
description string
inputSchema map[string]any
outputSchema map[string]any // New field for MCP Protocol Version 2025-06-18
}
// Create for DescriptiveTool should not be called.
func (dt *DescriptiveTool) Create(params []byte) Tool {
log.Errorf("DescriptiveTool.Create called for tool used in ComposedMCPServer. This should not happen.")
// Return a new instance to fulfill the interface, though it's an error state.
return &DescriptiveTool{
description: dt.description,
inputSchema: dt.inputSchema,
outputSchema: dt.outputSchema,
}
}
// Call for DescriptiveTool should not be called.
func (dt *DescriptiveTool) Call(httpCtx HttpContext, server Server) error {
log.Errorf("DescriptiveTool.Call called for tool used in ComposedMCPServer. This should not happen.")
return fmt.Errorf("DescriptiveTool.Call should not be invoked on a ComposedMCPServer's tool")
}
// Description returns the tool's description.
func (dt *DescriptiveTool) Description() string {
return dt.description
}
// InputSchema returns the tool's input schema.
func (dt *DescriptiveTool) InputSchema() map[string]any {
return dt.inputSchema
}
// OutputSchema returns the tool's output schema (MCP Protocol Version 2025-06-18).
func (dt *DescriptiveTool) OutputSchema() map[string]any {
return dt.outputSchema
}

View File

@@ -0,0 +1,328 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"fmt"
"os"
"testing"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
// testLogger is a mock logger for testing to prevent panics
type testLogger struct{}
func (l *testLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) }
func (l *testLogger) Tracef(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...)
}
func (l *testLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) }
func (l *testLogger) Debugf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...)
}
func (l *testLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) }
func (l *testLogger) Infof(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...)
}
func (l *testLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) }
func (l *testLogger) Warnf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...)
}
func (l *testLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) }
func (l *testLogger) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...)
}
func (l *testLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) }
func (l *testLogger) Criticalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...)
}
func (l *testLogger) ResetID(pluginID string) {}
func init() {
// Set a custom logger for testing to prevent panics
log.SetPluginLog(&testLogger{})
}
// TestMcpProxyConfigValidation tests configuration validation for mcp-proxy servers
func TestMcpProxyConfigValidation(t *testing.T) {
tests := []struct {
name string
config string
shouldErr bool
errMsg string
}{
{
name: "valid basic proxy config",
config: `{
"server": {
"name": "test-proxy",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "http://backend.example.com/mcp",
"timeout": 5000
},
"tools": [
{
"name": "test-tool",
"description": "Test tool",
"args": [
{
"name": "input",
"description": "Input parameter",
"type": "string",
"required": true
}
]
}
]
}`,
shouldErr: false,
},
{
name: "proxy config with security schemes",
config: `{
"server": {
"name": "secure-proxy",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "https://secure.example.com/mcp",
"timeout": 8000,
"securitySchemes": [
{
"id": "ApiKeyAuth",
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
"defaultCredential": "test-key"
}
]
},
"tools": [
{
"name": "secure-tool",
"description": "Secure tool",
"args": [
{
"name": "data",
"description": "Data parameter",
"type": "object",
"required": true
}
],
"requestTemplate": {
"security": {
"id": "ApiKeyAuth"
}
}
}
]
}`,
shouldErr: false,
},
{
name: "missing mcpServerURL should fail",
config: `{
"server": {
"name": "invalid-proxy",
"type": "mcp-proxy",
"transport": "http",
"timeout": 5000
},
"tools": [
{
"name": "test-tool",
"description": "Test tool",
"args": []
}
]
}`,
shouldErr: true,
errMsg: "mcpServerURL is required",
},
{
name: "invalid server type should use default REST handling",
config: `{
"server": {
"name": "rest-server",
"type": "rest-api"
},
"tools": [
{
"name": "rest-tool",
"description": "REST tool",
"args": [],
"requestTemplate": {
"url": "http://example.com/api",
"method": "GET"
},
"responseTemplate": {
"body": "$.result"
}
}
]
}`,
shouldErr: false, // Should fall back to REST server logic
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configJson := gjson.Parse(tt.config)
config := &McpServerConfig{}
// Create validation options (similar to validator package)
toolRegistry := &GlobalToolRegistry{}
toolRegistry.Initialize()
opts := &ConfigOptions{
Servers: make(map[string]Server),
ToolRegistry: toolRegistry,
SkipPreRegisteredServers: true,
}
err := ParseConfigCore(configJson, config, opts)
if tt.shouldErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.NotNil(t, config)
}
})
}
}
// TestSecuritySchemeValidation tests security scheme configuration validation
func TestSecuritySchemeValidation(t *testing.T) {
tests := []struct {
name string
scheme SecurityScheme
shouldErr bool
}{
{
name: "valid API key scheme",
scheme: SecurityScheme{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
shouldErr: false,
},
{
name: "valid HTTP bearer scheme",
scheme: SecurityScheme{
ID: "BearerAuth",
Type: "http",
Scheme: "bearer",
},
shouldErr: false,
},
{
name: "invalid scheme - missing ID",
scheme: SecurityScheme{
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
shouldErr: true,
},
{
name: "invalid scheme - missing Name for apiKey",
scheme: SecurityScheme{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This will test the validation logic once SecurityScheme validation is implemented
err := ValidateSecurityScheme(tt.scheme)
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestToolConfigValidation tests tool configuration validation
func TestToolConfigValidation(t *testing.T) {
tests := []struct {
name string
toolCfg McpProxyToolConfig
shouldErr bool
}{
{
name: "valid tool config",
toolCfg: McpProxyToolConfig{
Name: "valid-tool",
Description: "A valid tool",
Args: []ToolArg{
{
Name: "param1",
Description: "Parameter 1",
Type: "string",
Required: true,
},
},
},
shouldErr: false,
},
{
name: "invalid tool - missing name",
toolCfg: McpProxyToolConfig{
Description: "Tool without name",
Args: []ToolArg{},
},
shouldErr: true,
},
{
name: "invalid tool - empty description",
toolCfg: McpProxyToolConfig{
Name: "tool-no-desc",
Description: "",
Args: []ToolArg{},
},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateToolConfig(tt.toolCfg)
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// These validation functions are now implemented in proxy_server.go

View File

@@ -0,0 +1,817 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"errors"
"fmt"
"net/url"
"reflect"
"slices"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/invopop/jsonschema"
"github.com/tidwall/gjson"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
const (
DefaultMaxBodyBytes uint32 = 100 * 1024 * 1024
GlobalToolRegistryKey = "GlobalToolRegistry"
)
// SupportedMCPVersions contains all supported MCP protocol versions
var SupportedMCPVersions = []string{"2024-11-05", "2025-03-26", "2025-06-18"}
// validateURL validates that the given string is a valid URL
func validateURL(urlStr string) error {
if urlStr == "" {
return errors.New("url cannot be empty")
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %v", err)
}
// Allow both full URLs (with scheme and host) and path-only URLs
// Path-only URLs will be resolved against the cluster's base URL
if parsedURL.Scheme != "" {
// If scheme is provided, host must also be provided
if parsedURL.Host == "" {
return errors.New("url with scheme must include a host")
}
// Only allow http and https schemes for security
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("unsupported URL scheme '%s', only http and https are allowed", parsedURL.Scheme)
}
}
return nil
}
// setupMcpProxyServer creates and configures an MCP proxy server
func setupMcpProxyServer(serverName string, serverJson gjson.Result, serverConfigJsonForInstance string) (*McpProxyServer, error) {
proxyServer := NewMcpProxyServer(serverName)
proxyServer.SetConfig([]byte(serverConfigJsonForInstance))
// Parse and validate transport (required for mcp-proxy)
transportStr := serverJson.Get("transport").String()
if transportStr == "" {
return nil, errors.New("transport field is required for mcp-proxy server type")
}
transport := TransportProtocol(transportStr)
if transport != TransportHTTP && transport != TransportSSE {
return nil, fmt.Errorf("invalid transport value: %s, must be 'http' or 'sse'", transportStr)
}
proxyServer.SetTransport(transport)
// Parse and validate mcpServerURL (required for mcp-proxy)
mcpServerURL := serverJson.Get("mcpServerURL").String()
if mcpServerURL == "" {
return nil, errors.New("mcpServerURL is required for mcp-proxy server type")
}
if err := validateURL(mcpServerURL); err != nil {
return nil, fmt.Errorf("invalid mcpServerURL: %v", err)
}
proxyServer.SetMcpServerURL(mcpServerURL)
// Parse timeout (optional)
timeout := serverJson.Get("timeout").Int()
if timeout > 0 {
proxyServer.SetTimeout(int(timeout))
}
// Parse passthroughAuthHeader (optional, defaults to false)
passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool()
proxyServer.SetPassthroughAuthHeader(passthroughAuthHeader)
// Parse security schemes
securitySchemesJson := serverJson.Get("securitySchemes")
if securitySchemesJson.Exists() {
for _, schemeJson := range securitySchemesJson.Array() {
var scheme SecurityScheme
if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil {
return nil, fmt.Errorf("failed to parse security scheme config: %v", err)
}
proxyServer.AddSecurityScheme(scheme)
}
}
// Parse default downstream security
defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity")
if defaultDownstreamSecurityJson.Exists() {
var defaultDownstreamSecurity SecurityRequirement
if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil {
return nil, fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err)
}
proxyServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
}
// Parse default upstream security
defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity")
if defaultUpstreamSecurityJson.Exists() {
var defaultUpstreamSecurity SecurityRequirement
if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil {
return nil, fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err)
}
proxyServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
}
return proxyServer, nil
}
type HttpContext wrapper.HttpContext
type Context struct {
servers map[string]Server
}
type CtxOption interface {
Apply(*Context)
}
var globalContext Context
// ToolInfo stores information about a tool for the global registry.
type ToolInfo struct {
Name string
Description string
InputSchema map[string]any
OutputSchema map[string]any // New field for MCP Protocol Version 2025-06-18
ServerName string // Original server name
Tool Tool // The actual tool instance for cloning
}
// GlobalToolRegistry holds all tools from all servers.
type GlobalToolRegistry struct {
// serverName -> toolName -> toolInfo
serverTools map[string]map[string]ToolInfo
}
// Initialize initializes the GlobalToolRegistry
func (r *GlobalToolRegistry) Initialize() {
r.serverTools = make(map[string]map[string]ToolInfo)
}
// RegisterTool registers a tool into the global registry.
func (r *GlobalToolRegistry) RegisterTool(serverName string, toolName string, tool Tool) {
if _, ok := r.serverTools[serverName]; !ok {
r.serverTools[serverName] = make(map[string]ToolInfo)
}
toolInfo := ToolInfo{
Name: toolName,
Description: tool.Description(),
InputSchema: tool.InputSchema(),
ServerName: serverName,
Tool: tool,
}
// Check if tool implements OutputSchema (MCP Protocol Version 2025-06-18)
if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok {
toolInfo.OutputSchema = toolWithSchema.OutputSchema()
}
r.serverTools[serverName][toolName] = toolInfo
log.Debugf("Registered tool %s/%s", serverName, toolName)
}
// GetToolInfo retrieves tool information from the global registry.
func (r *GlobalToolRegistry) GetToolInfo(serverName string, toolName string) (ToolInfo, bool) {
if serverTools, ok := r.serverTools[serverName]; ok {
toolInfo, found := serverTools[toolName]
return toolInfo, found
}
return ToolInfo{}, false
}
func onPluginStartOrReload(context wrapper.PluginContext) error {
toolRegistry := &GlobalToolRegistry{}
toolRegistry.Initialize()
context.SetContext(GlobalToolRegistryKey, toolRegistry)
context.EnableRuleLevelConfigIsolation()
return nil
}
// GetServer retrieves a server instance from the global context.
// This is needed by ComposedMCPServer to get original server instances.
func GetServerFromGlobalContext(serverName string) (Server, bool) {
server, exist := globalContext.servers[serverName]
return server, exist
}
type Server interface {
AddMCPTool(name string, tool Tool) Server
GetMCPTools() map[string]Tool // For single server, returns its tools. For composed, returns composed tools.
SetConfig(config []byte)
GetConfig(v any)
Clone() Server
// GetName() string // Returns the server name - REMOVED
}
type Tool interface {
Create(params []byte) Tool
Call(httpCtx HttpContext, server Server) error
Description() string
InputSchema() map[string]any
}
// ToolWithOutputSchema is an optional interface for tools that support output schema
// (MCP Protocol Version 2025-06-18). Tools can optionally implement this interface
// to provide output schema information.
type ToolWithOutputSchema interface {
Tool
OutputSchema() map[string]any
}
// ToolSetConfig defines the configuration for a toolset.
type ToolSetConfig struct {
Name string `json:"name"`
ServerTools []ServerToolConfig `json:"serverTools"`
}
// ServerToolConfig specifies which tools from a server to include in a toolset.
type ServerToolConfig struct {
ServerName string `json:"serverName"`
Tools []string `json:"tools"`
}
// ConfigOptions contains the dependencies needed for config parsing
type ConfigOptions struct {
Servers map[string]Server
ToolRegistry *GlobalToolRegistry
// Skip validation for pre-registered Go-based servers
SkipPreRegisteredServers bool
}
type McpServerConfig struct {
serverName string // Store the server name directly
server Server // Can be a single server or a composed server
methodHandlers utils.MethodHandlers
toolSet *ToolSetConfig // Parsed toolset configuration
isComposed bool
}
// GetServerName returns the server name for external access
func (c *McpServerConfig) GetServerName() string {
return c.serverName
}
// GetIsComposed returns whether this is a composed server for external access
func (c *McpServerConfig) GetIsComposed() bool {
return c.isComposed
}
// computeEffectiveAllowTools computes the effective allowTools by taking the intersection
// of config allowTools and request header allowTools.
// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set.
func computeEffectiveAllowTools(configAllowTools *map[string]struct{}) *map[string]struct{} {
// Get allowTools from request header
allowToolsHeaderStr, _ := proxywasm.GetHttpRequestHeader("x-envoy-allow-mcp-tools")
proxywasm.RemoveHttpRequestHeader("x-envoy-allow-mcp-tools")
// Only consider header as "present" if it has non-empty value
// Empty string means header is not set or explicitly empty, both treated as "no restriction"
headerExists := allowToolsHeaderStr != ""
return computeEffectiveAllowToolsFromHeader(configAllowTools, allowToolsHeaderStr, headerExists)
}
// computeEffectiveAllowToolsFromHeader computes the effective allowTools by taking the intersection
// of config allowTools and header allowTools string.
// This is useful when the header string is already extracted (e.g., in async callbacks).
// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set.
func computeEffectiveAllowToolsFromHeader(configAllowTools *map[string]struct{}, allowToolsHeaderStr string, headerExists bool) *map[string]struct{} {
var allowToolsFromHeader *map[string]struct{}
if headerExists {
// Header is present (even if empty string), parse it
headerMap := make(map[string]struct{})
for tool := range strings.SplitSeq(allowToolsHeaderStr, ",") {
trimmedTool := strings.TrimSpace(tool)
if trimmedTool == "" {
continue
}
headerMap[trimmedTool] = struct{}{}
}
// Always create pointer even if map is empty, to distinguish from "not configured"
allowToolsFromHeader = &headerMap
}
// Compute effective allowTools (intersection of config and header)
if configAllowTools == nil && allowToolsFromHeader == nil {
// Both not configured, allow all tools
return nil
} else if configAllowTools == nil {
// Only header restrictions
return allowToolsFromHeader
} else if allowToolsFromHeader == nil {
// Only config restrictions
return configAllowTools
} else {
// Both restrictions exist, compute intersection
intersection := make(map[string]struct{})
for tool := range *configAllowTools {
if _, exists := (*allowToolsFromHeader)[tool]; exists {
intersection[tool] = struct{}{}
}
}
return &intersection
}
}
// parseConfigCore contains the core config parsing logic with dependency injection
func parseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error {
toolSetJson := configJson.Get("toolSet")
serverJson := configJson.Get("server") // This is for single server or REST server definition
pluginServerConfigJson := configJson.Get("server.config").Raw // Config for the plugin instance itself, if any.
// serverConfigJsonForInstance is the config passed to the specific server instance (single or REST)
// It's distinct from pluginServerConfigJson which might be for the mcp-server plugin itself.
var serverConfigJsonForInstance string
if toolSetJson.Exists() {
config.isComposed = true
var tsConfig ToolSetConfig
if err := json.Unmarshal([]byte(toolSetJson.Raw), &tsConfig); err != nil {
return fmt.Errorf("failed to parse toolSet config: %v", err)
}
config.toolSet = &tsConfig
config.serverName = tsConfig.Name // Use toolSet name as the server name for composed server
log.Infof("Parsing toolSet configuration: %s", config.serverName)
composedServer := NewComposedMCPServer(config.serverName, tsConfig.ServerTools, opts.ToolRegistry)
// A composed server itself might have a config block, e.g. for shared settings, though not typical.
composedServer.SetConfig([]byte(pluginServerConfigJson))
config.server = composedServer
} else if serverJson.Exists() {
config.isComposed = false
config.serverName = serverJson.Get("name").String()
if config.serverName == "" {
return errors.New("server.name field is missing for single server config")
}
// This is the config for the specific server being defined (e.g. REST server's own config)
serverConfigJsonForInstance = serverJson.Get("config").Raw
log.Infof("Parsing single server configuration: %s", config.serverName)
// Check server type to determine which type of server to create
serverType := serverJson.Get("type").String()
if serverType == "" {
serverType = "rest" // Default to REST server type
}
toolsJson := configJson.Get("tools") // These are REST tools for this server instance or MCP proxy tools
if serverType == "mcp-proxy" {
// Create MCP proxy server
proxyServer, err := setupMcpProxyServer(config.serverName, serverJson, serverConfigJsonForInstance)
if err != nil {
return err
}
// Handle tools configuration (optional for MCP proxy)
if toolsJson.Exists() && len(toolsJson.Array()) > 0 {
for _, toolJson := range toolsJson.Array() {
var proxyTool McpProxyToolConfig
if err := json.Unmarshal([]byte(toolJson.Raw), &proxyTool); err != nil {
return fmt.Errorf("failed to parse proxy tool config: %v", err)
}
if err := proxyServer.AddProxyTool(proxyTool); err != nil {
return fmt.Errorf("failed to add proxy tool %s: %v", proxyTool.Name, err)
}
// Register tool to registry
opts.ToolRegistry.RegisterTool(config.serverName, proxyTool.Name, proxyServer.GetMCPTools()[proxyTool.Name])
}
}
// Set the proxy server regardless of whether tools are configured
config.server = proxyServer
} else if toolsJson.Exists() && len(toolsJson.Array()) > 0 {
// Handle REST-to-MCP server (requires tools configuration)
// Create REST-to-MCP server (default behavior)
restServer := NewRestMCPServer(config.serverName) // Pass the server name
restServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config
securitySchemesJson := serverJson.Get("securitySchemes")
if securitySchemesJson.Exists() {
for _, schemeJson := range securitySchemesJson.Array() {
var scheme SecurityScheme
if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil {
return fmt.Errorf("failed to parse security scheme config: %v", err)
}
restServer.AddSecurityScheme(scheme)
}
}
// Parse default downstream security
defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity")
if defaultDownstreamSecurityJson.Exists() {
var defaultDownstreamSecurity SecurityRequirement
if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil {
return fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err)
}
restServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
}
// Parse default upstream security
defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity")
if defaultUpstreamSecurityJson.Exists() {
var defaultUpstreamSecurity SecurityRequirement
if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil {
return fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err)
}
restServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
}
// Parse passthroughAuthHeader (optional, defaults to false)
passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool()
restServer.SetPassthroughAuthHeader(passthroughAuthHeader)
for _, toolJson := range toolsJson.Array() {
var restTool RestTool
if err := json.Unmarshal([]byte(toolJson.Raw), &restTool); err != nil {
return fmt.Errorf("failed to parse tool config: %v", err)
}
if err := restServer.AddRestTool(restTool); err != nil {
return fmt.Errorf("failed to add tool %s: %v", restTool.Name, err)
}
// Register tool to registry
opts.ToolRegistry.RegisterTool(config.serverName, restTool.Name, restServer.GetMCPTools()[restTool.Name])
}
config.server = restServer
} else {
// Logic for pre-registered Go-based servers (non-REST)
if opts.SkipPreRegisteredServers {
// In validation mode, skip pre-registered servers validation
// Just validate the basic structure without actual server instance
config.server = nil // Will be handled appropriately in validation context
} else {
if serverInstance, exist := opts.Servers[config.serverName]; exist {
clonedServer := serverInstance.Clone()
clonedServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config
config.server = clonedServer
// Register tools from this server to registry
for toolName, toolInstance := range clonedServer.GetMCPTools() {
opts.ToolRegistry.RegisterTool(config.serverName, toolName, toolInstance)
}
} else {
return fmt.Errorf("mcp server type '%s' not registered", config.serverName)
}
}
}
} else {
return errors.New("either 'server' or 'toolSet' field must be present in the configuration")
}
// Parse allowTools - this might need adjustment for composed servers
// Use pointer to distinguish between "not configured" (nil) and "configured as empty" (empty map)
var allowTools *map[string]struct{} // For single server, tool name. For composed, serverName/toolName.
allowToolsResult := configJson.Get("allowTools")
if allowToolsResult.Exists() {
// allowTools is configured, create the map
toolsMap := make(map[string]struct{})
allowToolsArray := allowToolsResult.Array()
for _, toolJson := range allowToolsArray {
toolsMap[toolJson.String()] = struct{}{}
}
allowTools = &toolsMap
}
// If allowTools is nil, it means not configured (allow all)
config.methodHandlers = make(utils.MethodHandlers)
// Use config.serverName which is now reliably set
currentServerNameForHandlers := config.serverName
config.methodHandlers["ping"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
utils.OnMCPResponseSuccess(ctx, map[string]any{}, fmt.Sprintf("mcp:%s:ping", currentServerNameForHandlers))
return nil
}
config.methodHandlers["notifications/initialized"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/initialized", currentServerNameForHandlers), nil, nil, -1)
return nil
}
config.methodHandlers["notifications/cancelled"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/cancelled", currentServerNameForHandlers), nil, nil, -1)
return nil
}
config.methodHandlers["initialize"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
requestedVersion := params.Get("protocolVersion").String()
if requestedVersion == "" {
utils.OnMCPResponseError(ctx, errors.New("protocolVersion is required"), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:initialize:error", currentServerNameForHandlers))
return nil
}
// MCP specification compliant version negotiation:
// If the server supports the requested protocol version, it MUST respond with the same version.
// Otherwise, the server MUST respond with another protocol version it supports.
// This SHOULD be the latest version supported by the server.
negotiatedVersion := requestedVersion
if !slices.Contains(SupportedMCPVersions, requestedVersion) {
// Return the latest supported version instead of rejecting the request
negotiatedVersion = SupportedMCPVersions[len(SupportedMCPVersions)-1]
log.Warnf("Client requested unsupported version %s, responding with latest supported version %s",
requestedVersion, negotiatedVersion)
}
utils.OnMCPResponseSuccess(ctx, map[string]any{
"protocolVersion": negotiatedVersion,
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": currentServerNameForHandlers, // Use the actual server name (single or composed)
"version": "1.0.0",
},
}, fmt.Sprintf("mcp:%s:initialize", currentServerNameForHandlers))
return nil
}
// Override tools/list and tools/call handlers for MCP proxy servers first
if config.server != nil {
if proxyServer, ok := config.server.(*McpProxyServer); ok {
// Use MCP proxy specific handlers that support ActionPause
proxyHandlers := CreateMcpProxyMethodHandlers(proxyServer, allowTools)
config.methodHandlers["tools/list"] = proxyHandlers["tools/list"]
config.methodHandlers["tools/call"] = proxyHandlers["tools/call"]
}
}
// Default tools/list handler for non-proxy servers
if config.methodHandlers["tools/list"] == nil {
config.methodHandlers["tools/list"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
var listedTools []map[string]any
// GetMCPTools() will return appropriately formatted tools for both single and composed servers
allTools := config.server.GetMCPTools() // For composed, keys are "serverName/toolName"
// Compute effective allowTools using helper function
effectiveAllowTools := computeEffectiveAllowTools(allowTools)
for toolFullName, tool := range allTools {
// For composed server, toolFullName is "originalServerName/originalToolName"
// For single server, toolFullName is "originalToolName"
// The allowTools map should use the same format as toolFullName
if effectiveAllowTools != nil {
if _, allow := (*effectiveAllowTools)[toolFullName]; !allow {
continue
}
}
toolDef := map[string]any{
"name": toolFullName,
"description": tool.Description(),
"inputSchema": tool.InputSchema(),
}
// Add outputSchema if tool implements ToolWithOutputSchema (MCP Protocol Version 2025-06-18)
if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok {
if outputSchema := toolWithSchema.OutputSchema(); len(outputSchema) > 0 {
toolDef["outputSchema"] = outputSchema
}
}
listedTools = append(listedTools, toolDef)
}
utils.OnMCPResponseSuccess(ctx, map[string]any{
"tools": listedTools,
}, fmt.Sprintf("mcp:%s:tools/list", currentServerNameForHandlers))
return nil
}
}
// Default tools/call handler for non-proxy servers
if config.methodHandlers["tools/call"] == nil {
config.methodHandlers["tools/call"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
if config.isComposed {
// This endpoint is for a composed server (toolSet).
// Actual tool calls should be routed by mcp-router to individual servers.
// If a tools/call request reaches here, it's a misconfiguration or unexpected.
errMsg := fmt.Sprintf("tools/call is not supported on a composed toolSet endpoint ('%s'). It should be routed by mcp-router to the target server.", currentServerNameForHandlers)
log.Errorf(errMsg)
utils.OnMCPResponseError(ctx, errors.New(errMsg), utils.ErrMethodNotFound, fmt.Sprintf("mcp:%s:tools/call:not_supported_on_toolset", currentServerNameForHandlers))
return nil
}
// Logic for single (non-composed) server
toolName := params.Get("name").String() // For single server, this is the direct tool name
args := params.Get("arguments")
// Compute effective allowTools using helper function
effectiveAllowTools := computeEffectiveAllowTools(allowTools)
// Check if tool is allowed
if effectiveAllowTools != nil {
if _, allow := (*effectiveAllowTools)[toolName]; !allow {
utils.OnMCPResponseError(ctx, fmt.Errorf("Tool not allowed: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:tool_not_allowed", currentServerNameForHandlers))
return nil
}
}
proxywasm.SetProperty([]string{"mcp_server_name"}, []byte(currentServerNameForHandlers))
proxywasm.SetProperty([]string{"mcp_tool_name"}, []byte(toolName))
toolToCall, ok := config.server.GetMCPTools()[toolName]
if !ok {
utils.OnMCPResponseError(ctx, fmt.Errorf("unknown tool: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:invalid_tool_name", currentServerNameForHandlers))
return nil
}
log.Debugf("Tool call [%s] on server [%s] with arguments[%s]", toolName, currentServerNameForHandlers, args.Raw)
toolInstance := toolToCall.Create([]byte(args.Raw))
err := toolInstance.Call(ctx, config.server) // Pass the single server instance
if err != nil {
utils.OnMCPToolCallError(ctx, err)
return nil
}
return nil
}
}
return nil
}
// ParseConfigCore exports the core parsing logic for external use (e.g., validation)
func ParseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error {
return parseConfigCore(configJson, config, opts)
}
func parseConfig(context wrapper.PluginContext, configJson gjson.Result, config *McpServerConfig) error {
registryI := context.GetContext(GlobalToolRegistryKey)
if registryI == nil {
return errors.New("GlobalToolRegistry not found")
}
registry, ok := registryI.(*GlobalToolRegistry)
if !ok {
return errors.New("invalid GlobalToolRegistry")
}
// Build runtime dependencies using global variables
opts := &ConfigOptions{
Servers: globalContext.servers,
ToolRegistry: registry,
}
// Call the core parsing logic
return parseConfigCore(configJson, config, opts)
}
func Load(options ...CtxOption) {
for _, opt := range options {
opt.Apply(&globalContext)
}
}
func Initialize() {
if globalContext.servers == nil {
panic("At least one mcpserver needs to be added.")
}
wrapper.SetCtx(
"mcp-server",
wrapper.PrePluginStartOrReload[McpServerConfig](onPluginStartOrReload),
wrapper.ParseConfigWithContext(parseConfig),
wrapper.WithLogger[McpServerConfig](&utils.MCPServerLog{}),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
wrapper.WithRebuildMaxMemBytes[McpServerConfig](200*1024*1024),
)
}
type addMCPServerOption struct {
name string
server Server
}
func AddMCPServer(name string, server Server) CtxOption {
return &addMCPServerOption{
name: name,
server: server,
}
}
func (o *addMCPServerOption) Apply(ctx *Context) {
if ctx.servers == nil {
ctx.servers = make(map[string]Server)
}
if _, exist := ctx.servers[o.name]; exist {
panic(fmt.Sprintf("Conflict! There is a mcp server with the same name:%s",
o.name))
}
ctx.servers[o.name] = o.server
}
func ToInputSchema(v any) map[string]any {
t := reflect.TypeOf(v)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
inputSchema := jsonschema.Reflect(v).Definitions[t.Name()]
inputSchemaBytes, _ := json.Marshal(inputSchema)
var result map[string]any
json.Unmarshal(inputSchemaBytes, &result)
return result
}
func StoreServerState(ctx wrapper.HttpContext, config any) {
if utils.IsStatefulSession(ctx) {
log.Warnf("There is no session ID, unable to store state.")
return
}
configBytes, err := json.Marshal(config)
if err != nil {
log.Errorf("Server config marshal failed:%v, config:%s", err, configBytes)
return
}
proxywasm.SetProperty([]string{"mcp_server_config"}, configBytes)
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action {
ctx.DisableReroute()
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
ctx.SetResponseBodyBufferLimit(DefaultMaxBodyBytes)
// Remove accept-encoding header to prevent backend from compressing the response
// This ensures we can properly process and modify the response body
proxywasm.RemoveHttpRequestHeader("accept-encoding")
// Parse MCP-Protocol-Version header and store in context
// This allows clients to specify the MCP protocol version via HTTP header
// instead of only through the JSON-RPC initialize method
protocolVersion, _ := proxywasm.GetHttpRequestHeader("MCP-Protocol-Version")
if protocolVersion != "" {
// Validate the protocol version against supported versions
if slices.Contains(SupportedMCPVersions, protocolVersion) {
log.Debugf("MCP Protocol Version set from header: %s", protocolVersion)
} else {
log.Warnf("Unsupported MCP Protocol Version in header: %s", protocolVersion)
}
// Remove the header from the request to prevent it from being forwarded
proxywasm.RemoveHttpRequestHeader("MCP-Protocol-Version")
}
if ctx.Method() == "GET" {
proxywasm.SendHttpResponseWithDetail(405, "not_support_sse_on_this_endpoint", nil, nil, -1)
return types.HeaderStopAllIterationAndWatermark
}
// Handle DELETE request for session termination (MCP 2025-06-18 spec)
// Per spec: "Clients that no longer need a particular session SHOULD send an HTTP DELETE
// to the MCP endpoint with the Mcp-Session-Id header, to explicitly terminate the session."
// Per spec: "The server MAY respond to this request with HTTP 405 Method Not Allowed,
// indicating that the server does not allow clients to terminate sessions."
if ctx.Method() == "DELETE" {
proxywasm.SendHttpResponseWithDetail(405, "session_termination_not_supported", nil, nil, -1)
return types.HeaderStopAllIterationAndWatermark
}
if !ctx.HasRequestBody() {
proxywasm.SendHttpResponseWithDetail(400, "missing_body_in_mcp_request", nil, nil, -1)
return types.HeaderStopAllIterationAndWatermark
}
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config McpServerConfig, body []byte) types.Action {
return utils.HandleJsonRpcMethod(ctx, body, config.methodHandlers)
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action {
// Check if this request initiated SSE channel (tools/list or tools/call with SSE transport)
// Only these requests need special SSE streaming response processing
if ctx.GetContext(CtxSSEProxyState) != nil {
// Check if response has a body
if ctx.HasResponseBody() {
// Pause streaming response for processing
// Content-type validation will be done in onHttpStreamingResponseBody
ctx.NeedPauseStreamingResponse()
return types.HeaderStopIteration
} else {
// No body, return error
utils.OnMCPResponseError(ctx, fmt.Errorf("no response body in SSE response"), utils.ErrInternalError, "mcp-proxy:sse:no_body")
return types.HeaderStopAllIterationAndWatermark
}
}
// For non-SSE streaming requests, continue normally
return types.HeaderContinue
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte {
// Check if this request initiated SSE channel (tools/list or tools/call with SSE transport)
// Only these requests need special SSE streaming response processing
if ctx.GetContext(CtxSSEProxyState) != nil {
return handleSSEStreamingResponse(ctx, config, data, endOfStream)
}
// For non-SSE streaming requests, return data as-is
return data
}

View File

@@ -0,0 +1,429 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestApiKeyAuthentication tests API key authentication forwarding
func TestApiKeyAuthentication(t *testing.T) {
server := NewMcpProxyServer("auth-test")
// Configure security scheme
scheme := SecurityScheme{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
DefaultCredential: "default-api-key",
}
server.AddSecurityScheme(scheme)
// Set server fields directly
server.SetMcpServerURL("http://secure-backend.example.com/mcp")
server.SetTimeout(5000)
// Create tool with client-to-gateway and gateway-to-backend security
toolConfig := McpProxyToolConfig{
Name: "secure_tool",
Description: "Tool requiring authentication",
Security: SecurityRequirement{
ID: "ApiKeyAuth", // Client-to-gateway authentication
Passthrough: true, // Extract client credential for backend use
},
Args: []ToolArg{
{
Name: "data",
Description: "Data parameter",
Type: "string",
Required: true,
},
},
OutputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"result": map[string]any{
"type": "string",
"description": "The result of the operation",
},
},
},
RequestTemplate: RequestTemplate{
Security: SecurityRequirement{
ID: "ApiKeyAuth", // Gateway-to-backend authentication (same scheme for simplicity)
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()["secure_tool"]
require.True(t, exists)
params := map[string]interface{}{
"data": "test data",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Authentication is now handled automatically during tool calls
// The actual authentication flow is tested in integration tests
}
// TestBearerAuthentication tests Bearer token authentication
func TestBearerAuthentication(t *testing.T) {
server := NewMcpProxyServer("bearer-auth-test")
// Configure Bearer security scheme
scheme := SecurityScheme{
ID: "BearerAuth",
Type: "http",
Scheme: "bearer",
}
server.AddSecurityScheme(scheme)
// Set server fields directly
server.SetMcpServerURL("https://secure-backend.example.com/mcp")
server.SetTimeout(8000)
// Create tool with Bearer authentication
// Create tool using only gateway-to-backend authentication (no client auth required)
toolConfig := McpProxyToolConfig{
Name: "bearer_tool",
Description: "Tool with Bearer authentication to backend only",
Args: []ToolArg{
{
Name: "query",
Description: "Query parameter",
Type: "string",
Required: true,
},
},
RequestTemplate: RequestTemplate{
Security: SecurityRequirement{
ID: "BearerAuth", // Only gateway-to-backend authentication
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()["bearer_tool"]
require.True(t, exists)
params := map[string]interface{}{
"query": "test query",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Authentication is now handled automatically during tool calls
// The actual authentication flow is tested in integration tests
// Test backward compatibility: this tool uses RequestTemplate.Security (legacy way)
// which should still work
}
// TestBasicAuthentication tests Basic authentication
func TestBasicAuthentication(t *testing.T) {
server := NewMcpProxyServer("basic-auth-test")
// Configure Basic security scheme
scheme := SecurityScheme{
ID: "BasicAuth",
Type: "http",
Scheme: "basic",
}
server.AddSecurityScheme(scheme)
// Test tool call with Basic authentication
toolConfig := McpProxyToolConfig{
Name: "basic_tool",
Description: "Tool with Basic authentication",
Args: []ToolArg{
{
Name: "resource",
Description: "Resource identifier",
Type: "string",
Required: true,
},
},
RequestTemplate: RequestTemplate{
Security: SecurityRequirement{
ID: "BasicAuth",
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()["basic_tool"]
require.True(t, exists)
params := map[string]interface{}{
"resource": "test-resource",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Authentication is now handled automatically during tool calls
// The actual authentication flow is tested in integration tests
// Test OutputSchema functionality (only for tools that have it configured)
if toolWithOutputSchema, ok := tool.(ToolWithOutputSchema); ok {
outputSchema := toolWithOutputSchema.OutputSchema()
if outputSchema != nil {
// Only validate if outputSchema is configured
assert.Equal(t, "object", outputSchema["type"])
properties, hasProperties := outputSchema["properties"].(map[string]any)
require.True(t, hasProperties)
resultSchema, hasResult := properties["result"].(map[string]any)
require.True(t, hasResult)
assert.Equal(t, "string", resultSchema["type"])
assert.Equal(t, "The result of the operation", resultSchema["description"])
}
}
}
// TestMultipleSecuritySchemes tests multiple security schemes in one server
func TestMultipleSecuritySchemes(t *testing.T) {
server := NewMcpProxyServer("multi-auth-test")
// Add multiple security schemes
schemes := []SecurityScheme{
{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
{
ID: "BearerAuth",
Type: "http",
Scheme: "bearer",
},
}
for _, scheme := range schemes {
server.AddSecurityScheme(scheme)
}
// Test that both schemes are available
for _, scheme := range schemes {
retrievedScheme, exists := server.GetSecurityScheme(scheme.ID)
assert.True(t, exists)
assert.Equal(t, scheme.ID, retrievedScheme.ID)
assert.Equal(t, scheme.Type, retrievedScheme.Type)
}
}
// ProxyAuthContext, RequestTemplate, SecurityConfig and authentication methods
// are now implemented in proxy_server.go
// TestToolsListAuthentication tests authentication configuration for tools/list requests
func TestToolsListAuthentication(t *testing.T) {
server := NewMcpProxyServer("test-server")
// Add a security scheme for global authentication
scheme := SecurityScheme{
ID: "GlobalAuth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
DefaultCredential: "default-global-key",
}
server.AddSecurityScheme(scheme)
// Test that we can retrieve the security scheme
retrievedScheme, exists := server.GetSecurityScheme("GlobalAuth")
assert.True(t, exists)
assert.Equal(t, "GlobalAuth", retrievedScheme.ID)
assert.Equal(t, "apiKey", retrievedScheme.Type)
assert.Equal(t, "header", retrievedScheme.In)
assert.Equal(t, "X-API-Key", retrievedScheme.Name)
// Test setting default security directly on server
defaultDownstreamSecurity := SecurityRequirement{
ID: "GlobalAuth",
Passthrough: true,
}
defaultUpstreamSecurity := SecurityRequirement{
ID: "GlobalAuth",
}
server.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
server.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
// Verify default security settings
retrievedDownstream := server.GetDefaultDownstreamSecurity()
assert.Equal(t, "GlobalAuth", retrievedDownstream.ID)
assert.True(t, retrievedDownstream.Passthrough)
retrievedUpstream := server.GetDefaultUpstreamSecurity()
assert.Equal(t, "GlobalAuth", retrievedUpstream.ID)
t.Logf("Tools/list authentication configuration test completed successfully")
}
// TestDefaultSecurityFallback tests the fallback mechanism from tool-level to default security
func TestDefaultSecurityFallback(t *testing.T) {
server := NewMcpProxyServer("test-server")
// Add security schemes
defaultScheme := SecurityScheme{
ID: "DefaultAuth",
Type: "apiKey",
In: "header",
Name: "X-Default-Key",
DefaultCredential: "default-key",
}
toolScheme := SecurityScheme{
ID: "ToolAuth",
Type: "apiKey",
In: "header",
Name: "X-Tool-Key",
DefaultCredential: "tool-key",
}
server.AddSecurityScheme(defaultScheme)
server.AddSecurityScheme(toolScheme)
// Test tool configuration with tool-level security (should use tool-level, not default)
toolConfigWithSecurity := McpProxyToolConfig{
Name: "secure_tool",
Description: "Tool with its own security",
Security: SecurityRequirement{
ID: "ToolAuth",
Passthrough: true,
},
RequestTemplate: RequestTemplate{
Security: SecurityRequirement{
ID: "ToolAuth",
},
},
}
// Test tool configuration without tool-level security (should fallback to default)
toolConfigWithoutSecurity := McpProxyToolConfig{
Name: "fallback_tool",
Description: "Tool that falls back to default security",
// No Security field configured, should use default
RequestTemplate: RequestTemplate{
// No Security field configured, should use default
},
}
// Set default security directly on server
server.SetDefaultDownstreamSecurity(SecurityRequirement{
ID: "DefaultAuth",
Passthrough: false,
})
server.SetDefaultUpstreamSecurity(SecurityRequirement{
ID: "DefaultAuth",
})
// Set server configuration directly
server.SetMcpServerURL("http://backend.example.com")
server.SetTimeout(5000)
// Add tools to server
err := server.AddProxyTool(toolConfigWithSecurity)
assert.NoError(t, err)
err = server.AddProxyTool(toolConfigWithoutSecurity)
assert.NoError(t, err)
// Verify tools were added
tools := server.GetMCPTools()
assert.Contains(t, tools, "secure_tool")
assert.Contains(t, tools, "fallback_tool")
t.Logf("Default security fallback test completed successfully")
}
// TestURLModificationInAuthentication tests that authentication can modify the URL (e.g., adding query parameters)
func TestURLModificationInAuthentication(t *testing.T) {
server := NewMcpProxyServer("test-server")
// Add a security scheme that adds parameters to query (apiKey in query)
scheme := SecurityScheme{
ID: "QueryApiKey",
Type: "apiKey",
In: "query",
Name: "api_key",
DefaultCredential: "test-key-123",
}
server.AddSecurityScheme(scheme)
// Verify the security scheme was added correctly
retrievedScheme, exists := server.GetSecurityScheme("QueryApiKey")
assert.True(t, exists)
assert.Equal(t, "apiKey", retrievedScheme.Type)
assert.Equal(t, "query", retrievedScheme.In)
assert.Equal(t, "api_key", retrievedScheme.Name)
t.Logf("URL modification authentication configuration test completed successfully")
}
// TestProxyServerFields tests the server-level field setting and getting
func TestProxyServerFields(t *testing.T) {
server := NewMcpProxyServer("test-server")
// Test mcpServerURL
testURL := "http://mcp.example.com:8080/mcp"
server.SetMcpServerURL(testURL)
assert.Equal(t, testURL, server.GetMcpServerURL())
// Test timeout
testTimeout := 10000
server.SetTimeout(testTimeout)
assert.Equal(t, testTimeout, server.GetTimeout())
// Test default security settings
downstreamSec := SecurityRequirement{
ID: "test-downstream",
Passthrough: true,
}
upstreamSec := SecurityRequirement{
ID: "test-upstream",
}
server.SetDefaultDownstreamSecurity(downstreamSec)
server.SetDefaultUpstreamSecurity(upstreamSec)
assert.Equal(t, "test-downstream", server.GetDefaultDownstreamSecurity().ID)
assert.True(t, server.GetDefaultDownstreamSecurity().Passthrough)
assert.Equal(t, "test-upstream", server.GetDefaultUpstreamSecurity().ID)
t.Logf("Proxy server fields test completed successfully")
}

View File

@@ -0,0 +1,302 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// MockHttpContext is a mock implementation for testing - skipping interface implementation for now
// Tests that require full HttpContext will be tested in integration tests with real host
type MockHttpContext struct {
responseBody []byte
responseStatus int
headers map[string]string
}
// TestMcpProtocolInitialization tests the MCP protocol initialization flow
func TestMcpProtocolInitialization(t *testing.T) {
// Create proxy server
server := NewMcpProxyServer("test-proxy")
// Set server fields directly
server.SetMcpServerURL("http://mock-backend.example.com/mcp")
server.SetTimeout(5000)
// Create proxy tool
toolConfig := McpProxyToolConfig{
Name: "test-tool",
Description: "Test tool for initialization",
Args: []ToolArg{
{
Name: "input",
Description: "Test input",
Type: "string",
Required: true,
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()["test-tool"]
require.True(t, exists)
// Create tool instance with parameters
params := map[string]interface{}{
"input": "test value",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Skip HttpContext-dependent test for now - will be tested in integration
// mockCtx := &MockHttpContext{}
// err = toolInstance.Call(mockCtx, server)
// assert.NoError(t, err)
// Test the tool creation was successful
assert.NotNil(t, toolInstance)
}
// TestMcpSessionManagement tests temporary session creation and cleanup
func TestMcpSessionManagement(t *testing.T) {
_ = NewMcpProxyServer("session-test")
// Skip session management test until implemented
t.Skip("Session management not implemented yet")
// Test session creation
sessionManager := NewMcpSessionManager()
sessionID, err := sessionManager.CreateSession("http://backend.example.com/mcp")
// This will fail until session management is implemented
assert.NoError(t, err)
assert.NotEmpty(t, sessionID)
// Test session retrieval
session, exists := sessionManager.GetSession(sessionID)
assert.True(t, exists)
assert.NotNil(t, session)
// Test session cleanup
sessionManager.CleanupSession(sessionID)
_, exists = sessionManager.GetSession(sessionID)
assert.False(t, exists)
}
// TestMcpProtocolVersionNegotiation tests protocol version handling
func TestMcpProtocolVersionNegotiation(t *testing.T) {
tests := []struct {
name string
requestedVersion string
supportedVersions []string
shouldSucceed bool
expectedVersion string
}{
{
name: "supported version 2025-03-26",
requestedVersion: "2025-03-26",
supportedVersions: []string{"2024-11-05", "2025-03-26"},
shouldSucceed: true,
expectedVersion: "2025-03-26",
},
{
name: "unsupported version",
requestedVersion: "2026-01-01",
supportedVersions: []string{"2024-11-05", "2025-03-26"},
shouldSucceed: false,
expectedVersion: "",
},
{
name: "fallback to supported version",
requestedVersion: "2025-06-18",
supportedVersions: []string{"2024-11-05", "2025-03-26"},
shouldSucceed: false,
expectedVersion: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Skip until NewMcpVersionNegotiator is implemented
t.Skip("Version negotiation not implemented yet")
negotiator := NewMcpVersionNegotiator(tt.supportedVersions)
version, err := negotiator.NegotiateVersion(tt.requestedVersion)
if tt.shouldSucceed {
assert.NoError(t, err)
assert.Equal(t, tt.expectedVersion, version)
} else {
assert.Error(t, err)
}
})
}
}
// TestMcpInitializeRequest tests the initialize request format and handling
func TestMcpInitializeRequest(t *testing.T) {
_ = NewMcpProxyServer("init-test")
// Skip until CreateInitializeRequest is implemented
t.Skip("MCP protocol initialization not implemented yet")
// Test initialize request creation
initRequest := CreateInitializeRequest()
assert.Equal(t, "2.0", initRequest.JsonRPC)
assert.Equal(t, "initialize", initRequest.Method)
assert.NotNil(t, initRequest.Params)
// Validate client info
params := initRequest.Params.(map[string]interface{})
clientInfo := params["clientInfo"].(map[string]interface{})
assert.Equal(t, "Higress-mcp-proxy", clientInfo["name"])
assert.Equal(t, "1.0.0", clientInfo["version"])
// Test protocol version
assert.Equal(t, "2025-03-26", params["protocolVersion"])
}
// TestMcpNotificationsInitialized tests the notifications/initialized message
func TestMcpNotificationsInitialized(t *testing.T) {
// Skip until CreateInitializedNotification is implemented
t.Skip("MCP notifications not implemented yet")
// Test notifications/initialized request creation
notification := CreateInitializedNotification()
assert.Equal(t, "2.0", notification.JsonRPC)
assert.Equal(t, "notifications/initialized", notification.Method)
assert.Nil(t, notification.ID) // Notifications don't have IDs
}
// TestMcpErrorHandling tests error response handling and source identification
func TestMcpErrorHandling(t *testing.T) {
tests := []struct {
name string
errorType string
originalError error
expectedSource string
expectedCode int
}{
{
name: "backend connection error",
errorType: "connection",
originalError: assert.AnError,
expectedSource: "mcp-proxy",
expectedCode: -32603,
},
{
name: "backend timeout error",
errorType: "timeout",
originalError: assert.AnError,
expectedSource: "mcp-proxy",
expectedCode: -32000,
},
{
name: "protocol version error",
errorType: "version",
originalError: assert.AnError,
expectedSource: "mcp-proxy",
expectedCode: -32602,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Skip until CreateMcpErrorResponse is implemented
t.Skip("MCP error handling not implemented yet")
errorResponse := CreateMcpErrorResponse(tt.errorType, tt.originalError, "http://backend.example.com/mcp")
assert.Equal(t, "2.0", errorResponse.JsonRPC)
assert.NotNil(t, errorResponse.Error)
assert.Equal(t, tt.expectedCode, errorResponse.Error.Code)
assert.Equal(t, tt.expectedSource, errorResponse.Error.Data["source"])
})
}
}
// Helper types and functions that will fail until implemented
type McpSessionManager struct{}
func NewMcpSessionManager() *McpSessionManager {
panic("McpSessionManager not implemented yet")
}
func (m *McpSessionManager) CreateSession(backendURL string) (string, error) {
panic("CreateSession not implemented yet")
}
func (m *McpSessionManager) GetSession(sessionID string) (interface{}, bool) {
panic("GetSession not implemented yet")
}
func (m *McpSessionManager) CleanupSession(sessionID string) {
panic("CleanupSession not implemented yet")
}
type McpVersionNegotiator struct {
supportedVersions []string
}
func NewMcpVersionNegotiator(versions []string) *McpVersionNegotiator {
panic("McpVersionNegotiator not implemented yet")
}
func (n *McpVersionNegotiator) NegotiateVersion(requested string) (string, error) {
panic("NegotiateVersion not implemented yet")
}
type McpRequest struct {
JsonRPC string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
type McpErrorResponse struct {
JsonRPC string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Error *McpError `json:"error"`
}
type McpError struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]interface{} `json:"data,omitempty"`
}
func CreateInitializeRequest() *McpRequest {
panic("CreateInitializeRequest not implemented yet")
}
func CreateInitializedNotification() *McpRequest {
panic("CreateInitializedNotification not implemented yet")
}
func CreateMcpErrorResponse(errorType string, originalError error, backendURL string) *McpErrorResponse {
panic("CreateMcpErrorResponse not implemented yet")
}

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// McpProxyConfig represents the configuration for MCP proxy server
// Note: mcpServerURL, timeout, defaultDownstreamSecurity, and defaultUpstreamSecurity
// are now direct server fields, not part of this config structure
type McpProxyConfig struct {
// This structure is kept for any additional server configuration that may be needed in the future
// Currently, most configuration is handled as direct server fields
}
// TransportProtocol represents the transport protocol type for MCP proxy
type TransportProtocol string
const (
TransportHTTP TransportProtocol = "http" // StreamableHTTP protocol
TransportSSE TransportProtocol = "sse" // SSE protocol
)
// ToolArg represents an argument for a proxy tool
type ToolArg struct {
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Required bool `json:"required"`
Default interface{} `json:"default,omitempty"`
Enum []interface{} `json:"enum,omitempty"`
}
// McpProxyToolConfig represents a tool configuration for MCP proxy
type McpProxyToolConfig struct {
Name string `json:"name"`
Description string `json:"description"`
Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server
Args []ToolArg `json:"args"`
OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18
RequestTemplate RequestTemplate `json:"requestTemplate,omitempty"`
}
// RequestTemplate defines request template configuration for proxy tools
type RequestTemplate struct {
Security SecurityRequirement `json:"security,omitempty"`
}
// McpProxyServer implements Server interface for MCP-to-MCP proxy
type McpProxyServer struct {
Name string
base BaseMCPServer
toolsConfig map[string]McpProxyToolConfig
securitySchemes map[string]SecurityScheme
defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication
defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication
mcpServerURL string // Backend MCP server URL
timeout int // Request timeout in milliseconds
transport TransportProtocol // Transport protocol (http or sse)
passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security
}
// NewMcpProxyServer creates a new MCP proxy server
func NewMcpProxyServer(name string) *McpProxyServer {
return &McpProxyServer{
Name: name,
base: NewBaseMCPServer(),
toolsConfig: make(map[string]McpProxyToolConfig),
securitySchemes: make(map[string]SecurityScheme),
}
}
// AddSecurityScheme adds a security scheme to the server's map
func (s *McpProxyServer) AddSecurityScheme(scheme SecurityScheme) {
if s.securitySchemes == nil {
s.securitySchemes = make(map[string]SecurityScheme)
}
s.securitySchemes[scheme.ID] = scheme
}
// GetSecurityScheme retrieves a security scheme by its ID from the map
func (s *McpProxyServer) GetSecurityScheme(id string) (SecurityScheme, bool) {
scheme, ok := s.securitySchemes[id]
return scheme, ok
}
// SetDefaultDownstreamSecurity sets the default downstream security configuration
func (s *McpProxyServer) SetDefaultDownstreamSecurity(security SecurityRequirement) {
s.defaultDownstreamSecurity = security
}
// GetDefaultDownstreamSecurity gets the default downstream security configuration
func (s *McpProxyServer) GetDefaultDownstreamSecurity() SecurityRequirement {
return s.defaultDownstreamSecurity
}
// SetDefaultUpstreamSecurity sets the default upstream security configuration
func (s *McpProxyServer) SetDefaultUpstreamSecurity(security SecurityRequirement) {
s.defaultUpstreamSecurity = security
}
// GetDefaultUpstreamSecurity gets the default upstream security configuration
func (s *McpProxyServer) GetDefaultUpstreamSecurity() SecurityRequirement {
return s.defaultUpstreamSecurity
}
// SetMcpServerURL sets the backend MCP server URL
func (s *McpProxyServer) SetMcpServerURL(url string) {
s.mcpServerURL = url
}
// GetMcpServerURL gets the backend MCP server URL
func (s *McpProxyServer) GetMcpServerURL() string {
return s.mcpServerURL
}
// SetTimeout sets the request timeout in milliseconds
func (s *McpProxyServer) SetTimeout(timeout int) {
s.timeout = timeout
}
// GetTimeout gets the request timeout in milliseconds
func (s *McpProxyServer) GetTimeout() int {
return s.timeout
}
// SetTransport sets the transport protocol
func (s *McpProxyServer) SetTransport(transport TransportProtocol) {
s.transport = transport
}
// GetTransport gets the transport protocol
func (s *McpProxyServer) GetTransport() TransportProtocol {
return s.transport
}
// AddMCPTool implements Server interface
func (s *McpProxyServer) AddMCPTool(name string, tool Tool) Server {
s.base.AddMCPTool(name, tool)
return s
}
// AddProxyTool adds a proxy tool configuration
func (s *McpProxyServer) AddProxyTool(toolConfig McpProxyToolConfig) error {
s.toolsConfig[toolConfig.Name] = toolConfig
s.base.AddMCPTool(toolConfig.Name, &McpProxyTool{
serverName: s.Name,
name: toolConfig.Name,
toolConfig: toolConfig,
})
return nil
}
// GetMCPTools implements Server interface
func (s *McpProxyServer) GetMCPTools() map[string]Tool {
return s.base.GetMCPTools()
}
// SetConfig implements Server interface
func (s *McpProxyServer) SetConfig(config []byte) {
s.base.SetConfig(config)
}
// GetConfig implements Server interface
func (s *McpProxyServer) GetConfig(v any) {
s.base.GetConfig(v)
}
// Clone implements Server interface
func (s *McpProxyServer) Clone() Server {
newServer := &McpProxyServer{
Name: s.Name,
base: s.base.CloneBase(),
toolsConfig: make(map[string]McpProxyToolConfig),
securitySchemes: make(map[string]SecurityScheme),
}
for k, v := range s.toolsConfig {
newServer.toolsConfig[k] = v
}
// Deep copy securitySchemes
if s.securitySchemes != nil {
for k, v := range s.securitySchemes {
newServer.securitySchemes[k] = v
}
}
return newServer
}
// GetToolConfig returns the proxy tool configuration for a given tool name
func (s *McpProxyServer) GetToolConfig(name string) (McpProxyToolConfig, bool) {
config, ok := s.toolsConfig[name]
return config, ok
}
// SetPassthroughAuthHeader sets the passthrough auth header flag
func (s *McpProxyServer) SetPassthroughAuthHeader(passthrough bool) {
s.passthroughAuthHeader = passthrough
}
// GetPassthroughAuthHeader gets the passthrough auth header flag
func (s *McpProxyServer) GetPassthroughAuthHeader() bool {
return s.passthroughAuthHeader
}
// ForwardToolsList forwards tools/list request to backend MCP server
func (s *McpProxyServer) ForwardToolsList(ctx HttpContext, cursor *string) error {
wrapperCtx := ctx.(wrapper.HttpContext)
// Handle default downstream security for tools/list requests
// tools/list requests use server-level default authentication configuration
passthroughCredential := ""
downstreamSecurity := s.GetDefaultDownstreamSecurity()
if downstreamSecurity.ID != "" {
clientScheme, schemeOk := s.GetSecurityScheme(downstreamSecurity.ID)
if !schemeOk {
log.Warnf("Default downstream security scheme ID '%s' not found for tools/list request.", downstreamSecurity.ID)
} else {
// Extract and remove the credential from the incoming request
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
if err != nil {
log.Warnf("Failed to extract/remove incoming credential for tools/list using scheme %s: %v", clientScheme.ID, err)
} else if extractedCred == "" {
log.Debugf("No incoming credential found for tools/list using scheme %s for extraction/removal.", clientScheme.ID)
}
// Only use passthrough if explicitly configured
if downstreamSecurity.Passthrough && extractedCred != "" {
passthroughCredential = extractedCred
log.Debugf("Passthrough credential set for tools/list request.")
}
}
} else {
// Fallback: Remove Authorization header if no downstream security is defined
// This prevents downstream credentials from being mistakenly passed to upstream
// Unless passthroughAuthHeader is explicitly set to true
if !s.GetPassthroughAuthHeader() {
proxywasm.RemoveHttpRequestHeader("Authorization")
}
}
// Create protocol handler using server fields
handler := NewMcpProtocolHandler(s.GetMcpServerURL(), s.GetTimeout())
// Prepare authentication information for gateway-to-backend communication
var authInfo *ProxyAuthInfo
upstreamSecurity := s.GetDefaultUpstreamSecurity()
if upstreamSecurity.ID != "" {
authInfo = &ProxyAuthInfo{
SecuritySchemeID: upstreamSecurity.ID,
PassthroughCredential: passthroughCredential,
Server: s,
}
}
// This will handle initialization asynchronously if needed and use ActionPause/Resume
return handler.ForwardToolsList(wrapperCtx, cursor, authInfo)
}
// McpProxyTool implements Tool interface for MCP-to-MCP proxy
type McpProxyTool struct {
serverName string
name string
toolConfig McpProxyToolConfig
arguments map[string]interface{}
}
// Create implements Tool interface
func (t *McpProxyTool) Create(params []byte) Tool {
newTool := &McpProxyTool{
serverName: t.serverName,
name: t.name,
toolConfig: t.toolConfig,
arguments: make(map[string]interface{}),
}
if len(params) > 0 {
json.Unmarshal(params, &newTool.arguments)
}
return newTool
}
// Call implements Tool interface - this is where the MCP protocol handling happens
func (t *McpProxyTool) Call(httpCtx HttpContext, server Server) error {
ctx := httpCtx.(wrapper.HttpContext)
// Get proxy server instance to access configuration
proxyServer, ok := server.(*McpProxyServer)
if !ok {
return fmt.Errorf("server is not a McpProxyServer")
}
// Handle tool-level or default downstream security: extract credential for passthrough if configured
// toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity
passthroughCredential := ""
var downstreamSecurity SecurityRequirement
if t.toolConfig.Security.ID != "" {
// Use tool-level security if configured
downstreamSecurity = t.toolConfig.Security
log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
} else {
// Fall back to server's default downstream security
downstreamSecurity = proxyServer.GetDefaultDownstreamSecurity()
if downstreamSecurity.ID != "" {
log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
}
}
if downstreamSecurity.ID != "" {
clientScheme, schemeOk := proxyServer.GetSecurityScheme(downstreamSecurity.ID)
if !schemeOk {
log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name)
} else {
// Extract and remove the credential from the incoming request
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
if err != nil {
log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err)
} else if extractedCred == "" {
log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID)
}
// Only use passthrough if explicitly configured
if downstreamSecurity.Passthrough && extractedCred != "" {
passthroughCredential = extractedCred
log.Debugf("Passthrough credential set for tool %s.", t.name)
}
}
} else {
// Fallback: Remove Authorization header if no downstream security is defined
// This prevents downstream credentials from being mistakenly passed to upstream
// Unless passthroughAuthHeader is explicitly set to true
if !proxyServer.GetPassthroughAuthHeader() {
proxywasm.RemoveHttpRequestHeader("Authorization")
}
}
// Create protocol handler using server fields
handler := NewMcpProtocolHandler(proxyServer.GetMcpServerURL(), proxyServer.GetTimeout())
// Prepare authentication information for gateway-to-backend communication
// toolConfig.RequestTemplate.Security represents gateway-to-backend authentication, falls back to server's defaultUpstreamSecurity
var authInfo *ProxyAuthInfo
var upstreamSecurity SecurityRequirement
if t.toolConfig.RequestTemplate.Security.ID != "" {
// Use tool-level upstream security if configured
upstreamSecurity = t.toolConfig.RequestTemplate.Security
log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
} else {
// Fall back to server's default upstream security
upstreamSecurity = proxyServer.GetDefaultUpstreamSecurity()
if upstreamSecurity.ID != "" {
log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
}
}
if upstreamSecurity.ID != "" {
authInfo = &ProxyAuthInfo{
SecuritySchemeID: upstreamSecurity.ID,
PassthroughCredential: passthroughCredential,
Server: proxyServer,
}
}
// This will handle initialization asynchronously if needed and use ActionPause/Resume
return handler.ForwardToolsCall(ctx, t.name, t.arguments, authInfo)
}
// Description implements Tool interface
func (t *McpProxyTool) Description() string {
return t.toolConfig.Description
}
// InputSchema implements Tool interface
func (t *McpProxyTool) InputSchema() map[string]any {
schema := map[string]any{
"type": "object",
"properties": make(map[string]any),
"required": []string{},
}
properties := schema["properties"].(map[string]any)
var required []string
for _, arg := range t.toolConfig.Args {
argSchema := map[string]any{
"type": arg.Type,
"description": arg.Description,
}
if arg.Default != nil {
argSchema["default"] = arg.Default
}
if len(arg.Enum) > 0 {
argSchema["enum"] = arg.Enum
}
properties[arg.Name] = argSchema
if arg.Required {
required = append(required, arg.Name)
}
}
schema["required"] = required
return schema
}
// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18)
func (t *McpProxyTool) OutputSchema() map[string]any {
return t.toolConfig.OutputSchema
}
// ValidateSecurityScheme validates a security scheme configuration
func ValidateSecurityScheme(scheme SecurityScheme) error {
if scheme.ID == "" {
return fmt.Errorf("security scheme ID is required")
}
if scheme.Type != "apiKey" && scheme.Type != "http" {
return fmt.Errorf("invalid security scheme type: %s", scheme.Type)
}
if scheme.Type == "apiKey" {
if scheme.Name == "" {
return fmt.Errorf("security scheme name is required for apiKey type")
}
if scheme.In != "header" && scheme.In != "query" && scheme.In != "cookie" {
return fmt.Errorf("invalid security scheme location: %s", scheme.In)
}
}
if scheme.Type == "http" {
if scheme.Scheme == "" {
return fmt.Errorf("security scheme scheme is required for http type")
}
}
return nil
}
// ValidateToolConfig validates a tool configuration
func ValidateToolConfig(config McpProxyToolConfig) error {
if config.Name == "" {
return fmt.Errorf("tool name is required")
}
if config.Description == "" {
return fmt.Errorf("tool description is required")
}
// Validate arguments
argNames := make(map[string]bool)
for _, arg := range config.Args {
if arg.Name == "" {
return fmt.Errorf("argument name is required")
}
if argNames[arg.Name] {
return fmt.Errorf("duplicate argument name: %s", arg.Name)
}
argNames[arg.Name] = true
if arg.Description == "" {
return fmt.Errorf("argument description is required for %s", arg.Name)
}
validTypes := []string{"string", "number", "integer", "boolean", "array", "object"}
validType := false
for _, t := range validTypes {
if arg.Type == t {
validType = true
break
}
}
if !validType {
return fmt.Errorf("invalid argument type %s for %s", arg.Type, arg.Name)
}
}
return nil
}

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestMcpProxyServerBasicInterface tests that McpProxyServer implements the Server interface
func TestMcpProxyServerBasicInterface(t *testing.T) {
// This test will fail until we implement McpProxyServer
server := NewMcpProxyServer("test-proxy")
// Test Server interface implementation
assert.NotNil(t, server)
assert.Equal(t, "test-proxy", server.Name)
// Test that it implements all required methods
tools := server.GetMCPTools()
assert.NotNil(t, tools)
assert.Equal(t, 0, len(tools))
// Test Clone method
cloned := server.Clone()
assert.NotNil(t, cloned)
}
// TestMcpProxyServerConfiguration tests configuration setting and getting
func TestMcpProxyServerConfiguration(t *testing.T) {
server := NewMcpProxyServer("test-proxy")
// Set server fields directly
server.SetMcpServerURL("http://backend.example.com/mcp")
server.SetTimeout(5000)
// Add security scheme
scheme := SecurityScheme{
ID: "test-auth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
}
server.AddSecurityScheme(scheme)
// Verify server fields
assert.Equal(t, "http://backend.example.com/mcp", server.GetMcpServerURL())
assert.Equal(t, 5000, server.GetTimeout())
// Verify security scheme
retrievedScheme, exists := server.GetSecurityScheme("test-auth")
assert.True(t, exists)
assert.Equal(t, "test-auth", retrievedScheme.ID)
assert.Equal(t, "apiKey", retrievedScheme.Type)
}
// TestMcpProxyServerAddTool tests adding proxy tools
func TestMcpProxyServerAddTool(t *testing.T) {
server := NewMcpProxyServer("test-proxy")
toolConfig := McpProxyToolConfig{
Name: "test-tool",
Description: "Test tool for proxy",
Args: []ToolArg{
{
Name: "input",
Description: "Test input",
Type: "string",
Required: true,
},
},
}
err := server.AddProxyTool(toolConfig)
assert.NoError(t, err)
tools := server.GetMCPTools()
assert.Len(t, tools, 1)
assert.Contains(t, tools, "test-tool")
}
// TestMcpProxyServerSecuritySchemes tests security scheme management
func TestMcpProxyServerSecuritySchemes(t *testing.T) {
server := NewMcpProxyServer("test-proxy")
scheme := SecurityScheme{
ID: "test-auth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
}
server.AddSecurityScheme(scheme)
retrievedScheme, exists := server.GetSecurityScheme("test-auth")
assert.True(t, exists)
assert.Equal(t, scheme.ID, retrievedScheme.ID)
assert.Equal(t, scheme.Type, retrievedScheme.Type)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,485 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestToolsListForwarding tests the tools/list request forwarding
func TestToolsListForwarding(t *testing.T) {
// Create proxy server with tools
server := NewMcpProxyServer("tools-list-test")
// Set server fields directly
server.SetMcpServerURL("http://backend.example.com/mcp")
server.SetTimeout(5000)
// Add test tools
toolConfigs := []McpProxyToolConfig{
{
Name: "get_weather",
Description: "Get weather information",
Args: []ToolArg{
{
Name: "location",
Description: "City name",
Type: "string",
Required: true,
},
},
},
{
Name: "get_news",
Description: "Get latest news",
Args: []ToolArg{
{
Name: "category",
Description: "News category",
Type: "string",
Required: false,
},
},
},
}
for _, toolConfig := range toolConfigs {
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
}
// Skip HttpContext-dependent test for now - will be tested in integration
// Test that tools were added to server successfully
tools := server.GetMCPTools()
assert.Len(t, tools, 2)
assert.Contains(t, tools, "get_weather")
assert.Contains(t, tools, "get_news")
}
// TestToolsCallForwarding tests the tools/call request forwarding
func TestToolsCallForwarding(t *testing.T) {
server := NewMcpProxyServer("tools-call-test")
// Set server fields directly
server.SetMcpServerURL("http://backend.example.com/mcp")
server.SetTimeout(5000)
// Add test tool
toolConfig := McpProxyToolConfig{
Name: "test_tool",
Description: "Test tool for call forwarding",
Args: []ToolArg{
{
Name: "input",
Description: "Input parameter",
Type: "string",
Required: true,
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
// Get the tool and create instance
tool, exists := server.GetMCPTools()["test_tool"]
require.True(t, exists)
params := map[string]interface{}{
"input": "test value",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Skip HttpContext-dependent test for now - will be tested in integration
// Test tool instance creation was successful
assert.NotNil(t, toolInstance)
assert.Equal(t, "test_tool", toolInstance.(*McpProxyTool).name)
assert.Equal(t, "test value", toolInstance.(*McpProxyTool).arguments["input"])
}
// TestToolsCallWithParameters tests tool call with various parameter types
func TestToolsCallWithParameters(t *testing.T) {
tests := []struct {
name string
toolConfig McpProxyToolConfig
params map[string]interface{}
shouldErr bool
}{
{
name: "string parameter",
toolConfig: McpProxyToolConfig{
Name: "string_tool",
Description: "Tool with string parameter",
Args: []ToolArg{
{
Name: "text",
Description: "Text input",
Type: "string",
Required: true,
},
},
},
params: map[string]interface{}{
"text": "hello world",
},
shouldErr: false,
},
{
name: "number parameter",
toolConfig: McpProxyToolConfig{
Name: "number_tool",
Description: "Tool with number parameter",
Args: []ToolArg{
{
Name: "value",
Description: "Numeric value",
Type: "number",
Required: true,
},
},
},
params: map[string]interface{}{
"value": 42.5,
},
shouldErr: false,
},
{
name: "object parameter",
toolConfig: McpProxyToolConfig{
Name: "object_tool",
Description: "Tool with object parameter",
Args: []ToolArg{
{
Name: "data",
Description: "Object data",
Type: "object",
Required: true,
},
},
},
params: map[string]interface{}{
"data": map[string]interface{}{
"key1": "value1",
"key2": 123,
},
},
shouldErr: false,
},
{
name: "missing required parameter",
toolConfig: McpProxyToolConfig{
Name: "required_tool",
Description: "Tool with required parameter",
Args: []ToolArg{
{
Name: "required_param",
Description: "Required parameter",
Type: "string",
Required: true,
},
},
},
params: map[string]interface{}{},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewMcpProxyServer("param-test")
// Set server fields directly
server.SetMcpServerURL("http://backend.example.com/mcp")
server.SetTimeout(5000)
err := server.AddProxyTool(tt.toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()[tt.toolConfig.Name]
require.True(t, exists)
paramsBytes, err := json.Marshal(tt.params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Skip HttpContext-dependent test for now - will be tested in integration
// Test tool instance creation
assert.NotNil(t, toolInstance)
if !tt.shouldErr {
assert.Equal(t, tt.toolConfig.Name, toolInstance.(*McpProxyTool).name)
}
})
}
}
// TestToolsCallWithCursor tests tools/list with pagination cursor
func TestToolsCallWithCursor(t *testing.T) {
server := NewMcpProxyServer("cursor-test")
// Set server fields directly
server.SetMcpServerURL("http://backend.example.com/mcp")
server.SetTimeout(5000)
// Skip HttpContext-dependent test for now - will be tested in integration
// Test cursor parameter handling logic (basic validation)
cursor := "page-2-cursor"
assert.NotNil(t, cursor)
assert.NotEmpty(t, cursor)
}
// TestBackendErrorHandling tests handling of backend MCP server errors
func TestBackendErrorHandling(t *testing.T) {
server := NewMcpProxyServer("error-test")
// Set server fields directly
server.SetMcpServerURL("http://failing-backend.example.com/mcp")
server.SetTimeout(5000)
toolConfig := McpProxyToolConfig{
Name: "failing_tool",
Description: "Tool that will fail on backend",
Args: []ToolArg{
{
Name: "input",
Description: "Input parameter",
Type: "string",
Required: true,
},
},
}
err := server.AddProxyTool(toolConfig)
require.NoError(t, err)
tool, exists := server.GetMCPTools()["failing_tool"]
require.True(t, exists)
params := map[string]interface{}{
"input": "test value",
}
paramsBytes, err := json.Marshal(params)
require.NoError(t, err)
toolInstance := tool.Create(paramsBytes)
require.NotNil(t, toolInstance)
// Skip HttpContext-dependent test for now - will be tested in integration
// Test tool instance creation for error scenario
assert.NotNil(t, toolInstance)
assert.Equal(t, "failing_tool", toolInstance.(*McpProxyTool).name)
}
// TestParseSSEResponse tests the SSE response parsing functionality
func TestParseSSEResponse(t *testing.T) {
tests := []struct {
name string
sseData string
expectedData string
shouldErr bool
}{
{
name: "valid SSE with JSON data",
sseData: `event: message
data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}}
`,
expectedData: `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}}`,
shouldErr: false,
},
{
name: "SSE with multiple lines",
sseData: `event: message
data: {"jsonrpc":"2.0","id":2,"result":{"success":true}}
event: close
data: {"jsonrpc":"2.0","method":"close"}
`,
expectedData: `{"jsonrpc":"2.0","id":2,"result":{"success":true}}`,
shouldErr: false,
},
{
name: "SSE with comments and empty lines",
sseData: `: This is a comment
event: message
data: {"jsonrpc":"2.0","id":3,"result":{"test":true}}
: Another comment
`,
expectedData: `{"jsonrpc":"2.0","id":3,"result":{"test":true}}`,
shouldErr: false,
},
{
name: "SSE with any data content",
sseData: `event: message
data: {invalid json}
`,
expectedData: `{invalid json}`,
shouldErr: false,
},
{
name: "SSE with no data field",
sseData: `event: message
id: 123
`,
shouldErr: true,
},
{
name: "empty SSE data",
sseData: ``,
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseSSEResponse([]byte(tt.sseData))
if tt.shouldErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tt.expectedData, string(result))
}
})
}
}
// TestIsBackendError tests detection of backend error responses
func TestIsBackendError(t *testing.T) {
tests := []struct {
name string
response string
expectError bool
expectErrType string
}{
{
name: "JSON-RPC 2.0 error with unknown tool",
response: `{
"jsonrpc": "2.0",
"id": 3,
"error": {
"code": -32602,
"message": "Unknown tool: invalid_tool_name"
}
}`,
expectError: true,
expectErrType: "jsonrpc_error",
},
{
name: "JSON-RPC 2.0 error with method not found",
response: `{
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": -32601,
"message": "Method not found"
}
}`,
expectError: true,
expectErrType: "jsonrpc_error",
},
{
name: "result.isError format",
response: `{
"jsonrpc": "2.0",
"id": 3,
"result": {
"isError": true,
"content": [
{
"type": "text",
"text": "Tool execution failed: connection timeout"
}
]
}
}`,
expectError: true,
expectErrType: "result_isError",
},
{
name: "successful response with result",
response: `{
"jsonrpc": "2.0",
"id": 3,
"result": {
"content": [
{
"type": "text",
"text": "Success!"
}
]
}
}`,
expectError: false,
expectErrType: "",
},
{
name: "successful response with isError false",
response: `{
"jsonrpc": "2.0",
"id": 3,
"result": {
"isError": false,
"content": [
{
"type": "text",
"text": "Success!"
}
]
}
}`,
expectError: false,
expectErrType: "",
},
{
name: "invalid JSON",
response: `{invalid json}`,
expectError: false,
expectErrType: "",
},
{
name: "empty response",
response: `{}`,
expectError: false,
expectErrType: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isError, errType := IsBackendError([]byte(tt.response))
assert.Equal(t, tt.expectError, isError, "isError mismatch")
assert.Equal(t, tt.expectErrType, errType, "error type mismatch")
})
}
}
// ForwardToolsList is now implemented in proxy_server.go

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,922 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tidwall/sjson"
)
func TestConvertArgToString(t *testing.T) {
tests := []struct {
name string
input interface{}
expected string
}{
{
name: "string value",
input: "test string",
expected: "test string",
},
{
name: "boolean true",
input: true,
expected: "true",
},
{
name: "boolean false",
input: false,
expected: "false",
},
{
name: "integer",
input: 42,
expected: "42",
},
{
name: "float",
input: 3.14,
expected: "3.14",
},
{
name: "map",
input: map[string]interface{}{"key": "value"},
expected: `{"key":"value"}`,
},
{
name: "array",
input: []interface{}{1, 2, 3},
expected: "[1,2,3]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertArgToString(tt.input)
if result != tt.expected {
t.Errorf("convertArgToString(%v) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
func TestResponseTemplatePrependAppend(t *testing.T) {
// Test response template with PrependBody and AppendBody
sampleResponse := `{"result": "success", "data": {"name": "Test", "value": 42}}`
tests := []struct {
name string
template RestToolResponseTemplate
expected []string
notExpected []string
}{
{
name: "with body template only",
template: RestToolResponseTemplate{
Body: "# Result\n- Name: {{.data.name}}\n- Value: {{.data.value}}",
},
expected: []string{
"# Result",
"- Name: Test",
"- Value: 42",
},
notExpected: []string{
"Field Descriptions:",
"End of Response",
`{"result": "success"`,
},
},
{
name: "with prepend only",
template: RestToolResponseTemplate{
PrependBody: "# Field Descriptions:\n- result: Operation result\n- data: Response data\n\n",
},
expected: []string{
"# Field Descriptions:",
"- result: Operation result",
"- data: Response data",
`{"result": "success"`,
`"name": "Test"`,
},
},
{
name: "with append only",
template: RestToolResponseTemplate{
AppendBody: "\n\n*End of Response*",
},
expected: []string{
`{"result": "success"`,
`"name": "Test"`,
"*End of Response*",
},
},
{
name: "with both prepend and append",
template: RestToolResponseTemplate{
PrependBody: "# API Response:\n\n",
AppendBody: "\n\n*This is raw JSON data with field 'name' = Test and 'value' = 42*",
},
expected: []string{
"# API Response:",
`{"result": "success"`,
`"name": "Test"`,
"*This is raw JSON data with field 'name' = Test and 'value' = 42*",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a tool with the test template
// For tests with only prepend/append (no body), add a RequestTemplate.URL
// to avoid direct response mode validation
tool := RestTool{
ResponseTemplate: tt.template,
}
if tt.template.Body == "" && (tt.template.PrependBody != "" || tt.template.AppendBody != "") {
tool.RequestTemplate.URL = "http://example.com/api"
}
// Parse templates
err := tool.parseTemplates()
if err != nil {
t.Fatalf("Failed to parse templates: %v", err)
}
// Simulate response processing
var result string
responseBody := []byte(sampleResponse)
// Case 1: Full response template is provided
if tool.parsedResponseTemplate != nil {
templateResult, err := executeTemplate(tool.parsedResponseTemplate, responseBody)
if err != nil {
t.Fatalf("Failed to execute response template: %v", err)
}
result = templateResult
} else {
// Case 2: No template, but prepend/append might be used
rawResponse := string(responseBody)
// Apply prepend/append if specified
if tool.ResponseTemplate.PrependBody != "" || tool.ResponseTemplate.AppendBody != "" {
result = tool.ResponseTemplate.PrependBody + rawResponse + tool.ResponseTemplate.AppendBody
} else {
// Case 3: No template and no prepend/append, just use raw response
result = rawResponse
}
}
// Check that the result contains expected substrings
for _, substr := range tt.expected {
if !strings.Contains(result, substr) {
t.Errorf("Expected substring not found: %s", substr)
}
}
// Check that the result does not contain unexpected substrings
for _, substr := range tt.notExpected {
if strings.Contains(result, substr) {
t.Errorf("Unexpected substring found: %s", substr)
}
}
})
}
}
func TestHasContentType(t *testing.T) {
tests := []struct {
name string
headers [][2]string
contentTypeStr string
expectedOutcome bool
}{
{
name: "exact match",
headers: [][2]string{
{"Content-Type", "application/json"},
},
contentTypeStr: "application/json",
expectedOutcome: true,
},
{
name: "case insensitive match",
headers: [][2]string{
{"content-type", "application/JSON"},
},
contentTypeStr: "application/json",
expectedOutcome: true,
},
{
name: "substring match",
headers: [][2]string{
{"Content-Type", "application/json; charset=utf-8"},
},
contentTypeStr: "application/json",
expectedOutcome: true,
},
{
name: "no match",
headers: [][2]string{
{"Content-Type", "text/plain"},
},
contentTypeStr: "application/json",
expectedOutcome: false,
},
{
name: "header not present",
headers: [][2]string{
{"Accept", "application/json"},
},
contentTypeStr: "application/json",
expectedOutcome: false,
},
{
name: "empty headers",
headers: [][2]string{},
contentTypeStr: "application/json",
expectedOutcome: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasContentType(tt.headers, tt.contentTypeStr)
if result != tt.expectedOutcome {
t.Errorf("hasContentType(%v, %v) = %v, want %v", tt.headers, tt.contentTypeStr, result, tt.expectedOutcome)
}
})
}
}
func TestRestToolValidation(t *testing.T) {
tests := []struct {
name string
tool RestTool
expectedError bool
}{
{
name: "valid tool with no args options",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
},
},
expectedError: false,
},
{
name: "valid tool with argsToJsonBody",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "POST",
ArgsToJsonBody: true,
},
},
expectedError: false,
},
{
name: "valid tool with argsToUrlParam",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
ArgsToUrlParam: true,
},
},
expectedError: false,
},
{
name: "valid tool with argsToFormBody",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "POST",
ArgsToFormBody: true,
},
},
expectedError: false,
},
{
name: "invalid tool with multiple args options",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "POST",
ArgsToJsonBody: true,
ArgsToFormBody: true,
},
},
expectedError: true,
},
{
name: "invalid tool with all args options",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "POST",
ArgsToJsonBody: true,
ArgsToUrlParam: true,
ArgsToFormBody: true,
},
},
expectedError: true,
},
{
name: "invalid tool with both Body and PrependBody",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
},
ResponseTemplate: RestToolResponseTemplate{
Body: "# Result\n{{.data}}",
PrependBody: "# Field Descriptions:\n",
},
},
expectedError: true,
},
{
name: "invalid tool with both Body and AppendBody",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
},
ResponseTemplate: RestToolResponseTemplate{
Body: "# Result\n{{.data}}",
AppendBody: "\n*End of response*",
},
},
expectedError: true,
},
{
name: "invalid tool with Body, PrependBody, and AppendBody",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
},
ResponseTemplate: RestToolResponseTemplate{
Body: "# Result\n{{.data}}",
PrependBody: "# Field Descriptions:\n",
AppendBody: "\n*End of response*",
},
},
expectedError: true,
},
{
name: "valid tool with PrependBody and AppendBody but no Body",
tool: RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "https://example.com",
Method: "GET",
},
ResponseTemplate: RestToolResponseTemplate{
PrependBody: "# Field Descriptions:\n",
AppendBody: "\n*End of response*",
},
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.tool.parseTemplates()
if (err != nil) != tt.expectedError {
t.Errorf("parseTemplates() error = %v, expectedError %v", err, tt.expectedError)
}
})
}
}
func TestInputSchemaWithComplexTypes(t *testing.T) {
// Create a tool with array and object type arguments
tool := RestMCPTool{
toolConfig: RestTool{
Args: []RestToolArg{
{
Name: "stringArg",
Description: "A string argument",
Type: "string",
},
{
Name: "arrayArg",
Description: "An array argument",
Type: "array",
Items: map[string]interface{}{
"type": "string",
},
},
{
Name: "objectArg",
Description: "An object argument",
Type: "object",
Properties: map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "Name property",
},
"age": map[string]interface{}{
"type": "integer",
"description": "Age property",
},
},
},
{
Name: "arrayOfObjects",
Description: "An array of objects",
Type: "array",
Items: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
},
"value": map[string]interface{}{
"type": "number",
},
},
},
},
},
},
}
schema := tool.InputSchema()
// Check schema structure
if schema["type"] != "object" {
t.Errorf("Expected schema type to be 'object', got %v", schema["type"])
}
properties, ok := schema["properties"].(map[string]interface{})
if !ok {
t.Fatalf("Expected properties to be a map, got %T", schema["properties"])
}
// Check individual property types
checkProperty := func(name, expectedType string) {
prop, ok := properties[name].(map[string]interface{})
if !ok {
t.Fatalf("Expected property %s to be a map, got %T", name, properties[name])
}
if prop["type"] != expectedType {
t.Errorf("Expected property %s type to be '%s', got %v", name, expectedType, prop["type"])
}
}
checkProperty("stringArg", "string")
checkProperty("arrayArg", "array")
checkProperty("objectArg", "object")
checkProperty("arrayOfObjects", "array")
// Check array items
arrayArg, _ := properties["arrayArg"].(map[string]interface{})
if arrayArg["items"] == nil {
t.Errorf("Expected arrayArg to have items property")
}
// Check object properties
objectArg, _ := properties["objectArg"].(map[string]interface{})
if objectArg["properties"] == nil {
t.Errorf("Expected objectArg to have properties property")
}
// Check array of objects
arrayOfObjects, _ := properties["arrayOfObjects"].(map[string]interface{})
items, ok := arrayOfObjects["items"].(map[string]interface{})
if !ok || items["type"] != "object" {
t.Errorf("Expected arrayOfObjects items to be of type object")
}
}
func TestArgsToUrlParamAndFormBody(t *testing.T) {
// Test argsToUrlParam
t.Run("argsToUrlParam", func(t *testing.T) {
args := map[string]interface{}{
"string": "value",
"int": 42,
"bool": true,
"array": []interface{}{1, 2, 3},
"object": map[string]interface{}{"key": "value"},
}
// Parse URL and add parameters
baseURL := "https://example.com/api"
parsedURL, _ := url.Parse(baseURL)
query := parsedURL.Query()
for key, value := range args {
query.Set(key, convertArgToString(value))
}
parsedURL.RawQuery = query.Encode()
result := parsedURL.String()
// Verify each parameter is in the URL
for key, value := range args {
strValue := convertArgToString(value)
encodedValue := url.QueryEscape(strValue)
paramStr := key + "=" + encodedValue
if !strings.Contains(result, paramStr) {
t.Errorf("URL parameter missing: %s", paramStr)
}
}
})
// Test argsToFormBody
t.Run("argsToFormBody", func(t *testing.T) {
args := map[string]interface{}{
"string": "value",
"int": 42,
"bool": true,
"array": []interface{}{1, 2, 3},
"object": map[string]interface{}{"key": "value"},
}
// Create form values
formValues := url.Values{}
for key, value := range args {
formValues.Set(key, convertArgToString(value))
}
formBody := formValues.Encode()
// Verify each parameter is in the form body
for key, value := range args {
strValue := convertArgToString(value)
encodedValue := url.QueryEscape(strValue)
paramStr := key + "=" + encodedValue
if !strings.Contains(formBody, paramStr) {
t.Errorf("Form body missing parameter: %s", paramStr)
}
}
})
}
func TestRestToolConfig(t *testing.T) {
// Example REST tool configuration
configJSON := `
{
"server": {
"name": "rest-amap-server",
"config": {
"apiKey": "xxxxx"
}
},
"tools": [
{
"name": "maps-geo",
"description": "将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标",
"args": [
{
"name": "address",
"description": "待解析的结构化地址信息",
"type": "string",
"required": true
},
{
"name": "city",
"description": "指定查询的城市",
"required": false
},
{
"name": "output",
"description": "输出格式",
"type": "string",
"enum": ["json", "xml"],
"default": "json"
},
{
"name": "options",
"description": "高级选项",
"type": "object",
"properties": {
"extensions": {
"type": "string",
"enum": ["base", "all"]
},
"batch": {
"type": "boolean"
}
}
},
{
"name": "batch_addresses",
"description": "批量地址",
"type": "array",
"items": {
"type": "string"
}
}
],
"requestTemplate": {
"url": "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp",
"method": "GET",
"headers": [
{
"key": "Content-Type",
"value": "application/json"
}
]
},
"responseTemplate": {
"body": "# 地理编码信息\n{{- range $index, $geo := .Geocodes }}\n## 地点 {{add $index 1}}\n\n- **国家**: {{ $geo.Country }}\n- **省份**: {{ $geo.Province }}\n- **城市**: {{ $geo.City }}\n- **城市代码**: {{ $geo.Citycode }}\n- **区/县**: {{ $geo.District }}\n- **街道**: {{ $geo.Street }}\n- **门牌号**: {{ $geo.Number }}\n- **行政编码**: {{ $geo.Adcode }}\n- **坐标**: {{ $geo.Location }}\n- **级别**: {{ $geo.Level }}\n{{- end }}"
}
}
]
}
`
// Parse the config to verify it's valid JSON
var configData map[string]interface{}
err := json.Unmarshal([]byte(configJSON), &configData)
if err != nil {
t.Fatalf("Invalid JSON config: %v", err)
}
// Example tool configuration
tool := RestTool{
Name: "maps-geo",
Description: "将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标",
Args: []RestToolArg{
{
Name: "address",
Description: "待解析的结构化地址信息",
Type: "string",
Required: true,
},
{
Name: "city",
Description: "指定查询的城市",
Required: false,
},
{
Name: "output",
Description: "输出格式",
Type: "string",
Enum: []interface{}{"json", "xml"},
Default: "json",
},
{
Name: "options",
Description: "高级选项",
Type: "object",
Properties: map[string]interface{}{
"extensions": map[string]interface{}{
"type": "string",
"enum": []interface{}{"base", "all"},
},
"batch": map[string]interface{}{
"type": "boolean",
},
},
},
{
Name: "batch_addresses",
Description: "批量地址",
Type: "array",
Items: map[string]interface{}{
"type": "string",
},
},
},
RequestTemplate: RestToolRequestTemplate{
URL: "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp",
Method: "GET",
Headers: []RestToolHeader{
{
Key: "Content-Type",
Value: "application/json",
},
},
},
ResponseTemplate: RestToolResponseTemplate{
Body: `# 地理编码信息
{{- range $index, $geo := .Geocodes }}
## 地点 {{add $index 1}}
- **国家**: {{ $geo.Country }}
- **省份**: {{ $geo.Province }}
- **城市**: {{ $geo.City }}
- **城市代码**: {{ $geo.Citycode }}
- **区/县**: {{ $geo.District }}
- **街道**: {{ $geo.Street }}
- **门牌号**: {{ $geo.Number }}
- **行政编码**: {{ $geo.Adcode }}
- **坐标**: {{ $geo.Location }}
- **级别**: {{ $geo.Level }}
{{- end }}`,
},
}
// Parse templates
err = tool.parseTemplates()
if err != nil {
t.Fatalf("Failed to parse templates: %v", err)
}
var templateData []byte
templateData, _ = sjson.SetBytes(templateData, "config", map[string]interface{}{"apiKey": "test-api-key"})
templateData, _ = sjson.SetBytes(templateData, "args", map[string]interface{}{
"address": "北京市朝阳区阜通东大街6号",
"city": "北京",
"output": "json",
})
// Test URL template
url, err := executeTemplate(tool.parsedURLTemplate, templateData)
if err != nil {
t.Fatalf("Failed to execute URL template: %v", err)
}
expectedURL := "https://restapi.amap.com/v3/geocode/geo?key=test-api-key&address=北京市朝阳区阜通东大街6号&city=北京&output=json&source=ts_mcp"
if url != expectedURL {
t.Errorf("URL template rendering failed. Expected: %s, Got: %s", expectedURL, url)
}
// Test InputSchema for complex types
mcpTool := &RestMCPTool{
toolConfig: tool,
}
schema := mcpTool.InputSchema()
properties := schema["properties"].(map[string]interface{})
// Check object type
options, ok := properties["options"].(map[string]interface{})
if !ok || options["type"] != "object" {
t.Errorf("Expected options to be of type object")
}
// Check array type
batchAddresses, ok := properties["batch_addresses"].(map[string]interface{})
if !ok || batchAddresses["type"] != "array" {
t.Errorf("Expected batch_addresses to be of type array")
}
// Test response template with sample data
sampleResponse := `
{"Geocodes": [
{
"Country": "中国",
"Province": "北京市",
"City": "北京市",
"Citycode": "010",
"District": "朝阳区",
"Street": "阜通东大街",
"Number": "6号",
"Adcode": "110105",
"Location": "116.483038,39.990633",
"Level": "门牌号",
}]}`
result, err := executeTemplate(tool.parsedResponseTemplate, []byte(sampleResponse))
if err != nil {
t.Fatalf("Failed to execute response template: %v", err)
}
// Just check that the result contains expected substrings
expectedSubstrings := []string{
"# 地理编码信息",
"## 地点 1",
"**国家**: 中国",
"**省份**: 北京市",
"**坐标**: 116.483038,39.990633",
}
for _, substr := range expectedSubstrings {
if !strings.Contains(result, substr) {
t.Errorf("Response template rendering failed. Expected substring not found: %s", substr)
}
}
}
// TestRestServerDefaultSecurity tests the default security configuration for REST MCP server
func TestRestServerDefaultSecurity(t *testing.T) {
server := NewRestMCPServer("test-rest-server")
// Add security schemes
defaultScheme := SecurityScheme{
ID: "DefaultAuth",
Type: "apiKey",
In: "header",
Name: "X-Default-Key",
DefaultCredential: "default-key",
}
toolScheme := SecurityScheme{
ID: "ToolAuth",
Type: "apiKey",
In: "header",
Name: "X-Tool-Key",
DefaultCredential: "tool-key",
}
server.AddSecurityScheme(defaultScheme)
server.AddSecurityScheme(toolScheme)
// Test setting default security directly on server
server.SetDefaultDownstreamSecurity(SecurityRequirement{
ID: "DefaultAuth",
Passthrough: false,
})
server.SetDefaultUpstreamSecurity(SecurityRequirement{
ID: "DefaultAuth",
})
// Verify default security settings
retrievedDownstream := server.GetDefaultDownstreamSecurity()
assert.Equal(t, "DefaultAuth", retrievedDownstream.ID)
assert.False(t, retrievedDownstream.Passthrough)
retrievedUpstream := server.GetDefaultUpstreamSecurity()
assert.Equal(t, "DefaultAuth", retrievedUpstream.ID)
t.Logf("REST server default security configuration test completed successfully")
}
// TestRestServerSecurityFallback tests the fallback mechanism from tool-level to default security
func TestRestServerSecurityFallback(t *testing.T) {
server := NewRestMCPServer("test-rest-server")
// Add security schemes
defaultScheme := SecurityScheme{
ID: "DefaultAuth",
Type: "apiKey",
In: "header",
Name: "X-Default-Key",
DefaultCredential: "default-key",
}
toolScheme := SecurityScheme{
ID: "ToolAuth",
Type: "apiKey",
In: "header",
Name: "X-Tool-Key",
DefaultCredential: "tool-key",
}
server.AddSecurityScheme(defaultScheme)
server.AddSecurityScheme(toolScheme)
// Test tool configuration with tool-level security (should use tool-level, not default)
toolConfigWithSecurity := RestTool{
Name: "secure_tool",
Description: "Tool with its own security",
Security: SecurityRequirement{
ID: "ToolAuth",
Passthrough: true,
},
RequestTemplate: RestToolRequestTemplate{
URL: "http://api.example.com/secure",
Method: "GET",
Security: SecurityRequirement{
ID: "ToolAuth",
},
},
}
// Test tool configuration without tool-level security (should fallback to default)
toolConfigWithoutSecurity := RestTool{
Name: "fallback_tool",
Description: "Tool that falls back to default security",
// No Security field configured, should use default
RequestTemplate: RestToolRequestTemplate{
URL: "http://api.example.com/fallback",
Method: "GET",
// No Security field configured, should use default
},
}
// Add tools to server
err := server.AddRestTool(toolConfigWithSecurity)
assert.NoError(t, err)
err = server.AddRestTool(toolConfigWithoutSecurity)
assert.NoError(t, err)
// Verify tools were added
tools := server.GetMCPTools()
assert.Contains(t, tools, "secure_tool")
assert.Contains(t, tools, "fallback_tool")
t.Logf("REST server security fallback test completed successfully")
}

View File

@@ -0,0 +1,874 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
// Context keys for SSE proxy state management
CtxSSEProxyState = "sse_proxy_state"
CtxSSEProxyEndpointURL = "sse_proxy_endpoint_url"
CtxSSEProxyBuffer = "sse_proxy_buffer"
CtxSSEProxyAuthInfo = "sse_proxy_auth_info"
CtxSSEProxyRequestBody = "sse_proxy_request_body"
CtxSSEProxyRequestID = "sse_proxy_request_id"
CtxSSEProxyFirstChunk = "sse_proxy_first_chunk"
CtxSSEProxyJsonRpcID = "sse_proxy_jsonrpc_id"
// SSE proxy state values
SSEStateWaitingEndpoint = "waiting_endpoint"
SSEStateWaitingInitResp = "waiting_init_resp"
SSEStateWaitingNotifyResp = "waiting_notify_resp"
SSEStateWaitingToolResp = "waiting_tool_resp"
// Buffer size limit: 100MB
MaxSSEBufferSize = 100 * 1024 * 1024
)
// injectSSEResponseSuccess injects a successful JSON-RPC response in streaming response body phase
func injectSSEResponseSuccess(ctx wrapper.HttpContext, result map[string]any) {
// Get JSON-RPC ID from context
jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID)
if jsonRpcIDRaw == nil {
log.Errorf("JSON-RPC ID not found in context for SSE response")
return
}
jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID)
var body []byte
var err error
if jsonRpcID.IsString {
body, err = json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": jsonRpcID.StringValue,
"result": result,
})
} else {
body, err = json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": jsonRpcID.IntValue,
"result": result,
})
}
if err != nil {
log.Errorf("Failed to marshal JSON-RPC success response: %v", err)
return
}
proxywasm.InjectEncodedDataToFilterChain(body, true)
}
// injectSSEResponseError injects an error JSON-RPC response in streaming response body phase
func injectSSEResponseError(ctx wrapper.HttpContext, err error, errorCode int) {
// Get JSON-RPC ID from context
jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID)
if jsonRpcIDRaw == nil {
log.Errorf("JSON-RPC ID not found in context for SSE error response")
return
}
jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID)
var body []byte
var marshalErr error
if jsonRpcID.IsString {
body, marshalErr = json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": jsonRpcID.StringValue,
"error": map[string]interface{}{
"code": errorCode,
"message": err.Error(),
},
})
} else {
body, marshalErr = json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": jsonRpcID.IntValue,
"error": map[string]interface{}{
"code": errorCode,
"message": err.Error(),
},
})
}
if marshalErr != nil {
log.Errorf("Failed to marshal JSON-RPC error response: %v", marshalErr)
return
}
proxywasm.InjectEncodedDataToFilterChain(body, true)
}
// SSEMessage represents a parsed SSE message
type SSEMessage struct {
Event string
Data string
ID string
}
// ParseSSEMessage parses SSE format data and returns complete messages
// Returns the parsed message and the remaining unparsed data
func ParseSSEMessage(data []byte) (*SSEMessage, []byte, error) {
scanner := bufio.NewScanner(bytes.NewReader(data))
// Set max token size to 32MB to handle large messages
maxTokenSize := 32 * 1024 * 1024 // 32MB
scanner.Buffer(make([]byte, 0, 64*1024), maxTokenSize)
msg := &SSEMessage{}
lineCount := 0
lastPos := 0
for scanner.Scan() {
line := scanner.Text()
lineCount++
lastPos += len(line) + 1 // +1 for newline
// Empty line indicates end of message
if strings.TrimSpace(line) == "" {
if msg.Event != "" || msg.Data != "" || msg.ID != "" {
// Found a complete message
return msg, data[lastPos:], nil
}
// Empty message, continue
continue
}
// Skip comment lines (lines starting with ':')
if strings.HasPrefix(line, ":") {
continue
}
// Parse field
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue
}
field := parts[0]
value := strings.TrimSpace(parts[1])
switch field {
case "event":
msg.Event = value
case "data":
if msg.Data != "" {
msg.Data += "\n" + value
} else {
msg.Data = value
}
case "id":
msg.ID = value
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
return nil, nil, fmt.Errorf("SSE message line exceeds maximum token size (32MB): %w", err)
}
return nil, nil, fmt.Errorf("error scanning SSE data: %v", err)
}
// No complete message found, return all data as remaining
return nil, data, nil
}
// ExtractEndpointURL extracts the endpoint URL from an SSE endpoint message
// It handles two cases:
// 1. endpointData is a full URL (e.g., http://example.com/sse) - return as-is
// 2. endpointData is a path - if baseURL has scheme and host, combine them; otherwise return the path as-is
func ExtractEndpointURL(endpointData string, baseURL string) (string, error) {
// Case 1: endpointData is a full URL
if strings.HasPrefix(endpointData, "http://") || strings.HasPrefix(endpointData, "https://") {
return endpointData, nil
}
// endpointData is a path
parsedBase, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %v", err)
}
// Case 2: baseURL has scheme and host, combine them
if parsedBase.Scheme != "" && parsedBase.Host != "" {
// Combine scheme, host, and the new path
// Ensure endpointData starts with "/"
if !strings.HasPrefix(endpointData, "/") {
endpointData = "/" + endpointData
}
result := parsedBase.Scheme + "://" + parsedBase.Host + endpointData
return result, nil
}
// Case 3: baseURL is also just a path, return endpointData as-is
return endpointData, nil
}
// sendSSEInitialize sends the initialize request for SSE protocol
func sendSSEInitialize(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error {
initRequest := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]interface{}{
"protocolVersion": "2025-03-26",
"capabilities": map[string]interface{}{
"roots": map[string]interface{}{
"listChanged": true,
},
"sampling": map[string]interface{}{},
"elicitation": map[string]interface{}{},
},
"clientInfo": map[string]interface{}{
"name": "Higress-mcp-proxy",
"title": "Higress MCP Proxy",
"version": "1.0.0",
},
},
}
requestBody, err := json.Marshal(initRequest)
if err != nil {
return fmt.Errorf("failed to marshal initialize request: %v", err)
}
// Copy headers from current request (now supported in response phase by Envoy)
finalHeaders := copyHeadersForSSERequest(ctx)
// Override required headers for SSE initialize
ensureHeader(&finalHeaders, "Content-Type", "application/json")
// Apply authentication to headers and URL
finalURL := endpointURL
if authInfo != nil && authInfo.SecuritySchemeID != "" {
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
if err != nil {
log.Errorf("Failed to apply authentication for SSE initialize: %v", err)
} else {
finalURL = modifiedURL
}
}
// Note: headers are already copied from the current request (which has server-level headers applied)
// via copyHeadersForSSERequest, so no need to apply them again
// Store state for tracking
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingInitResp)
ctx.SetContext(CtxSSEProxyRequestID, 1)
// Use RouteCluster client to send initialize request
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
timeout := uint32(proxyServer.GetTimeout())
if timeout == 0 {
timeout = 5000 // Default 5 seconds
}
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 && statusCode != 202 {
log.Errorf("SSE initialize request failed with status %d: %s", statusCode, string(responseBody))
// At this point, we're in streaming response phase, must use injectSSEResponseError
injectSSEResponseError(ctx, fmt.Errorf("SSE initialize failed with status %d", statusCode), utils.ErrInternalError)
return
}
log.Debugf("SSE initialize request sent successfully")
// The response will be received through SSE channel and processed in streaming response handler
// State has already been set to SSEStateWaitingInitResp before this POST request
// No need to change state here
}, timeout)
}
// sendSSENotification sends the notifications/initialized message for SSE protocol
func sendSSENotification(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error {
notification := map[string]interface{}{
"jsonrpc": "2.0",
"method": "notifications/initialized",
}
requestBody, err := json.Marshal(notification)
if err != nil {
return fmt.Errorf("failed to marshal notification: %v", err)
}
// Copy headers from current request (now supported in response phase by Envoy)
finalHeaders := copyHeadersForSSERequest(ctx)
// Override required headers for SSE notification
ensureHeader(&finalHeaders, "Content-Type", "application/json")
// Apply authentication to headers and URL
finalURL := endpointURL
if authInfo != nil && authInfo.SecuritySchemeID != "" {
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
if err != nil {
log.Errorf("Failed to apply authentication for SSE notification: %v", err)
} else {
finalURL = modifiedURL
}
}
// Note: headers are already copied from the current request (which has server-level headers applied)
// via copyHeadersForSSERequest, so no need to apply them again
// Store state for tracking
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingNotifyResp)
// Use RouteCluster client to send notification
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
timeout := uint32(proxyServer.GetTimeout())
if timeout == 0 {
timeout = 5000 // Default 5 seconds
}
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 && statusCode != 202 {
log.Warnf("SSE notification request failed with status %d: %s", statusCode, string(responseBody))
// Even if notification fails, we should try to continue
// Some servers may not strictly require notification success
}
log.Debugf("SSE notification sent successfully")
// Now we can send the actual tool request
// Get stored context
endpointURLRaw := ctx.GetContext(CtxSSEProxyEndpointURL)
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
requestBodyRaw := ctx.GetContext(CtxSSEProxyRequestBody)
if endpointURLRaw == nil || proxyServerRaw == nil || requestBodyRaw == nil {
log.Errorf("Missing context for sending tool request")
// At this point, we're in streaming response phase, must use injectSSEResponseError
injectSSEResponseError(ctx, fmt.Errorf("internal error: missing context"), utils.ErrInternalError)
return
}
endpointURL := endpointURLRaw.(string)
proxyServer := proxyServerRaw.(*McpProxyServer)
requestBody := requestBodyRaw.([]byte)
var authInfo *ProxyAuthInfo
if authInfoRaw != nil {
authInfo = authInfoRaw.(*ProxyAuthInfo)
}
// Parse to get request ID
reqID := gjson.GetBytes(requestBody, "id").Int()
if err := sendSSEToolRequest(ctx, endpointURL, authInfo, proxyServer, requestBody, int(reqID)); err != nil {
log.Errorf("Failed to send SSE tool request: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
}
}, timeout)
}
// sendSSEToolRequest sends the tools/list or tools/call request for SSE protocol
func sendSSEToolRequest(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer, requestBody []byte, requestID int) error {
// Copy headers from current request (now supported in response phase by Envoy)
finalHeaders := copyHeadersForSSERequest(ctx)
// Override required headers for SSE tool request
ensureHeader(&finalHeaders, "Content-Type", "application/json")
// Apply authentication to headers and URL
finalURL := endpointURL
if authInfo != nil && authInfo.SecuritySchemeID != "" {
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
if err != nil {
log.Errorf("Failed to apply authentication for SSE tool request: %v", err)
} else {
finalURL = modifiedURL
}
}
// Note: headers are already copied from the current request (which has server-level headers applied)
// via copyHeadersForSSERequest, so no need to apply them again
// Store state for tracking
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingToolResp)
ctx.SetContext(CtxSSEProxyRequestID, requestID)
// Use RouteCluster client to send tool request
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
timeout := uint32(proxyServer.GetTimeout())
if timeout == 0 {
timeout = 5000 // Default 5 seconds
}
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 && statusCode != 202 {
log.Errorf("SSE tool request failed with status %d: %s", statusCode, string(responseBody))
// At this point, we're in streaming response phase, must use injectSSEResponseError
injectSSEResponseError(ctx, fmt.Errorf("SSE tool request failed with status %d", statusCode), utils.ErrInternalError)
return
}
log.Debugf("SSE tool request sent successfully")
// The response will be received through SSE channel and processed in streaming response handler
}, timeout)
}
// copyHeadersForSSERequest copies headers from current request for SSE RouteCluster calls
// This leverages Envoy's new capability to access request headers in response phase
func copyHeadersForSSERequest(ctx wrapper.HttpContext) [][2]string {
headers := make([][2]string, 0)
// Headers to skip
skipHeaders := map[string]bool{
"content-length": true, // Will be set by the client
"transfer-encoding": true, // Will be set by the client
"accept": true, // Will be set explicitly for SSE requests
":path": true, // Pseudo-header, not needed
":method": true, // Pseudo-header, not needed
":scheme": true, // Pseudo-header, not needed
":authority": true, // Pseudo-header, not needed
}
// Get all request headers (now supported in response phase by Envoy)
headerMap, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
log.Warnf("Failed to get request headers in response phase: %v", err)
// Return minimal headers
return [][2]string{}
}
// Copy headers, skipping unwanted ones
for _, header := range headerMap {
headerName := strings.ToLower(header[0])
if skipHeaders[headerName] {
continue
}
headers = append(headers, header)
}
log.Debugf("Copied %d headers from request in response phase for SSE", len(headers))
return headers
}
// applyProxyAuthenticationForSSE applies authentication for SSE proxy requests
func applyProxyAuthenticationForSSE(server *McpProxyServer, schemeID string, passthroughCredential string, headers *[][2]string, targetURL string) (string, error) {
// Parse the target URL
parsedURL, err := url.Parse(targetURL)
if err != nil {
return "", fmt.Errorf("failed to parse target URL: %v", err)
}
// Create authentication context
authCtx := AuthRequestContext{
Method: "POST",
Headers: *headers,
ParsedURL: parsedURL,
RequestBody: []byte{},
PassthroughCredential: passthroughCredential,
}
// Create security config
securityConfig := SecurityRequirement{
ID: schemeID,
Credential: "",
Passthrough: passthroughCredential != "",
}
// Apply authentication
err = ApplySecurity(securityConfig, server, &authCtx)
if err != nil {
return "", err
}
// Update headers
*headers = authCtx.Headers
// Reconstruct URL
u := authCtx.ParsedURL
encodedPath := u.EscapedPath()
var urlStr string
if u.Scheme != "" && u.Host != "" {
urlStr = u.Scheme + "://" + u.Host + encodedPath
} else {
urlStr = "/" + strings.TrimPrefix(encodedPath, "/")
}
if u.RawQuery != "" {
urlStr += "?" + u.RawQuery
}
if u.Fragment != "" {
urlStr += "#" + u.Fragment
}
return urlStr, nil
}
// handleSSEStreamingResponse handles the streaming SSE response
func handleSSEStreamingResponse(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte {
// Get the first chunk flag
isFirstChunk := ctx.GetBoolContext(CtxSSEProxyFirstChunk, true)
if isFirstChunk {
ctx.SetContext(CtxSSEProxyFirstChunk, false)
}
log.Debugf("Handling chunk of SSE response, data: %q", string(data))
// On first chunk, validate content-type and modify headers
if isFirstChunk {
// Validate that backend returned text/event-stream
contentType, err := proxywasm.GetHttpResponseHeader("content-type")
if err != nil || !strings.Contains(strings.ToLower(contentType), "text/event-stream") {
log.Errorf("Backend did not return text/event-stream content-type, got: %s", contentType)
// Return JSON-RPC error
injectSSEResponseError(ctx, fmt.Errorf("invalid content-type, expected text/event-stream but got: %s", contentType), utils.ErrInternalError)
return []byte{}
}
// Remove content-length and modify content-type
proxywasm.RemoveHttpResponseHeader("content-length")
proxywasm.ReplaceHttpResponseHeader("content-type", "application/json; charset=utf-8")
proxywasm.ReplaceHttpResponseHeader(":status", "200")
}
// Get or initialize buffer
var buffer []byte
if bufferRaw := ctx.GetContext(CtxSSEProxyBuffer); bufferRaw != nil {
buffer = bufferRaw.([]byte)
}
// Append new data to buffer
buffer = append(buffer, data...)
// Check buffer size limit
if len(buffer) > MaxSSEBufferSize {
log.Errorf("SSE buffer exceeded maximum size of %d bytes", MaxSSEBufferSize)
injectSSEResponseError(ctx, errors.New("response too large"), utils.ErrInternalError)
return []byte{}
}
// Store buffer back
ctx.SetContext(CtxSSEProxyBuffer, buffer)
// Get current state
state := ctx.GetContext(CtxSSEProxyState)
if state == nil {
state = SSEStateWaitingEndpoint
ctx.SetContext(CtxSSEProxyState, state)
}
log.Debugf("SSE proxy state: %s, now buffering data: %q", state.(string), string(buffer))
// Process based on state
switch state.(string) {
case SSEStateWaitingEndpoint:
return handleWaitingEndpoint(ctx, config, &buffer)
case SSEStateWaitingInitResp:
return handleWaitingInitResp(ctx, config, &buffer)
case SSEStateWaitingNotifyResp:
return handleWaitingNotifyResp(ctx, config, &buffer)
case SSEStateWaitingToolResp:
return handleWaitingToolResp(ctx, config, &buffer)
default:
log.Warnf("Unknown SSE proxy state: %v", state)
return []byte{}
}
}
// handleWaitingEndpoint processes SSE messages waiting for endpoint message
func handleWaitingEndpoint(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
for {
msg, remaining, err := ParseSSEMessage(*buffer)
if err != nil {
log.Errorf("Failed to parse SSE message: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
if msg == nil {
// No complete message yet
*buffer = remaining
return []byte{}
}
// Update buffer
*buffer = remaining
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
// Check for endpoint message
if msg.Event == "endpoint" {
// Extract and store endpoint URL
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
if proxyServerRaw == nil {
log.Errorf("mcp_proxy_server not found in context")
injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError)
return []byte{}
}
proxyServer := proxyServerRaw.(*McpProxyServer)
endpointURL, err := ExtractEndpointURL(msg.Data, proxyServer.GetMcpServerURL())
if err != nil {
log.Errorf("Failed to extract endpoint URL: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
log.Infof("Received SSE endpoint URL: %s", endpointURL)
ctx.SetContext(CtxSSEProxyEndpointURL, endpointURL)
// Get stored auth info
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
var authInfo *ProxyAuthInfo
if authInfoRaw != nil {
authInfo = authInfoRaw.(*ProxyAuthInfo)
}
// Send initialize request
if err := sendSSEInitialize(ctx, endpointURL, authInfo, proxyServer); err != nil {
log.Errorf("Failed to send SSE initialize: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
// State has been changed to SSEStateWaitingInitResp in sendSSEInitialize
// Return immediately to allow next chunk to be processed in the new state
return []byte{}
}
// Skip other message types (like ping) while waiting for endpoint
// Continue to process next message in buffer
log.Debugf("Skipping SSE message with event '%s' while waiting for endpoint", msg.Event)
continue
}
}
// handleWaitingInitResp processes SSE messages waiting for initialize response
func handleWaitingInitResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
requestID := ctx.GetContext(CtxSSEProxyRequestID)
if requestID == nil {
requestID = 1
}
for {
msg, remaining, err := ParseSSEMessage(*buffer)
if err != nil {
log.Errorf("Failed to parse SSE message: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
if msg == nil {
// No complete message yet
*buffer = remaining
return []byte{}
}
// Update buffer
*buffer = remaining
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
// Check for message event
if msg.Event == "message" {
// Parse JSON-RPC response
var jsonRpcResp map[string]interface{}
if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil {
log.Errorf("Failed to parse JSON-RPC response: %v", err)
continue
}
// Check if this is the initialize response
respID := jsonRpcResp["id"]
if respID != nil {
var idMatch bool
switch v := respID.(type) {
case float64:
idMatch = int(v) == requestID.(int)
case int:
idMatch = v == requestID.(int)
}
if idMatch {
// Check for errors
if errorObj, hasError := jsonRpcResp["error"]; hasError {
log.Errorf("Backend initialize error: %v", errorObj)
injectSSEResponseError(ctx, fmt.Errorf("backend initialize failed"), utils.ErrInternalError)
return []byte{}
}
log.Debugf("Received initialize response, sending notification")
// Get endpoint URL and auth info
endpointURL := ctx.GetContext(CtxSSEProxyEndpointURL).(string)
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
var authInfo *ProxyAuthInfo
if authInfoRaw != nil {
authInfo = authInfoRaw.(*ProxyAuthInfo)
}
proxyServer := proxyServerRaw.(*McpProxyServer)
// Send notification
// The notification callback will send the tool request after notification succeeds
if err := sendSSENotification(ctx, endpointURL, authInfo, proxyServer); err != nil {
log.Errorf("Failed to send SSE notification: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
// State has been changed to SSEStateWaitingNotifyResp in sendSSENotification
// The tool request will be sent in the notification callback
// Return immediately to allow next chunk to be processed in the new state
return []byte{}
}
}
}
// Skip other message types (like ping) while waiting for init response
// Continue to process next message in buffer
log.Debugf("Skipping SSE message with event '%s' while waiting for init response", msg.Event)
continue
}
}
// handleWaitingNotifyResp processes SSE messages waiting for notification response
func handleWaitingNotifyResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
// For notifications, we don't expect a response in SSE channel
// Just continue to send tool request
// This state should be very brief
return []byte{}
}
// handleWaitingToolResp processes SSE messages waiting for tool response
func handleWaitingToolResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
requestID := ctx.GetContext(CtxSSEProxyRequestID)
if requestID == nil {
log.Errorf("Request ID not found in context")
injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError)
return []byte{}
}
for {
msg, remaining, err := ParseSSEMessage(*buffer)
if err != nil {
log.Errorf("Failed to parse SSE message: %v", err)
injectSSEResponseError(ctx, err, utils.ErrInternalError)
return []byte{}
}
if msg == nil {
// No complete message yet
*buffer = remaining
return []byte{}
}
// Update buffer
*buffer = remaining
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
// Check for message event
if msg.Event == "message" {
// Parse JSON-RPC response
var jsonRpcResp map[string]interface{}
if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil {
log.Errorf("Failed to parse JSON-RPC response: %v", err)
continue
}
// Check if this is the expected response
respID := jsonRpcResp["id"]
if respID != nil {
var idMatch bool
switch v := respID.(type) {
case float64:
idMatch = int(v) == requestID.(int)
case int:
idMatch = v == requestID.(int)
}
if idMatch {
// Check for errors
if errorObj, hasError := jsonRpcResp["error"]; hasError {
log.Errorf("Backend tool error: %v", errorObj)
injectSSEResponseError(ctx, fmt.Errorf("backend tool call failed"), utils.ErrInternalError)
return []byte{}
}
// Extract result and return to client
if result, hasResult := jsonRpcResp["result"]; hasResult {
if resultMap, ok := result.(map[string]interface{}); ok {
// Apply allowTools filtering if this is a tools/list response
filteredResult := resultMap
if _, hasTools := resultMap["tools"]; hasTools {
// Get pre-computed effective allowTools from context
if allowToolsCtx := ctx.GetContext("mcp_proxy_effective_allow_tools"); allowToolsCtx != nil {
if effectiveAllowTools, ok := allowToolsCtx.(*map[string]struct{}); ok && effectiveAllowTools != nil {
// Apply filtering
if tools, hasToolsArray := resultMap["tools"]; hasToolsArray {
if toolsArray, ok := tools.([]interface{}); ok {
filteredTools := make([]interface{}, 0)
for _, tool := range toolsArray {
if toolMap, ok := tool.(map[string]interface{}); ok {
if name, hasName := toolMap["name"]; hasName {
if toolName, ok := name.(string); ok {
if _, allow := (*effectiveAllowTools)[toolName]; allow {
filteredTools = append(filteredTools, tool)
}
}
}
}
}
// Create filtered result
filteredResult = make(map[string]interface{})
for k, v := range resultMap {
filteredResult[k] = v
}
filteredResult["tools"] = filteredTools
}
}
}
}
}
injectSSEResponseSuccess(ctx, filteredResult)
// Clear buffer as we've processed the response
*buffer = []byte{}
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
return []byte{}
}
}
log.Errorf("Invalid tool response format")
injectSSEResponseError(ctx, errors.New("invalid response format"), utils.ErrInternalError)
return []byte{}
}
}
}
// Skip other message types (like ping) while waiting for tool response
// Continue to process next message in buffer
log.Debugf("Skipping SSE message with event '%s' while waiting for tool response", msg.Event)
continue
}
}

View File

@@ -0,0 +1,297 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"testing"
)
// TestParseSSEMessage tests SSE message parsing
func TestParseSSEMessage(t *testing.T) {
tests := []struct {
name string
input []byte
wantEvent string
wantData string
wantID string
shouldParse bool
}{
{
name: "endpoint message",
input: []byte(`event: endpoint
data: /messages/?session_id=test123
`),
wantEvent: "endpoint",
wantData: "/messages/?session_id=test123",
shouldParse: true,
},
{
name: "message with JSON data",
input: []byte(`event: message
data: {"jsonrpc":"2.0","id":1,"result":{"test":"value"}}
`),
wantEvent: "message",
wantData: `{"jsonrpc":"2.0","id":1,"result":{"test":"value"}}`,
shouldParse: true,
},
{
name: "incomplete message",
input: []byte(`event: message
data: {"jsonrpc":"2.0"`),
shouldParse: false,
},
{
name: "message with id",
input: []byte(`id: 123
event: message
data: test data
`),
wantEvent: "message",
wantData: "test data",
wantID: "123",
shouldParse: true,
},
{
name: "comment line ignored",
input: []byte(`: this is a comment
event: message
data: test data
`),
wantEvent: "message",
wantData: "test data",
shouldParse: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg, remaining, err := ParseSSEMessage(tt.input)
if err != nil {
t.Fatalf("parseSSEMessage() error = %v", err)
}
if tt.shouldParse {
if msg == nil {
t.Errorf("parseSSEMessage() expected message but got nil")
return
}
if msg.Event != tt.wantEvent {
t.Errorf("parseSSEMessage() Event = %v, want %v", msg.Event, tt.wantEvent)
}
if msg.Data != tt.wantData {
t.Errorf("parseSSEMessage() Data = %v, want %v", msg.Data, tt.wantData)
}
if msg.ID != tt.wantID {
t.Errorf("parseSSEMessage() ID = %v, want %v", msg.ID, tt.wantID)
}
if len(remaining) != 0 {
t.Errorf("parseSSEMessage() expected no remaining bytes, got %d bytes", len(remaining))
}
} else {
if msg != nil {
t.Errorf("parseSSEMessage() expected no message but got %v", msg)
}
if len(remaining) != len(tt.input) {
t.Errorf("parseSSEMessage() expected all data as remaining, got %d bytes instead of %d", len(remaining), len(tt.input))
}
}
})
}
}
// TestExtractEndpointURL tests endpoint URL extraction
func TestExtractEndpointURL(t *testing.T) {
tests := []struct {
name string
endpointData string
baseURL string
want string
wantErr bool
}{
{
name: "full URL",
endpointData: "http://example.com/messages?session=123",
baseURL: "http://backend.com/mcp",
want: "http://example.com/messages?session=123",
wantErr: false,
},
{
name: "path only",
endpointData: "/messages/?session_id=abc",
baseURL: "http://backend.com/mcp",
want: "http://backend.com/messages/?session_id=abc",
wantErr: false,
},
{
name: "https base URL",
endpointData: "/sse/endpoint",
baseURL: "https://secure.backend.com:8443/api",
want: "https://secure.backend.com:8443/sse/endpoint",
wantErr: false,
},
{
name: "path-only base URL",
endpointData: "/messages",
baseURL: "/api/v1",
want: "/messages",
wantErr: false,
},
{
name: "path without leading slash",
endpointData: "api/v1/messages",
baseURL: "http://backend.com",
want: "http://backend.com/api/v1/messages",
wantErr: false,
},
{
name: "path without leading slash with port",
endpointData: "sse/endpoint",
baseURL: "https://secure.backend.com:8443",
want: "https://secure.backend.com:8443/sse/endpoint",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ExtractEndpointURL(tt.endpointData, tt.baseURL)
if (err != nil) != tt.wantErr {
t.Errorf("extractEndpointURL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("extractEndpointURL() = %v, want %v", got, tt.want)
}
})
}
}
// TestTransportProtocolValidation tests transport protocol validation
func TestTransportProtocolValidation(t *testing.T) {
tests := []struct {
name string
transport string
wantValid bool
}{
{
name: "valid http transport",
transport: "http",
wantValid: true,
},
{
name: "valid sse transport",
transport: "sse",
wantValid: true,
},
{
name: "invalid transport",
transport: "websocket",
wantValid: false,
},
{
name: "empty transport",
transport: "",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := TransportProtocol(tt.transport)
isValid := transport == TransportHTTP || transport == TransportSSE
if isValid != tt.wantValid {
t.Errorf("TransportProtocol validation = %v, want %v for %s", isValid, tt.wantValid, tt.transport)
}
})
}
}
// TestMcpProxyServerTransport tests transport getter/setter
func TestMcpProxyServerTransport(t *testing.T) {
server := NewMcpProxyServer("test-server")
// Test default transport
if server.GetTransport() != "" {
t.Errorf("Expected empty default transport, got %v", server.GetTransport())
}
// Test setting HTTP transport
server.SetTransport(TransportHTTP)
if server.GetTransport() != TransportHTTP {
t.Errorf("Expected HTTP transport, got %v", server.GetTransport())
}
// Test setting SSE transport
server.SetTransport(TransportSSE)
if server.GetTransport() != TransportSSE {
t.Errorf("Expected SSE transport, got %v", server.GetTransport())
}
}
// TestSSEMessageParsing_MultipleMessages tests parsing multiple SSE messages
func TestSSEMessageParsing_MultipleMessages(t *testing.T) {
data := []byte(`event: endpoint
data: /messages/123
event: message
data: {"id":1}
: comment line
event: message
data: {"id":2}
`)
// First message
msg1, remaining, err := ParseSSEMessage(data)
if err != nil {
t.Fatalf("Failed to parse first message: %v", err)
}
if msg1 == nil || msg1.Event != "endpoint" || msg1.Data != "/messages/123" {
t.Errorf("First message incorrect: %+v", msg1)
}
// Second message
msg2, remaining, err := ParseSSEMessage(remaining)
if err != nil {
t.Fatalf("Failed to parse second message: %v", err)
}
if msg2 == nil || msg2.Event != "message" || msg2.Data != `{"id":1}` {
t.Errorf("Second message incorrect: %+v", msg2)
}
// Third message
msg3, remaining, err := ParseSSEMessage(remaining)
if err != nil {
t.Fatalf("Failed to parse third message: %v", err)
}
if msg3 == nil || msg3.Event != "message" || msg3.Data != `{"id":2}` {
t.Errorf("Third message incorrect: %+v", msg3)
}
// Should be no more complete messages
msg4, _, err := ParseSSEMessage(remaining)
if err != nil {
t.Fatalf("Error parsing remaining data: %v", err)
}
if msg4 != nil {
t.Errorf("Expected no more messages, got: %+v", msg4)
}
}

View File

@@ -0,0 +1,209 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"strconv"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"google.golang.org/protobuf/proto"
"github.com/higress-group/wasm-go/pkg/iface"
"github.com/higress-group/wasm-go/pkg/log"
pb "github.com/higress-group/wasm-go/pkg/protos"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
const (
CtxJsonRpcID = "jsonRpcID"
CtxNeedPause = "needPause" // Context key to signal if the handler needs to pause
JError = "error"
JCode = "code"
JMessage = "message"
JResult = "result"
ErrParseError = -32700
ErrInvalidRequest = -32600
ErrMethodNotFound = -32601
ErrInvalidParams = -32602
ErrInternalError = -32603
)
// JsonRpcID represents a JSON-RPC ID which can be either a string or a number
type JsonRpcID struct {
StringValue string
IntValue int64
IsString bool
}
// NewJsonRpcIDFromGjson creates a JsonRpcID from a gjson.Result
func NewJsonRpcIDFromGjson(result gjson.Result) JsonRpcID {
if result.Type == gjson.String {
return JsonRpcID{
StringValue: result.String(),
IsString: true,
}
}
return JsonRpcID{
IntValue: result.Int(),
IsString: false,
}
}
type JsonRpcRequestHandler func(context wrapper.HttpContext, id JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action
type JsonRpcResponseHandler func(context wrapper.HttpContext, id JsonRpcID, result gjson.Result, error gjson.Result, rawBody []byte) types.Action
type JsonRpcMethodHandler func(context wrapper.HttpContext, id JsonRpcID, params gjson.Result) error
type MethodHandlers map[string]JsonRpcMethodHandler
func makeHttpResponse(ctx wrapper.HttpContext, code uint32, debugInfo string, headers [][2]string, body []byte) {
phase := ctx.GetExecutionPhase()
if phase < iface.EncodeHeader {
proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1)
return
}
if debugInfo != "" {
log.Infof("response detail info:%s", debugInfo)
}
proxywasm.RemoveHttpResponseHeader("content-length")
proxywasm.ReplaceHttpResponseHeader(":status", strconv.Itoa(int(code)))
for _, kv := range headers {
proxywasm.ReplaceHttpResponseHeader(kv[0], kv[1])
}
if phase == iface.EncodeData {
proxywasm.ReplaceHttpResponseBody(body)
return
}
// EncodeHeader phase
args := &pb.InjectEncodedDataToFilterChainArguments{
Body: string(body),
Endstream: true,
}
argsStr, _ := proto.Marshal(args)
_, err := proxywasm.CallForeignFunction("inject_encoded_data_to_filter_chain_on_header", argsStr)
if err != nil {
log.Warnf("call inject_encoded_data_to_filter_chain_on_header failed, err:%v, fallback to send directly", err)
proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1)
return
}
}
func sendJsonRpcResponse(ctx wrapper.HttpContext, id JsonRpcID, extras map[string]any, debugInfo string) {
body := []byte(`{"jsonrpc": "2.0"}`)
if id.IsString {
body, _ = sjson.SetBytes(body, "id", id.StringValue)
} else {
body, _ = sjson.SetBytes(body, "id", id.IntValue)
}
for key, value := range extras {
body, _ = sjson.SetBytes(body, key, value)
}
makeHttpResponse(ctx, 200, debugInfo, [][2]string{{"Content-Type", "application/json; charset=utf-8"}}, body)
}
func OnJsonRpcResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo ...string) {
var (
id JsonRpcID
ok bool
)
idRaw := ctx.GetContext(CtxJsonRpcID)
if id, ok = idRaw.(JsonRpcID); !ok {
makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id"))
return
}
responseDebugInfo := "json_rpc_success"
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
sendJsonRpcResponse(ctx, id, map[string]any{JResult: result}, responseDebugInfo)
}
func OnJsonRpcResponseError(ctx wrapper.HttpContext, err error, errorCode int, debugInfo ...string) {
var (
id JsonRpcID
ok bool
)
idRaw := ctx.GetContext(CtxJsonRpcID)
if id, ok = idRaw.(JsonRpcID); !ok {
makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id"))
return
}
responseDebugInfo := fmt.Sprintf("json_rpc_error(%s)", err)
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
sendJsonRpcResponse(ctx, id, map[string]any{JError: map[string]any{
JMessage: err.Error(),
JCode: errorCode,
}}, responseDebugInfo)
}
func HandleJsonRpcMethod(ctx wrapper.HttpContext, body []byte, handles MethodHandlers) types.Action {
idResult := gjson.GetBytes(body, "id")
id := NewJsonRpcIDFromGjson(idResult)
ctx.SetContext(CtxJsonRpcID, id)
method := gjson.GetBytes(body, "method").String()
params := gjson.GetBytes(body, "params")
if method != "" {
if handle, ok := handles[method]; ok {
log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw)
// Clear pause flag before calling handler
ctx.SetContext(CtxNeedPause, false)
err := handle(ctx, id, params)
if err != nil {
OnJsonRpcResponseError(ctx, err, ErrInvalidRequest)
return types.ActionContinue
}
// Check if the handler set the pause flag
if needPause := ctx.GetContext(CtxNeedPause); needPause != nil && needPause.(bool) {
return types.ActionPause
}
return types.ActionContinue
}
OnJsonRpcResponseError(ctx, fmt.Errorf("method not found:%s", method), ErrMethodNotFound)
} else {
proxywasm.SendHttpResponseWithDetail(202, "json_rpc_ack", nil, nil, -1)
}
return types.ActionContinue
}
func HandleJsonRpcRequest(ctx wrapper.HttpContext, body []byte, handle JsonRpcRequestHandler) types.Action {
idResult := gjson.GetBytes(body, "id")
id := NewJsonRpcIDFromGjson(idResult)
ctx.SetContext(CtxJsonRpcID, id)
method := gjson.GetBytes(body, "method").String()
params := gjson.GetBytes(body, "params")
log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw)
return handle(ctx, id, method, params, body)
}
func HandleJsonRpcResponse(ctx wrapper.HttpContext, body []byte, handle JsonRpcResponseHandler) types.Action {
idResult := gjson.GetBytes(body, "id")
id := NewJsonRpcIDFromGjson(idResult)
error := gjson.GetBytes(body, "error")
result := gjson.GetBytes(body, "result")
log.Debugf("json rpc response error[%s] result[%s]", error.Raw, result.Raw)
return handle(ctx, id, result, error, body)
}

View File

@@ -0,0 +1,160 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"encoding/json"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestJsonRpcIDFromGjson(t *testing.T) {
tests := []struct {
name string
jsonData string
expected JsonRpcID
}{
{
name: "integer id",
jsonData: `{"id": 123}`,
expected: JsonRpcID{
IntValue: 123,
IsString: false,
},
},
{
name: "string id",
jsonData: `{"id": "abc-123"}`,
expected: JsonRpcID{
StringValue: "abc-123",
IsString: true,
},
},
{
name: "float id treated as int",
jsonData: `{"id": 123.45}`,
expected: JsonRpcID{
IntValue: 123,
IsString: false,
},
},
{
name: "boolean id treated as int",
jsonData: `{"id": true}`,
expected: JsonRpcID{
IntValue: 1,
IsString: false,
},
},
{
name: "null id treated as int zero",
jsonData: `{"id": null}`,
expected: JsonRpcID{
IntValue: 0,
IsString: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idResult := gjson.Get(tt.jsonData, "id")
result := NewJsonRpcIDFromGjson(idResult)
if result.IsString != tt.expected.IsString {
t.Errorf("IsString = %v, want %v", result.IsString, tt.expected.IsString)
}
if result.IsString {
if result.StringValue != tt.expected.StringValue {
t.Errorf("StringValue = %v, want %v", result.StringValue, tt.expected.StringValue)
}
} else {
if result.IntValue != tt.expected.IntValue {
t.Errorf("IntValue = %v, want %v", result.IntValue, tt.expected.IntValue)
}
}
})
}
}
// Skip TestSendJsonRpcResponse because it requires proxywasm which is not available in the test environment
// This function would normally test that sendJsonRpcResponse correctly handles different ID types
func TestSendJsonRpcResponse(t *testing.T) {
t.Skip("Skipping test that requires proxywasm")
}
func TestJsonRpcIDMarshaling(t *testing.T) {
// Test that JsonRpcID is correctly marshaled in a JSON response
tests := []struct {
name string
id JsonRpcID
expected string
}{
{
name: "integer id",
id: JsonRpcID{
IntValue: 123,
IsString: false,
},
expected: `"id":123`,
},
{
name: "string id",
id: JsonRpcID{
StringValue: "abc-123",
IsString: true,
},
expected: `"id":"abc-123"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a JSON object with the ID
var jsonObj map[string]interface{}
if tt.id.IsString {
jsonObj = map[string]interface{}{
"jsonrpc": "2.0",
"id": tt.id.StringValue,
}
} else {
jsonObj = map[string]interface{}{
"jsonrpc": "2.0",
"id": tt.id.IntValue,
}
}
// Marshal to JSON
body, err := json.Marshal(jsonObj)
if err != nil {
t.Errorf("Failed to marshal JSON: %v", err)
}
// Check that the ID is correctly marshaled
if !json.Valid(body) {
t.Errorf("Invalid JSON: %s", string(body))
}
// Check that the ID is correctly formatted
if !strings.Contains(string(body), tt.expected) {
t.Errorf("ID not correctly formatted. Expected to contain %s, got %s", tt.expected, string(body))
}
})
}
}

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
type MCPServerLog struct {
}
func setMCPInfo(msg string) string {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
requestID = "nil"
}
mcpServerNameRaw, _ := proxywasm.GetProperty([]string{"mcp_server_name"})
mcpServerName := string(mcpServerNameRaw)
mcpToolNameRaw, _ := proxywasm.GetProperty([]string{"mcp_tool_name"})
mcpToolName := string(mcpToolNameRaw)
mcpInfo := mcpServerName
if mcpToolName != "" {
mcpInfo = fmt.Sprintf("%s/%s", mcpServerName, mcpToolName)
}
return fmt.Sprintf("[mcp-server] [%s] [%s] %s", mcpInfo, requestID, msg)
}
func (l MCPServerLog) log(level wrapper.LogLevel, msg string) {
msg = setMCPInfo(msg)
switch level {
case wrapper.LogLevelTrace:
proxywasm.LogTrace(msg)
case wrapper.LogLevelDebug:
proxywasm.LogDebug(msg)
case wrapper.LogLevelInfo:
proxywasm.LogInfo(msg)
case wrapper.LogLevelWarn:
proxywasm.LogWarn(msg)
case wrapper.LogLevelError:
proxywasm.LogError(msg)
case wrapper.LogLevelCritical:
proxywasm.LogCritical(msg)
}
}
func (l MCPServerLog) logFormat(level wrapper.LogLevel, format string, args ...interface{}) {
format = setMCPInfo(format)
switch level {
case wrapper.LogLevelTrace:
proxywasm.LogTracef(format, args...)
case wrapper.LogLevelDebug:
proxywasm.LogDebugf(format, args...)
case wrapper.LogLevelInfo:
proxywasm.LogInfof(format, args...)
case wrapper.LogLevelWarn:
proxywasm.LogWarnf(format, args...)
case wrapper.LogLevelError:
proxywasm.LogErrorf(format, args...)
case wrapper.LogLevelCritical:
proxywasm.LogCriticalf(format, args...)
}
}
func (l MCPServerLog) Trace(msg string) {
l.log(wrapper.LogLevelTrace, msg)
}
func (l MCPServerLog) Tracef(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelTrace, format, args...)
}
func (l MCPServerLog) Debug(msg string) {
l.log(wrapper.LogLevelDebug, msg)
}
func (l MCPServerLog) Debugf(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelDebug, format, args...)
}
func (l MCPServerLog) Info(msg string) {
l.log(wrapper.LogLevelInfo, msg)
}
func (l MCPServerLog) Infof(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelInfo, format, args...)
}
func (l MCPServerLog) Warn(msg string) {
l.log(wrapper.LogLevelWarn, msg)
}
func (l MCPServerLog) Warnf(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelWarn, format, args...)
}
func (l MCPServerLog) Error(msg string) {
l.log(wrapper.LogLevelError, msg)
}
func (l MCPServerLog) Errorf(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelError, format, args...)
}
func (l MCPServerLog) Critical(msg string) {
l.log(wrapper.LogLevelCritical, msg)
}
func (l MCPServerLog) Criticalf(format string, args ...interface{}) {
l.logFormat(wrapper.LogLevelCritical, format, args...)
}
func (l MCPServerLog) ResetID(pluginID string) {
}

View File

@@ -0,0 +1,117 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnMCPResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo string) {
OnJsonRpcResponseSuccess(ctx, result, debugInfo)
// TODO: support pub to redis when use POST + SSE
}
func OnMCPResponseError(ctx wrapper.HttpContext, err error, code int, debugInfo string) {
OnJsonRpcResponseError(ctx, err, code, debugInfo)
// TODO: support pub to redis when use POST + SSE
}
func OnMCPToolCallSuccess(ctx wrapper.HttpContext, content []map[string]any, debugInfo string) {
OnMCPResponseSuccess(ctx, map[string]any{
"content": content,
"isError": false,
}, debugInfo)
}
// OnMCPToolCallSuccessWithStructuredContent sends a successful MCP tool response with structured content
// According to MCP spec, structuredContent is a field in tool results, not a capability
func OnMCPToolCallSuccessWithStructuredContent(ctx wrapper.HttpContext, content []map[string]any, structuredContent json.RawMessage, debugInfo string) {
response := map[string]any{
"content": content,
"isError": false,
}
if structuredContent != nil && len(structuredContent) > 0 {
response["structuredContent"] = structuredContent
}
OnMCPResponseSuccess(ctx, response, debugInfo)
}
func OnMCPToolCallError(ctx wrapper.HttpContext, err error, debugInfo ...string) {
responseDebugInfo := fmt.Sprintf("mcp:tools/call:error(%s)", err)
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
OnMCPResponseSuccess(ctx, map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": err.Error(),
},
},
"isError": true,
}, responseDebugInfo)
}
func SendMCPToolTextResult(ctx wrapper.HttpContext, result string, debugInfo ...string) {
responseDebugInfo := "mcp:tools/call::result"
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
OnMCPToolCallSuccess(ctx, []map[string]any{
{
"type": "text",
"text": result,
},
}, responseDebugInfo)
}
func SendMCPToolImageResult(ctx wrapper.HttpContext, image []byte, contentType string, debugInfo ...string) {
responseDebugInfo := "mcp:tools/call::result"
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
content := []map[string]any{
{
"type": "image",
"data": base64.StdEncoding.EncodeToString(image),
"mimeType": contentType,
},
}
// Use traditional response format since no structured data is provided
OnMCPToolCallSuccess(ctx, content, responseDebugInfo)
}
// SendMCPToolTextResultWithStructuredContent sends a tool result with both text content and structured content
// According to MCP spec, for backward compatibility, tools that return structured content
// SHOULD also return the serialized JSON in a TextContent block
func SendMCPToolTextResultWithStructuredContent(ctx wrapper.HttpContext, textResult string, structuredContent json.RawMessage, debugInfo ...string) {
responseDebugInfo := "mcp:tools/call::result"
if len(debugInfo) > 0 {
responseDebugInfo = debugInfo[0]
}
content := []map[string]any{
{
"type": "text",
"text": textResult,
},
}
OnMCPToolCallSuccessWithStructuredContent(ctx, content, structuredContent, responseDebugInfo)
}

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"net/url"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func IsStatefulSession(ctx wrapper.HttpContext) bool {
parse, err := url.Parse(ctx.Path())
if err != nil {
log.Errorf("failed to parse request path: %v", err)
return false
}
query, err := url.ParseQuery(parse.RawQuery)
if err != nil {
log.Errorf("failed to parse query params: %v", err)
return false
}
// Protocol version: 2024-11-05
if query.Get("sessionId") != "" {
return true
}
// Protocol version: 2025-03-26
sessionHeader, err := proxywasm.GetHttpRequestHeader("mcp-session-id")
if err != nil {
log.Errorf("failed to get request header: %v", err)
return false
}
if sessionHeader != "" {
return true
}
return false
}

View File

@@ -0,0 +1,164 @@
# MCP Configuration Validator
This package provides a configuration validation library for MCP (Model Context Protocol) server configurations. It allows you to validate MCP configurations without requiring the full runtime environment, making it perfect for use in management platforms and frontend applications.
## Features
- **Lightweight Validation**: Validates configuration structure and syntax without requiring actual server instances
- **REST Tool Support**: Full validation for REST-based MCP tools including request/response templates
- **ToolSet Support**: Validates composed server configurations (toolSets)
- **Pre-registered Server Handling**: Gracefully handles pre-registered Go-based servers by skipping their validation
- **Minimal Dependencies**: Reuses the core parsing logic from the main MCP server implementation
## Usage
### Basic Validation
```go
import "github.com/higress-group/wasm-go/pkg/mcp/validator"
// Validate a configuration YAML string
yamlConfig := `
server:
name: my-server
config:
apiKey: secret
tools:
- name: my-tool
description: A sample tool
args:
- name: input
type: string
required: true
requestTemplate:
url: https://api.example.com/endpoint
method: POST
responseTemplate:
body: "{{.}}"
`
result, err := validator.ValidateConfigYAML(yamlConfig)
if err != nil {
// Handle error
return
}
if result.IsValid {
fmt.Printf("Configuration is valid for server: %s\n", result.ServerName)
if result.IsComposed {
fmt.Println("This is a composed server (toolSet)")
} else {
fmt.Println("This is a single server")
}
} else {
fmt.Printf("Configuration is invalid: %v\n", result.Error)
}
```
## Supported Configuration Types
### 1. REST Server Configuration
Validates REST-based MCP servers with tools, security schemes, and templates:
```yaml
server:
name: weather-api
config:
apiKey: your-api-key
securitySchemes:
- id: bearer-auth
type: http
scheme: bearer
tools:
- name: get_weather
description: Get current weather
args:
- name: city
type: string
required: true
requestTemplate:
url: "https://api.weather.com/v1/current?city={{.args.city}}"
method: GET
responseTemplate:
body: "Weather: {{.temperature}}°C"
```
### 2. ToolSet Configuration (Composed Server)
Validates composed servers that aggregate tools from multiple servers:
```yaml
toolSet:
name: ai-assistant-tools
serverTools:
- serverName: weather-api
tools: ["get_weather", "get_forecast"]
- serverName: search-api
tools: ["web_search"]
```
### 3. Pre-registered Go-based Server
For pre-registered Go-based servers, validation focuses on basic structure and skips server instance validation:
```yaml
server:
name: custom-go-server
config:
database_url: "postgres://localhost:5432/mydb"
allowTools: ["query_database"]
```
## Validation Result
The `ValidationResult` struct provides detailed information about the validation:
```go
type ValidationResult struct {
IsValid bool `json:"isValid"` // Whether the configuration is valid
Error error `json:"error"` // Validation error if any
ServerName string `json:"serverName"` // Parsed server name
IsComposed bool `json:"isComposed"` // Whether it's a composed server
}
```
## Architecture
The validator reuses the core parsing logic from the main MCP server implementation through dependency injection:
- **parseConfigCore**: Core parsing logic with configurable dependencies
- **ConfigOptions**: Dependency config options
- **SkipPreRegisteredServers**: Flag to skip validation of pre-registered Go servers
This approach ensures:
- **Consistency**: Same validation logic as runtime
- **Maintainability**: Single source of truth for parsing logic
- **Minimal Code Duplication**: Reuses existing implementation
## Testing
Run the tests to verify the validator works correctly:
```bash
cd pkg/mcp/validator
go test -v
```
The test suite covers:
- REST server configuration validation
- ToolSet configuration validation
- Pre-registered server handling
- Invalid configuration detection
- Error cases
## Error Handling
The validator provides detailed error messages for common configuration issues:
- Missing required fields (e.g., `server.name`)
- Invalid JSON structure
- Malformed tool definitions
- Invalid template syntax
- Missing server or toolSet configuration
These errors help developers quickly identify and fix configuration problems before deployment.

View File

@@ -0,0 +1,129 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validator
import (
"encoding/json"
"fmt"
"os"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
)
// validatorLogger is a simple logger implementation for validation mode
type validatorLogger struct{}
func (l *validatorLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) }
func (l *validatorLogger) Tracef(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...)
}
func (l *validatorLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) }
func (l *validatorLogger) Debugf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...)
}
func (l *validatorLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) }
func (l *validatorLogger) Infof(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...)
}
func (l *validatorLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) }
func (l *validatorLogger) Warnf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...)
}
func (l *validatorLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) }
func (l *validatorLogger) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...)
}
func (l *validatorLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) }
func (l *validatorLogger) Criticalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...)
}
func (l *validatorLogger) ResetID(pluginID string) {}
// init initializes the validator package
func init() {
// Set a custom logger for validation mode to prevent panics
log.SetPluginLog(&validatorLogger{})
}
// ValidationResult contains the result of configuration validation
type ValidationResult struct {
IsValid bool `json:"isValid"`
Error error `json:"error,omitempty"`
ServerName string `json:"serverName,omitempty"`
IsComposed bool `json:"isComposed"`
}
// ValidateConfig validates MCP configuration
// This function focuses on validating REST tools and toolSet configurations
// It skips validation for pre-registered Go-based servers
func ValidateConfig(configJSON string) (*ValidationResult, error) {
// Create empty dependencies for validation mode
// We skip pre-registered servers validation by setting SkipPreRegisteredServers to true
toolRegistry := &server.GlobalToolRegistry{}
toolRegistry.Initialize() // Initialize the registry to prevent nil map assignment panic
deps := &server.ConfigOptions{
Servers: make(map[string]server.Server), // Empty servers map
ToolRegistry: toolRegistry, // Initialized registry
SkipPreRegisteredServers: true, // Skip pre-registered servers
}
// Call core parsing logic for validation
configGjson := gjson.Parse(configJSON)
mockConfig := &server.McpServerConfig{}
err := server.ParseConfigCore(configGjson, mockConfig, deps)
result := &ValidationResult{
IsValid: err == nil,
Error: err,
}
if err == nil {
result.ServerName = mockConfig.GetServerName()
result.IsComposed = mockConfig.GetIsComposed()
}
return result, nil
}
// ValidateConfigYAML validates MCP configuration from YAML format
// This function converts YAML to JSON first, then validates using the same logic
func ValidateConfigYAML(configYAML string) (*ValidationResult, error) {
// Parse YAML into a generic interface
var yamlData interface{}
if err := yaml.Unmarshal([]byte(configYAML), &yamlData); err != nil {
return &ValidationResult{
IsValid: false,
Error: fmt.Errorf("failed to parse YAML: %v", err),
}, nil
}
// Convert to JSON
jsonBytes, err := json.Marshal(yamlData)
if err != nil {
return &ValidationResult{
IsValid: false,
Error: fmt.Errorf("failed to convert YAML to JSON: %v", err),
}, nil
}
// Use the existing JSON validation logic
return ValidateConfig(string(jsonBytes))
}

View File

@@ -0,0 +1,266 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validator
import (
"testing"
)
func TestValidateConfig_RestServer(t *testing.T) {
// Test REST server configuration
configJSON := `{
"server": {
"name": "test-rest-server",
"config": {}
},
"tools": [
{
"name": "test-tool",
"description": "A test tool",
"args": [
{
"name": "input",
"description": "Input parameter",
"type": "string",
"required": true
}
],
"requestTemplate": {
"url": "https://api.example.com/test",
"method": "POST"
},
"responseTemplate": {
"body": "{{.}}"
}
}
]
}`
result, err := ValidateConfig(configJSON)
if err != nil {
t.Fatalf("ValidateConfig returned error: %v", err)
}
if !result.IsValid {
t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error)
}
if result.ServerName != "test-rest-server" {
t.Errorf("Expected server name 'test-rest-server', got '%s'", result.ServerName)
}
if result.IsComposed {
t.Errorf("Expected single server (not composed), but got composed")
}
}
func TestValidateConfig_ToolSet(t *testing.T) {
// Test toolSet configuration
configJSON := `{
"toolSet": {
"name": "test-toolset",
"serverTools": [
{
"serverName": "server1",
"tools": ["tool1", "tool2"]
},
{
"serverName": "server2",
"tools": ["tool3"]
}
]
}
}`
result, err := ValidateConfig(configJSON)
if err != nil {
t.Fatalf("ValidateConfig returned error: %v", err)
}
if !result.IsValid {
t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error)
}
if result.ServerName != "test-toolset" {
t.Errorf("Expected server name 'test-toolset', got '%s'", result.ServerName)
}
if !result.IsComposed {
t.Errorf("Expected composed server, but got single server")
}
}
func TestValidateConfig_PreRegisteredServer(t *testing.T) {
// Test pre-registered Go-based server configuration (should be skipped in validation)
configJSON := `{
"server": {
"name": "some-go-server",
"config": {
"someParam": "value"
}
}
}`
result, err := ValidateConfig(configJSON)
if err != nil {
t.Fatalf("ValidateConfig returned error: %v", err)
}
if !result.IsValid {
t.Errorf("Expected config to be valid (pre-registered servers should be skipped), but got invalid with error: %v", result.Error)
}
if result.ServerName != "some-go-server" {
t.Errorf("Expected server name 'some-go-server', got '%s'", result.ServerName)
}
if result.IsComposed {
t.Errorf("Expected single server (not composed), but got composed")
}
}
func TestValidateConfig_InvalidConfig(t *testing.T) {
// Test invalid configuration (missing required fields)
configJSON := `{
"server": {
"config": {}
}
}`
result, err := ValidateConfig(configJSON)
if err != nil {
t.Fatalf("ValidateConfig returned error: %v", err)
}
if result.IsValid {
t.Errorf("Expected config to be invalid, but got valid")
}
if result.Error == nil {
t.Errorf("Expected validation error, but got nil")
}
}
func TestValidateConfig_MissingServerAndToolSet(t *testing.T) {
// Test configuration missing both server and toolSet
configJSON := `{
"allowTools": ["tool1"]
}`
result, err := ValidateConfig(configJSON)
if err != nil {
t.Fatalf("ValidateConfig returned error: %v", err)
}
if result.IsValid {
t.Errorf("Expected config to be invalid, but got valid")
}
if result.Error == nil {
t.Errorf("Expected validation error, but got nil")
}
}
func TestValidateConfigYAML_RestServer(t *testing.T) {
// Test REST server configuration in YAML format
configYAML := `
server:
name: test-rest-server-yaml
config: {}
tools:
- name: test-tool-yaml
description: A test tool from YAML
args:
- name: input
description: Input parameter
type: string
required: true
requestTemplate:
url: https://api.example.com/test
method: POST
responseTemplate:
body: "{{.}}"
`
result, err := ValidateConfigYAML(configYAML)
if err != nil {
t.Fatalf("ValidateConfigYAML returned error: %v", err)
}
if !result.IsValid {
t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error)
}
if result.ServerName != "test-rest-server-yaml" {
t.Errorf("Expected server name 'test-rest-server-yaml', got '%s'", result.ServerName)
}
if result.IsComposed {
t.Errorf("Expected single server (not composed), but got composed")
}
}
func TestValidateConfigYAML_ToolSet(t *testing.T) {
// Test toolSet configuration in YAML format
configYAML := `
toolSet:
name: test-toolset-yaml
serverTools:
- serverName: server1
tools: ["tool1", "tool2"]
- serverName: server2
tools: ["tool3"]
`
result, err := ValidateConfigYAML(configYAML)
if err != nil {
t.Fatalf("ValidateConfigYAML returned error: %v", err)
}
if !result.IsValid {
t.Errorf("Expected config to be valid, but got invalid with error: %v", result.Error)
}
if result.ServerName != "test-toolset-yaml" {
t.Errorf("Expected server name 'test-toolset-yaml', got '%s'", result.ServerName)
}
if !result.IsComposed {
t.Errorf("Expected composed server, but got single server")
}
}
func TestValidateConfigYAML_InvalidYAML(t *testing.T) {
// Test invalid YAML syntax
configYAML := `
server:
name: test-server
config: {
invalid yaml syntax here
`
result, err := ValidateConfigYAML(configYAML)
if err != nil {
t.Fatalf("ValidateConfigYAML returned error: %v", err)
}
if result.IsValid {
t.Errorf("Expected config to be invalid due to YAML syntax error, but got valid")
}
if result.Error == nil {
t.Errorf("Expected YAML parsing error, but got nil")
}
}

View File

@@ -0,0 +1,250 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validator
import (
"encoding/json"
"fmt"
)
// ExampleUsage demonstrates how to use the ValidateConfig function
func ExampleUsage() {
// Example 1: REST server configuration
restServerConfig := `{
"server": {
"name": "weather-api",
"config": {
"apiKey": "your-api-key"
}
},
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"args": [
{
"name": "city",
"description": "City name",
"type": "string",
"required": true
},
{
"name": "units",
"description": "Temperature units",
"type": "string",
"enum": ["celsius", "fahrenheit"],
"default": "celsius"
}
],
"requestTemplate": {
"url": "https://api.weather.com/v1/current?city={{.args.city}}&units={{.args.units}}",
"method": "GET",
"headers": [
{
"key": "Authorization",
"value": "Bearer {{.config.apiKey}}"
}
]
},
"responseTemplate": {
"body": "Current weather in {{.args.city}}: {{.temperature}}°{{.args.units}}"
}
}
],
"allowTools": ["get_weather"]
}`
result, err := ValidateConfig(restServerConfig)
if err != nil {
fmt.Printf("Error validating REST server config: %v\n", err)
return
}
fmt.Printf("REST Server Config Validation:\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
fmt.Printf(" Server Name: %s\n", result.ServerName)
fmt.Printf(" Is Composed: %t\n", result.IsComposed)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
fmt.Println()
// Example 2: ToolSet configuration
toolSetConfig := `{
"toolSet": {
"name": "ai-assistant-tools",
"serverTools": [
{
"serverName": "weather-api",
"tools": ["get_weather", "get_forecast"]
},
{
"serverName": "search-api",
"tools": ["web_search", "image_search"]
}
]
},
"allowTools": ["weather-api/get_weather", "search-api/web_search"]
}`
result, err = ValidateConfig(toolSetConfig)
if err != nil {
fmt.Printf("Error validating toolSet config: %v\n", err)
return
}
fmt.Printf("ToolSet Config Validation:\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
fmt.Printf(" Server Name: %s\n", result.ServerName)
fmt.Printf(" Is Composed: %t\n", result.IsComposed)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
fmt.Println()
// Example 3: Pre-registered Go-based server (validation skipped)
goServerConfig := `{
"server": {
"name": "custom-go-server",
"config": {
"database_url": "postgres://localhost:5432/mydb",
"max_connections": 10
}
},
"allowTools": ["query_database", "update_record"]
}`
result, err = ValidateConfig(goServerConfig)
if err != nil {
fmt.Printf("Error validating Go server config: %v\n", err)
return
}
fmt.Printf("Go Server Config Validation (skipped):\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
fmt.Printf(" Server Name: %s\n", result.ServerName)
fmt.Printf(" Is Composed: %t\n", result.IsComposed)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
fmt.Println()
// Example 4: Invalid configuration
invalidConfig := `{
"server": {
"config": {}
}
}`
result, err = ValidateConfig(invalidConfig)
if err != nil {
fmt.Printf("Error validating invalid config: %v\n", err)
return
}
fmt.Printf("Invalid Config Validation:\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
}
// ValidateConfigFromBytes validates configuration from byte array
func ValidateConfigFromBytes(configBytes []byte) (*ValidationResult, error) {
return ValidateConfig(string(configBytes))
}
// ValidateConfigFromMap validates configuration from a map
func ValidateConfigFromMap(configMap map[string]interface{}) (*ValidationResult, error) {
configBytes, err := json.Marshal(configMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal config map: %v", err)
}
return ValidateConfig(string(configBytes))
}
// ExampleYAMLUsage demonstrates how to use the ValidateConfigYAML function
func ExampleYAMLUsage() {
// Example YAML configuration for REST server
yamlConfig := `
server:
name: weather-api-yaml
config:
apiKey: your-api-key
tools:
- name: get_weather
description: Get current weather for a city
args:
- name: city
description: City name
type: string
required: true
- name: units
description: Temperature units
type: string
enum: ["celsius", "fahrenheit"]
default: celsius
requestTemplate:
url: "https://api.weather.com/v1/current?city={{.args.city}}&units={{.args.units}}"
method: GET
headers:
- key: Authorization
value: "Bearer {{.config.apiKey}}"
responseTemplate:
body: "Current weather in {{.args.city}}: {{.temperature}}°{{.args.units}}"
allowTools: ["get_weather"]
`
result, err := ValidateConfigYAML(yamlConfig)
if err != nil {
fmt.Printf("Error validating YAML config: %v\n", err)
return
}
fmt.Printf("YAML Config Validation:\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
fmt.Printf(" Server Name: %s\n", result.ServerName)
fmt.Printf(" Is Composed: %t\n", result.IsComposed)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
fmt.Println()
// Example YAML configuration for ToolSet
yamlToolSetConfig := `
toolSet:
name: ai-assistant-tools-yaml
serverTools:
- serverName: weather-api
tools: ["get_weather", "get_forecast"]
- serverName: search-api
tools: ["web_search", "image_search"]
allowTools: ["weather-api/get_weather", "search-api/web_search"]
`
result, err = ValidateConfigYAML(yamlToolSetConfig)
if err != nil {
fmt.Printf("Error validating YAML toolSet config: %v\n", err)
return
}
fmt.Printf("YAML ToolSet Config Validation:\n")
fmt.Printf(" Valid: %t\n", result.IsValid)
fmt.Printf(" Server Name: %s\n", result.ServerName)
fmt.Printf(" Is Composed: %t\n", result.IsComposed)
if result.Error != nil {
fmt.Printf(" Error: %v\n", result.Error)
}
}

View File

@@ -1,183 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"fmt"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
type Cluster interface {
ClusterName() string
HostName() string
}
type RouteCluster struct {
Host string
}
func (c RouteCluster) ClusterName() string {
routeName, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil {
proxywasm.LogErrorf("get route cluster failed, err:%v", err)
}
return string(routeName)
}
func (c RouteCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return GetRequestHost()
}
type TargetCluster struct {
Host string
Cluster string
}
func (c TargetCluster) ClusterName() string {
return c.Cluster
}
func (c TargetCluster) HostName() string {
return c.Host
}
type K8sCluster struct {
ServiceName string
Namespace string
Port int64
Version string
Host string
}
func (c K8sCluster) ClusterName() string {
namespace := "default"
if c.Namespace != "" {
namespace = c.Namespace
}
return fmt.Sprintf("outbound|%d|%s|%s.%s.svc.cluster.local",
c.Port, c.Version, c.ServiceName, namespace)
}
func (c K8sCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return fmt.Sprintf("%s.%s.svc.cluster.local", c.ServiceName, c.Namespace)
}
type NacosCluster struct {
ServiceName string
// use DEFAULT-GROUP by default
Group string
NamespaceID string
Port int64
// set true if use edas/sae registry
IsExtRegistry bool
Version string
Host string
}
func (c NacosCluster) ClusterName() string {
group := "DEFAULT-GROUP"
if c.Group != "" {
group = strings.ReplaceAll(c.Group, "_", "-")
}
tail := "nacos"
if c.IsExtRegistry {
tail += "-ext"
}
return fmt.Sprintf("outbound|%d|%s|%s.%s.%s.%s",
c.Port, c.Version, c.ServiceName, group, c.NamespaceID, tail)
}
func (c NacosCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return c.ServiceName
}
type StaticIpCluster struct {
ServiceName string
Port int64
Host string
}
func (c StaticIpCluster) ClusterName() string {
return fmt.Sprintf("outbound|%d||%s.static", c.Port, c.ServiceName)
}
func (c StaticIpCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return c.ServiceName
}
type DnsCluster struct {
ServiceName string
Domain string
Port int64
}
func (c DnsCluster) ClusterName() string {
return fmt.Sprintf("outbound|%d||%s.dns", c.Port, c.ServiceName)
}
func (c DnsCluster) HostName() string {
return c.Domain
}
type ConsulCluster struct {
ServiceName string
Datacenter string
Port int64
Host string
}
func (c ConsulCluster) ClusterName() string {
tail := "consul"
return fmt.Sprintf("outbound|%d||%s.%s.%s",
c.Port, c.ServiceName, c.Datacenter, tail)
}
func (c ConsulCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return c.ServiceName
}
type FQDNCluster struct {
FQDN string
Host string
Port int64
}
func (c FQDNCluster) ClusterName() string {
return fmt.Sprintf("outbound|%d||%s", c.Port, c.FQDN)
}
func (c FQDNCluster) HostName() string {
if c.Host != "" {
return c.Host
}
return c.FQDN
}

View File

@@ -1,102 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestClusterAndHost(t *testing.T) {
cases := []struct {
name string
cluster Cluster
expectCluster string
expectHost string
}{
{
name: "k8s",
cluster: K8sCluster{
ServiceName: "foo",
Namespace: "bar",
Port: 8080,
Version: "1.0",
},
expectCluster: "outbound|8080|1.0|foo.bar.svc.cluster.local",
expectHost: "foo.bar.svc.cluster.local",
},
{
name: "k8s default",
cluster: K8sCluster{
ServiceName: "foo",
Port: 8080,
Host: "www.example.com",
},
expectCluster: "outbound|8080||foo.default.svc.cluster.local",
expectHost: "www.example.com",
},
{
name: "nacos",
cluster: NacosCluster{
ServiceName: "foo",
Group: "DEFAULT_GROUP",
NamespaceID: "xxxx",
Port: 8080,
Version: "1.0",
},
expectCluster: "outbound|8080|1.0|foo.DEFAULT-GROUP.xxxx.nacos",
expectHost: "foo",
},
{
name: "nacos ext",
cluster: NacosCluster{
ServiceName: "foo",
NamespaceID: "xxxx",
Port: 8080,
IsExtRegistry: true,
Host: "www.test.com",
},
expectCluster: "outbound|8080||foo.DEFAULT-GROUP.xxxx.nacos-ext",
expectHost: "www.test.com",
},
{
name: "static",
cluster: StaticIpCluster{
ServiceName: "foo",
Port: 8080,
Host: "www.test.com",
},
expectCluster: "outbound|8080||foo.static",
expectHost: "www.test.com",
},
{
name: "dns",
cluster: DnsCluster{
ServiceName: "foo",
Port: 8080,
Domain: "www.test.com",
},
expectCluster: "outbound|8080||foo.dns",
expectHost: "www.test.com",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
assert.Equal(t, c.expectCluster, c.cluster.ClusterName())
assert.Equal(t, c.expectHost, c.cluster.HostName())
})
}
}

View File

@@ -1,145 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
type ResponseCallback func(statusCode int, responseHeaders http.Header, responseBody []byte)
type HttpClient interface {
Get(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error
Head(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error
Options(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error
Post(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Put(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Patch(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Delete(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Connect(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Trace(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
Call(method, rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error
ClusterName() string
}
type ClusterClient[C Cluster] struct {
cluster C
}
func NewClusterClient[C Cluster](cluster C) *ClusterClient[C] {
return &ClusterClient[C]{cluster: cluster}
}
func (c ClusterClient[C]) Get(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodGet, rawURL, headers, nil, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Head(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodHead, rawURL, headers, nil, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Options(rawURL string, headers [][2]string, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodOptions, rawURL, headers, nil, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Post(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodPost, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Put(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodPut, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Patch(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodPatch, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Delete(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodDelete, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Connect(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodConnect, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Trace(rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, http.MethodTrace, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) Call(method, rawURL string, headers [][2]string, body []byte, cb ResponseCallback, timeoutMillisecond ...uint32) error {
return HttpCall(c.cluster, method, rawURL, headers, body, cb, timeoutMillisecond...)
}
func (c ClusterClient[C]) ClusterName() string { return c.cluster.ClusterName() }
func HttpCall(cluster Cluster, method, rawURL string, headers [][2]string, body []byte,
callback ResponseCallback, timeoutMillisecond ...uint32) error {
for i := len(headers) - 1; i >= 0; i-- {
key := headers[i][0]
if key == ":method" || key == ":path" || key == ":authority" {
headers = append(headers[:i], headers[i+1:]...)
}
}
parsedURL, err := url.Parse(rawURL)
if err != nil {
proxywasm.LogCriticalf("invalid rawURL:%s", rawURL)
return err
}
authority := cluster.HostName()
if parsedURL.Host != "" {
authority = parsedURL.Host
}
path := "/" + strings.TrimPrefix(parsedURL.Path, "/")
if parsedURL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, parsedURL.RawQuery)
}
// default timeout is 500ms
var timeout uint32 = 500
if len(timeoutMillisecond) > 0 {
timeout = timeoutMillisecond[0]
}
headers = append(headers, [2]string{":method", method}, [2]string{":path", path}, [2]string{":authority", authority})
requestID := uuid.New().String()
_, err = proxywasm.DispatchHttpCall(cluster.ClusterName(), headers, body, nil, timeout, func(numHeaders, bodySize, numTrailers int) {
respBody, err := proxywasm.GetHttpCallResponseBody(0, bodySize)
if err != nil {
proxywasm.LogCriticalf("failed to get response body: %v", err)
}
respHeaders, err := proxywasm.GetHttpCallResponseHeaders()
if err != nil {
proxywasm.LogCriticalf("failed to get response headers: %v", err)
}
code := http.StatusBadGateway
var normalResponse bool
headers := make(http.Header)
for _, h := range respHeaders {
if h[0] == ":status" {
code, err = strconv.Atoi(h[1])
if err != nil {
proxywasm.LogErrorf("failed to parse status: %v", err)
code = http.StatusInternalServerError
} else {
normalResponse = true
}
}
headers.Add(h[0], h[1])
}
proxywasm.LogDebugf("http call end, id: %s, code: %d, normal: %t, body: %s",
requestID, code, normalResponse, respBody)
callback(code, headers, respBody)
})
proxywasm.LogDebugf("http call start, id: %s, cluster: %s, method: %s, url: %s, headers: %#v, body: %s, timeout: %d",
requestID, cluster.ClusterName(), method, rawURL, headers, body, timeout)
return err
}

View File

@@ -1,135 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
type LogLevel uint32
const (
LogLevelTrace LogLevel = iota
LogLevelDebug
LogLevelInfo
LogLevelWarn
LogLevelError
LogLevelCritical
)
type DefaultLog struct {
pluginName string
pluginID string
}
func (l *DefaultLog) log(level LogLevel, msg string) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
requestID = "nil"
}
msg = fmt.Sprintf("[%s] [%s] [%s] %s", l.pluginName, l.pluginID, requestID, msg)
switch level {
case LogLevelTrace:
proxywasm.LogTrace(msg)
case LogLevelDebug:
proxywasm.LogDebug(msg)
case LogLevelInfo:
proxywasm.LogInfo(msg)
case LogLevelWarn:
proxywasm.LogWarn(msg)
case LogLevelError:
proxywasm.LogError(msg)
case LogLevelCritical:
proxywasm.LogCritical(msg)
}
}
func (l *DefaultLog) logFormat(level LogLevel, format string, args ...interface{}) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
requestID = "nil"
}
format = fmt.Sprintf("[%s] [%s] [%s] %s", l.pluginName, l.pluginID, requestID, format)
switch level {
case LogLevelTrace:
proxywasm.LogTracef(format, args...)
case LogLevelDebug:
proxywasm.LogDebugf(format, args...)
case LogLevelInfo:
proxywasm.LogInfof(format, args...)
case LogLevelWarn:
proxywasm.LogWarnf(format, args...)
case LogLevelError:
proxywasm.LogErrorf(format, args...)
case LogLevelCritical:
proxywasm.LogCriticalf(format, args...)
}
}
func (l *DefaultLog) Trace(msg string) {
l.log(LogLevelTrace, msg)
}
func (l *DefaultLog) Tracef(format string, args ...interface{}) {
l.logFormat(LogLevelTrace, format, args...)
}
func (l *DefaultLog) Debug(msg string) {
l.log(LogLevelDebug, msg)
}
func (l *DefaultLog) Debugf(format string, args ...interface{}) {
l.logFormat(LogLevelDebug, format, args...)
}
func (l *DefaultLog) Info(msg string) {
l.log(LogLevelInfo, msg)
}
func (l *DefaultLog) Infof(format string, args ...interface{}) {
l.logFormat(LogLevelInfo, format, args...)
}
func (l *DefaultLog) Warn(msg string) {
l.log(LogLevelWarn, msg)
}
func (l *DefaultLog) Warnf(format string, args ...interface{}) {
l.logFormat(LogLevelWarn, format, args...)
}
func (l *DefaultLog) Error(msg string) {
l.log(LogLevelError, msg)
}
func (l *DefaultLog) Errorf(format string, args ...interface{}) {
l.logFormat(LogLevelError, format, args...)
}
func (l *DefaultLog) Critical(msg string) {
l.log(LogLevelCritical, msg)
}
func (l *DefaultLog) Criticalf(format string, args ...interface{}) {
l.logFormat(LogLevelCritical, format, args...)
}
func (l *DefaultLog) ResetID(pluginID string) {
l.pluginID = pluginID
}

View File

@@ -1,791 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"encoding/json"
"fmt"
"strconv"
"time"
"unsafe"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/matcher"
)
type Log log.Log
const (
CustomLogKey = "custom_log"
AILogKey = "ai_log"
TraceSpanTagPrefix = "trace_span_tag."
PluginIDKey = "_plugin_id_"
)
type HttpContext interface {
Scheme() string
Host() string
Path() string
Method() string
SetContext(key string, value interface{})
GetContext(key string) interface{}
GetBoolContext(key string, defaultValue bool) bool
GetStringContext(key, defaultValue string) string
GetByteSliceContext(key string, defaultValue []byte) []byte
GetUserAttribute(key string) interface{}
SetUserAttribute(key string, value interface{})
SetUserAttributeMap(kvmap map[string]interface{})
GetUserAttributeMap() map[string]interface{}
// You can call this function to set custom log
WriteUserAttributeToLog() error
// You can call this function to set custom log with your specific key
WriteUserAttributeToLogWithKey(key string) error
// You can call this function to set custom trace span attribute
WriteUserAttributeToTrace() error
// If the onHttpRequestBody handle is not set, the request body will not be read by default
DontReadRequestBody()
// If the onHttpResponseBody handle is not set, the request body will not be read by default
DontReadResponseBody()
// If the onHttpStreamingRequestBody handle is not set, and the onHttpRequestBody handle is set, the request body will be buffered by default
BufferRequestBody()
// If the onHttpStreamingResponseBody handle is not set, and the onHttpResponseBody handle is set, the response body will be buffered by default
BufferResponseBody()
// If any request header is changed in onHttpRequestHeaders, envoy will re-calculate the route. Call this function to disable the re-routing.
// You need to call this before making any header modification operations.
DisableReroute()
// Note that this parameter affects the gateway's memory usageSupport setting a maximum buffer size for each request body individually in request phase.
SetRequestBodyBufferLimit(byteSize uint32)
// Note that this parameter affects the gateway's memory usage! Support setting a maximum buffer size for each response body individually in response phase.
SetResponseBodyBufferLimit(byteSize uint32)
// Get contextId of HttpContext
GetContextId() uint32
}
type oldParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig, log Log) error
type oldParseRuleConfigFunc[PluginConfig any] func(json gjson.Result, global PluginConfig, config *PluginConfig, log Log) error
type oldOnHttpHeadersFunc[PluginConfig any] func(context HttpContext, config PluginConfig, log Log) types.Action
type oldOnHttpBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, body []byte, log Log) types.Action
type oldOnHttpStreamingBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log Log) []byte
type oldOnHttpStreamDoneFunc[PluginConfig any] func(context HttpContext, config PluginConfig, log Log)
type ParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig) error
type ParseRuleConfigFunc[PluginConfig any] func(json gjson.Result, global PluginConfig, config *PluginConfig) error
type onHttpHeadersFunc[PluginConfig any] func(context HttpContext, config PluginConfig) types.Action
type onHttpBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, body []byte) types.Action
type onHttpStreamingBodyFunc[PluginConfig any] func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte
type onHttpStreamDoneFunc[PluginConfig any] func(context HttpContext, config PluginConfig)
type CommonVmCtx[PluginConfig any] struct {
types.DefaultVMContext
pluginName string
log Log
hasCustomConfig bool
parseConfig ParseConfigFunc[PluginConfig]
parseRuleConfig ParseRuleConfigFunc[PluginConfig]
onHttpRequestHeaders onHttpHeadersFunc[PluginConfig]
onHttpRequestBody onHttpBodyFunc[PluginConfig]
onHttpStreamingRequestBody onHttpStreamingBodyFunc[PluginConfig]
onHttpResponseHeaders onHttpHeadersFunc[PluginConfig]
onHttpResponseBody onHttpBodyFunc[PluginConfig]
onHttpStreamingResponseBody onHttpStreamingBodyFunc[PluginConfig]
onHttpStreamDone onHttpStreamDoneFunc[PluginConfig]
}
type TickFuncEntry struct {
lastExecuted int64
tickPeriod int64
tickFunc func()
}
var globalOnTickFuncs []TickFuncEntry = []TickFuncEntry{}
// Register multiple onTick functions. Parameters include:
// 1) tickPeriod: the execution period of tickFunc, this value should be a multiple of 100
// 2) tickFunc: the function to be executed
//
// You should call this function in parseConfig phase, for example:
//
// func parseConfig(json gjson.Result, config *HelloWorldConfig, log wrapper.Log) error {
// wrapper.RegisterTickFunc(1000, func() { proxywasm.LogInfo("onTick 1s") })
// wrapper.RegisterTickFunc(3000, func() { proxywasm.LogInfo("onTick 3s") })
// return nil
// }
func RegisterTickFunc(tickPeriod int64, tickFunc func()) {
globalOnTickFuncs = append(globalOnTickFuncs, TickFuncEntry{0, tickPeriod, tickFunc})
}
func SetCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtx(pluginName, options...))
}
func SetCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtxWithOptions(pluginName, options...))
}
type CtxOption[PluginConfig any] interface {
Apply(*CommonVmCtx[PluginConfig])
}
type parseConfigOption[PluginConfig any] struct {
f ParseConfigFunc[PluginConfig]
oldF oldParseConfigFunc[PluginConfig]
}
func (o parseConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.parseConfig = o.f
} else {
ctx.parseConfig = func(json gjson.Result, config *PluginConfig) error { return o.oldF(json, config, ctx.log) }
}
}
// Deprecated: Please use `ParseConfig` instead.
func ParseConfigBy[PluginConfig any](f oldParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseConfigOption[PluginConfig]{oldF: f}
}
func ParseConfig[PluginConfig any](f ParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseConfigOption[PluginConfig]{f: f}
}
type parseOverrideConfigOption[PluginConfig any] struct {
parseConfigF ParseConfigFunc[PluginConfig]
parseRuleConfigF ParseRuleConfigFunc[PluginConfig]
oldParseConfigF oldParseConfigFunc[PluginConfig]
oldParseRuleConfigF oldParseRuleConfigFunc[PluginConfig]
}
func (o *parseOverrideConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.parseConfigF != nil && o.parseRuleConfigF != nil {
ctx.parseConfig = o.parseConfigF
ctx.parseRuleConfig = o.parseRuleConfigF
} else {
ctx.parseConfig = func(json gjson.Result, config *PluginConfig) error {
return o.oldParseConfigF(json, config, ctx.log)
}
ctx.parseRuleConfig = func(json gjson.Result, global PluginConfig, config *PluginConfig) error {
return o.oldParseRuleConfigF(json, global, config, ctx.log)
}
}
}
// Deprecated: Please use `ParseOverrideConfig` instead.
func ParseOverrideConfigBy[PluginConfig any](f oldParseConfigFunc[PluginConfig], g oldParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseOverrideConfigOption[PluginConfig]{
oldParseConfigF: f,
oldParseRuleConfigF: g,
}
}
func ParseOverrideConfig[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseOverrideConfigOption[PluginConfig]{
parseConfigF: f,
parseRuleConfigF: g,
}
}
type onProcessRequestHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
oldF oldOnHttpHeadersFunc[PluginConfig]
}
func (o *onProcessRequestHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpRequestHeaders = o.f
} else {
ctx.onHttpRequestHeaders = func(context HttpContext, config PluginConfig) types.Action {
return o.oldF(context, config, ctx.log)
}
}
}
// Deprecated: Please use `ProcessRequestHeaders` instead.
func ProcessRequestHeadersBy[PluginConfig any](f oldOnHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestHeadersOption[PluginConfig]{oldF: f}
}
func ProcessRequestHeaders[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestHeadersOption[PluginConfig]{f: f}
}
type onProcessRequestBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
oldF oldOnHttpBodyFunc[PluginConfig]
}
func (o *onProcessRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpRequestBody = o.f
} else {
ctx.onHttpRequestBody = func(context HttpContext, config PluginConfig, body []byte) types.Action {
return o.oldF(context, config, body, ctx.log)
}
}
}
// Deprecated: Please use `ProcessRequestBody` instead.
func ProcessRequestBodyBy[PluginConfig any](f oldOnHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestBodyOption[PluginConfig]{oldF: f}
}
func ProcessRequestBody[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestBodyOption[PluginConfig]{f: f}
}
type onProcessStreamingRequestBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
oldF oldOnHttpStreamingBodyFunc[PluginConfig]
}
func (o *onProcessStreamingRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpStreamingRequestBody = o.f
} else {
ctx.onHttpStreamingRequestBody = func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte {
return o.oldF(context, config, chunk, isLastChunk, ctx.log)
}
}
}
// Deprecated: Please use `ProcessStreamingRequestBody` instead.
func ProcessStreamingRequestBodyBy[PluginConfig any](f oldOnHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingRequestBodyOption[PluginConfig]{oldF: f}
}
func ProcessStreamingRequestBody[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingRequestBodyOption[PluginConfig]{f: f}
}
type onProcessResponseHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
oldF oldOnHttpHeadersFunc[PluginConfig]
}
func (o *onProcessResponseHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpResponseHeaders = o.f
} else {
ctx.onHttpResponseHeaders = func(context HttpContext, config PluginConfig) types.Action {
return o.oldF(context, config, ctx.log)
}
}
}
// Deprecated: Please use `ProcessResponseHeaders` instead.
func ProcessResponseHeadersBy[PluginConfig any](f oldOnHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseHeadersOption[PluginConfig]{oldF: f}
}
func ProcessResponseHeaders[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseHeadersOption[PluginConfig]{f: f}
}
type onProcessResponseBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
oldF oldOnHttpBodyFunc[PluginConfig]
}
func (o *onProcessResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpResponseBody = o.f
} else {
ctx.onHttpResponseBody = func(context HttpContext, config PluginConfig, body []byte) types.Action {
return o.oldF(context, config, body, ctx.log)
}
}
}
// Deprecated: Please use `ProcessResponseBody` instead.
func ProcessResponseBodyBy[PluginConfig any](f oldOnHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseBodyOption[PluginConfig]{oldF: f}
}
func ProcessResponseBody[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseBodyOption[PluginConfig]{f: f}
}
type onProcessStreamingResponseBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
oldF oldOnHttpStreamingBodyFunc[PluginConfig]
}
func (o *onProcessStreamingResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpStreamingResponseBody = o.f
} else {
ctx.onHttpStreamingResponseBody = func(context HttpContext, config PluginConfig, chunk []byte, isLastChunk bool) []byte {
return o.oldF(context, config, chunk, isLastChunk, ctx.log)
}
}
}
// Deprecated: Please use `ProcessStreamingResponseBody` instead.
func ProcessStreamingResponseBodyBy[PluginConfig any](f oldOnHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingResponseBodyOption[PluginConfig]{oldF: f}
}
func ProcessStreamingResponseBody[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingResponseBodyOption[PluginConfig]{f: f}
}
type onProcessStreamDoneOption[PluginConfig any] struct {
f onHttpStreamDoneFunc[PluginConfig]
oldF oldOnHttpStreamDoneFunc[PluginConfig]
}
func (o *onProcessStreamDoneOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
if o.f != nil {
ctx.onHttpStreamDone = o.f
} else {
ctx.onHttpStreamDone = func(context HttpContext, config PluginConfig) { o.oldF(context, config, ctx.log) }
}
}
// Deprecated: Please use `ProcessStreamDoneBy` instead.
func ProcessStreamDoneBy[PluginConfig any](f oldOnHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamDoneOption[PluginConfig]{oldF: f}
}
func ProcessStreamDone[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamDoneOption[PluginConfig]{f: f}
}
type logOption[PluginConfig any] struct {
logger Log
}
func (o *logOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
log.SetPluginLog(o.logger)
ctx.log = o.logger
}
func WithLogger[PluginConfig any](logger Log) CtxOption[PluginConfig] {
return &logOption[PluginConfig]{logger}
}
func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig) error {
return nil
}
func NewCommonVmCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
logger := &DefaultLog{pluginName, "nil"}
opts := []CtxOption[PluginConfig]{WithLogger[PluginConfig](logger)}
for _, opt := range options {
if opt == nil {
continue
}
opts = append(opts, opt)
}
return NewCommonVmCtxWithOptions(pluginName, opts...)
}
func NewCommonVmCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
ctx := &CommonVmCtx[PluginConfig]{
pluginName: pluginName,
hasCustomConfig: true,
}
for _, opt := range options {
opt.Apply(ctx)
}
if ctx.parseConfig == nil {
var config PluginConfig
if unsafe.Sizeof(config) != 0 {
msg := "the `parseConfig` is missing in NewCommonVmCtx's arguments"
panic(msg)
}
ctx.hasCustomConfig = false
ctx.parseConfig = parseEmptyPluginConfig[PluginConfig]
}
return ctx
}
func (ctx *CommonVmCtx[PluginConfig]) NewPluginContext(uint32) types.PluginContext {
return &CommonPluginCtx[PluginConfig]{
vm: ctx,
}
}
type CommonPluginCtx[PluginConfig any] struct {
types.DefaultPluginContext
matcher.RuleMatcher[PluginConfig]
vm *CommonVmCtx[PluginConfig]
onTickFuncs []TickFuncEntry
}
func (ctx *CommonPluginCtx[PluginConfig]) OnPluginStart(int) types.OnPluginStartStatus {
data, err := proxywasm.GetPluginConfiguration()
globalOnTickFuncs = nil
if err != nil && err != types.ErrorStatusNotFound {
ctx.vm.log.Criticalf("error reading plugin configuration: %v", err)
return types.OnPluginStartStatusFailed
}
var jsonData gjson.Result
if len(data) == 0 {
if ctx.vm.hasCustomConfig {
ctx.vm.log.Warn("config is empty, but has ParseConfigFunc")
}
} else {
if !gjson.ValidBytes(data) {
ctx.vm.log.Warnf("the plugin configuration is not a valid json: %s", string(data))
return types.OnPluginStartStatusFailed
}
pluginID := gjson.GetBytes(data, PluginIDKey).String()
if pluginID != "" {
ctx.vm.log.ResetID(pluginID)
data, _ = sjson.DeleteBytes([]byte(data), PluginIDKey)
}
jsonData = gjson.ParseBytes(data)
}
var parseOverrideConfig func(gjson.Result, PluginConfig, *PluginConfig) error
if ctx.vm.parseRuleConfig != nil {
parseOverrideConfig = func(js gjson.Result, global PluginConfig, cfg *PluginConfig) error {
return ctx.vm.parseRuleConfig(js, global, cfg)
}
}
err = ctx.ParseRuleConfig(jsonData,
func(js gjson.Result, cfg *PluginConfig) error {
return ctx.vm.parseConfig(js, cfg)
},
parseOverrideConfig,
)
if err != nil {
ctx.vm.log.Warnf("parse rule config failed: %v", err)
ctx.vm.log.Error("plugin start failed")
return types.OnPluginStartStatusFailed
}
if globalOnTickFuncs != nil {
ctx.onTickFuncs = globalOnTickFuncs
if err := proxywasm.SetTickPeriodMilliSeconds(100); err != nil {
ctx.vm.log.Error("SetTickPeriodMilliSeconds failed, onTick functions will not take effect.")
ctx.vm.log.Error("plugin start failed")
return types.OnPluginStartStatusFailed
}
}
ctx.vm.log.Info("plugin start successfully")
return types.OnPluginStartStatusOK
}
func (ctx *CommonPluginCtx[PluginConfig]) OnTick() {
for i := range ctx.onTickFuncs {
currentTimeStamp := time.Now().UnixMilli()
if currentTimeStamp-ctx.onTickFuncs[i].lastExecuted >= ctx.onTickFuncs[i].tickPeriod {
ctx.onTickFuncs[i].tickFunc()
ctx.onTickFuncs[i].lastExecuted = currentTimeStamp
}
}
}
func (ctx *CommonPluginCtx[PluginConfig]) NewHttpContext(contextID uint32) types.HttpContext {
httpCtx := &CommonHttpCtx[PluginConfig]{
plugin: ctx,
contextID: contextID,
userContext: map[string]interface{}{},
userAttribute: map[string]interface{}{},
}
if ctx.vm.onHttpRequestBody != nil || ctx.vm.onHttpStreamingRequestBody != nil {
httpCtx.needRequestBody = true
}
if ctx.vm.onHttpResponseBody != nil || ctx.vm.onHttpStreamingResponseBody != nil {
httpCtx.needResponseBody = true
}
if ctx.vm.onHttpStreamingRequestBody != nil {
httpCtx.streamingRequestBody = true
}
if ctx.vm.onHttpStreamingResponseBody != nil {
httpCtx.streamingResponseBody = true
}
return httpCtx
}
type CommonHttpCtx[PluginConfig any] struct {
types.DefaultHttpContext
plugin *CommonPluginCtx[PluginConfig]
config *PluginConfig
needRequestBody bool
needResponseBody bool
streamingRequestBody bool
streamingResponseBody bool
requestBodySize int
responseBodySize int
contextID uint32
userContext map[string]interface{}
userAttribute map[string]interface{}
}
func (ctx *CommonHttpCtx[PluginConfig]) SetContext(key string, value interface{}) {
ctx.userContext[key] = value
}
func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} {
return ctx.userContext[key]
}
func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttribute(key string, value interface{}) {
ctx.userAttribute[key] = value
}
func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{} {
return ctx.userAttribute[key]
}
func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) {
ctx.userAttribute = kvmap
}
func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} {
return ctx.userAttribute
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error {
return ctx.WriteUserAttributeToLogWithKey(CustomLogKey)
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLogWithKey(key string) error {
// e.g. {\"field1\":\"value1\",\"field2\":\"value2\"}
preMarshalledJsonLogStr, _ := proxywasm.GetProperty([]string{key})
newAttributeMap := map[string]interface{}{}
if string(preMarshalledJsonLogStr) != "" {
// e.g. {"field1":"value1","field2":"value2"}
preJsonLogStr := UnmarshalStr(fmt.Sprintf(`"%s"`, string(preMarshalledJsonLogStr)))
err := json.Unmarshal([]byte(preJsonLogStr), &newAttributeMap)
if err != nil {
ctx.plugin.vm.log.Warnf("Unmarshal failed, will overwrite %s, pre value is: %s", key, string(preMarshalledJsonLogStr))
return err
}
}
// update customLog
for k, v := range ctx.userAttribute {
newAttributeMap[k] = v
}
// e.g. {"field1":"value1","field2":2,"field3":"value3"}
jsonStr, _ := json.Marshal(newAttributeMap)
// e.g. {\"field1\":\"value1\",\"field2\":2,\"field3\":\"value3\"}
marshalledJsonStr := MarshalStr(string(jsonStr))
if err := proxywasm.SetProperty([]string{key}, []byte(marshalledJsonStr)); err != nil {
ctx.plugin.vm.log.Warnf("failed to set %s in filter state, raw is %s, err is %v", key, marshalledJsonStr, err)
return err
}
return nil
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToTrace() error {
for k, v := range ctx.userAttribute {
traceSpanTag := TraceSpanTagPrefix + k
traceSpanValue := fmt.Sprint(v)
var err error
if traceSpanValue != "" {
err = proxywasm.SetProperty([]string{traceSpanTag}, []byte(traceSpanValue))
} else {
err = fmt.Errorf("value of %s is empty", traceSpanTag)
}
if err != nil {
ctx.plugin.vm.log.Warnf("Failed to set trace attribute - %s: %s, error message: %v", traceSpanTag, traceSpanValue, err)
}
}
return nil
}
func (ctx *CommonHttpCtx[PluginConfig]) GetBoolContext(key string, defaultValue bool) bool {
if b, ok := ctx.userContext[key].(bool); ok {
return b
}
return defaultValue
}
func (ctx *CommonHttpCtx[PluginConfig]) GetStringContext(key, defaultValue string) string {
if s, ok := ctx.userContext[key].(string); ok {
return s
}
return defaultValue
}
func (ctx *CommonHttpCtx[PluginConfig]) GetByteSliceContext(key string, defaultValue []byte) []byte {
if s, ok := ctx.userContext[key].([]byte); ok {
return s
}
return defaultValue
}
func (ctx *CommonHttpCtx[PluginConfig]) Scheme() string {
proxywasm.SetEffectiveContext(ctx.contextID)
return GetRequestScheme()
}
func (ctx *CommonHttpCtx[PluginConfig]) Host() string {
proxywasm.SetEffectiveContext(ctx.contextID)
return GetRequestHost()
}
func (ctx *CommonHttpCtx[PluginConfig]) Path() string {
proxywasm.SetEffectiveContext(ctx.contextID)
return GetRequestPath()
}
func (ctx *CommonHttpCtx[PluginConfig]) Method() string {
proxywasm.SetEffectiveContext(ctx.contextID)
return GetRequestMethod()
}
func (ctx *CommonHttpCtx[PluginConfig]) DontReadRequestBody() {
ctx.needRequestBody = false
}
func (ctx *CommonHttpCtx[PluginConfig]) DontReadResponseBody() {
ctx.needResponseBody = false
}
func (ctx *CommonHttpCtx[PluginConfig]) BufferRequestBody() {
ctx.streamingRequestBody = false
}
func (ctx *CommonHttpCtx[PluginConfig]) BufferResponseBody() {
ctx.streamingResponseBody = false
}
func (ctx *CommonHttpCtx[PluginConfig]) DisableReroute() {
_ = proxywasm.SetProperty([]string{"clear_route_cache"}, []byte("off"))
}
func (ctx *CommonHttpCtx[PluginConfig]) SetRequestBodyBufferLimit(size uint32) {
ctx.plugin.vm.log.Infof("SetRequestBodyBufferLimit: %d", size)
_ = proxywasm.SetProperty([]string{"set_decoder_buffer_limit"}, []byte(strconv.Itoa(int(size))))
}
func (ctx *CommonHttpCtx[PluginConfig]) SetResponseBodyBufferLimit(size uint32) {
ctx.plugin.vm.log.Infof("SetResponseBodyBufferLimit: %d", size)
_ = proxywasm.SetProperty([]string{"set_encoder_buffer_limit"}, []byte(strconv.Itoa(int(size))))
}
func (ctx *CommonHttpCtx[PluginConfig]) GetContextId() uint32 {
return ctx.contextID
}
func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action {
requestID, _ := proxywasm.GetHttpRequestHeader("x-request-id")
_ = proxywasm.SetProperty([]string{"x_request_id"}, []byte(requestID))
config, err := ctx.plugin.GetMatchConfig()
if err != nil {
ctx.plugin.vm.log.Errorf("get match config failed, err:%v", err)
return types.ActionContinue
}
if config == nil {
return types.ActionContinue
}
ctx.config = config
// To avoid unexpected operations, plugins do not read the binary content body
if IsBinaryRequestBody() {
ctx.needRequestBody = false
}
if ctx.plugin.vm.onHttpRequestHeaders == nil {
return types.ActionContinue
}
return ctx.plugin.vm.onHttpRequestHeaders(ctx, *config)
}
func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action {
if ctx.config == nil {
return types.ActionContinue
}
if !ctx.needRequestBody {
return types.ActionContinue
}
if ctx.plugin.vm.onHttpStreamingRequestBody != nil && ctx.streamingRequestBody {
chunk, _ := proxywasm.GetHttpRequestBody(0, bodySize)
modifiedChunk := ctx.plugin.vm.onHttpStreamingRequestBody(ctx, *ctx.config, chunk, endOfStream)
err := proxywasm.ReplaceHttpRequestBody(modifiedChunk)
if err != nil {
ctx.plugin.vm.log.Warnf("replace request body chunk failed: %v", err)
return types.ActionContinue
}
return types.ActionContinue
}
if ctx.plugin.vm.onHttpRequestBody != nil {
ctx.requestBodySize += bodySize
if !endOfStream {
return types.ActionPause
}
body, err := proxywasm.GetHttpRequestBody(0, ctx.requestBodySize)
if err != nil {
ctx.plugin.vm.log.Warnf("get request body failed: %v", err)
return types.ActionContinue
}
return ctx.plugin.vm.onHttpRequestBody(ctx, *ctx.config, body)
}
return types.ActionContinue
}
func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseHeaders(numHeaders int, endOfStream bool) types.Action {
if ctx.config == nil {
return types.ActionContinue
}
// To avoid unexpected operations, plugins do not read the binary content body
if IsBinaryResponseBody() {
ctx.needResponseBody = false
}
if ctx.plugin.vm.onHttpResponseHeaders == nil {
return types.ActionContinue
}
return ctx.plugin.vm.onHttpResponseHeaders(ctx, *ctx.config)
}
func (ctx *CommonHttpCtx[PluginConfig]) OnHttpResponseBody(bodySize int, endOfStream bool) types.Action {
if ctx.config == nil {
return types.ActionContinue
}
if !ctx.needResponseBody {
return types.ActionContinue
}
if ctx.plugin.vm.onHttpStreamingResponseBody != nil && ctx.streamingResponseBody {
chunk, _ := proxywasm.GetHttpResponseBody(0, bodySize)
modifiedChunk := ctx.plugin.vm.onHttpStreamingResponseBody(ctx, *ctx.config, chunk, endOfStream)
err := proxywasm.ReplaceHttpResponseBody(modifiedChunk)
if err != nil {
ctx.plugin.vm.log.Warnf("replace response body chunk failed: %v", err)
return types.ActionContinue
}
return types.ActionContinue
}
if ctx.plugin.vm.onHttpResponseBody != nil {
ctx.responseBodySize += bodySize
if !endOfStream {
return types.ActionPause
}
body, err := proxywasm.GetHttpResponseBody(0, ctx.responseBodySize)
if err != nil {
ctx.plugin.vm.log.Warnf("get response body failed: %v", err)
return types.ActionContinue
}
return ctx.plugin.vm.onHttpResponseBody(ctx, *ctx.config, body)
}
return types.ActionContinue
}
func (ctx *CommonHttpCtx[PluginConfig]) OnHttpStreamDone() {
if ctx.config == nil {
return
}
if ctx.plugin.vm.onHttpStreamDone == nil {
return
}
ctx.plugin.vm.onHttpStreamDone(ctx, *ctx.config)
}

View File

@@ -1,902 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/resp"
)
type RedisResponseCallback func(response resp.Value)
type RedisClient interface {
Init(username, password string, timeout int64, opts ...optionFunc) error
// return whether redis client is ready
Ready() bool
// with this function, you can call redis as if you are using redis-cli
Command(cmds []interface{}, callback RedisResponseCallback) error
Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error
// Key
Del(key string, callback RedisResponseCallback) error
Exists(key string, callback RedisResponseCallback) error
Expire(key string, ttl int, callback RedisResponseCallback) error
Persist(key string, callback RedisResponseCallback) error
// String
Get(key string, callback RedisResponseCallback) error
Set(key string, value interface{}, callback RedisResponseCallback) error
SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error
SetNX(key string, value interface{}, ttl int, callback RedisResponseCallback) error
MGet(keys []string, callback RedisResponseCallback) error
MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error
Incr(key string, callback RedisResponseCallback) error
Decr(key string, callback RedisResponseCallback) error
IncrBy(key string, delta int, callback RedisResponseCallback) error
DecrBy(key string, delta int, callback RedisResponseCallback) error
// List
LLen(key string, callback RedisResponseCallback) error
RPush(key string, vals []interface{}, callback RedisResponseCallback) error
RPop(key string, callback RedisResponseCallback) error
LPush(key string, vals []interface{}, callback RedisResponseCallback) error
LPop(key string, callback RedisResponseCallback) error
LIndex(key string, index int, callback RedisResponseCallback) error
LRange(key string, start, stop int, callback RedisResponseCallback) error
LRem(key string, count int, value interface{}, callback RedisResponseCallback) error
LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error
LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error
// Hash
HExists(key, field string, callback RedisResponseCallback) error
HDel(key string, fields []string, callback RedisResponseCallback) error
HLen(key string, callback RedisResponseCallback) error
HGet(key, field string, callback RedisResponseCallback) error
HSet(key, field string, value interface{}, callback RedisResponseCallback) error
HMGet(key string, fields []string, callback RedisResponseCallback) error
HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error
HKeys(key string, callback RedisResponseCallback) error
HVals(key string, callback RedisResponseCallback) error
HGetAll(key string, callback RedisResponseCallback) error
HIncrBy(key, field string, delta int, callback RedisResponseCallback) error
HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error
// Set
SCard(key string, callback RedisResponseCallback) error
SAdd(key string, value []interface{}, callback RedisResponseCallback) error
SRem(key string, values []interface{}, callback RedisResponseCallback) error
SIsMember(key string, value interface{}, callback RedisResponseCallback) error
SMembers(key string, callback RedisResponseCallback) error
SDiff(key1, key2 string, callback RedisResponseCallback) error
SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error
SInter(key1, key2 string, callback RedisResponseCallback) error
SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error
SUnion(key1, key2 string, callback RedisResponseCallback) error
SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error
// Sorted Set
ZCard(key string, callback RedisResponseCallback) error
ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error
ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error
ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error
ZScore(key, member string, callback RedisResponseCallback) error
ZRank(key, member string, callback RedisResponseCallback) error
ZRevRank(key, member string, callback RedisResponseCallback) error
ZRem(key string, members []string, callback RedisResponseCallback) error
ZRange(key string, start, stop int, callback RedisResponseCallback) error
ZRevRange(key string, start, stop int, callback RedisResponseCallback) error
}
type RedisClusterClient[C Cluster] struct {
cluster C
ready bool
checkReadyFunc func() error
option redisOption
}
type redisOption struct {
dataBase int
}
type optionFunc func(*redisOption)
func WithDataBase(dataBase int) optionFunc {
return func(o *redisOption) {
o.dataBase = dataBase
}
}
func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] {
return &RedisClusterClient[C]{
cluster: cluster,
checkReadyFunc: func() error {
return errors.New("redis client is not ready, please call Init() first")
},
}
}
func RedisCall(cluster Cluster, respQuery []byte, callback RedisResponseCallback) error {
requestID := uuid.New().String()
_, err := proxywasm.DispatchRedisCall(
cluster.ClusterName(),
respQuery,
func(status int, responseSize int) {
response, err := proxywasm.GetRedisCallResponse(0, responseSize)
var responseValue resp.Value
if status != 0 {
proxywasm.LogCriticalf("Error occurred while calling redis, it seems cannot connect to the redis cluster. request-id: %s", requestID)
responseValue = resp.ErrorValue(fmt.Errorf("cannot connect to redis cluster"))
} else {
if err != nil {
proxywasm.LogCriticalf("failed to get redis response body, request-id: %s, error: %v", requestID, err)
responseValue = resp.ErrorValue(fmt.Errorf("cannot get redis response"))
} else {
rd := resp.NewReader(bytes.NewReader(response))
value, _, err := rd.ReadValue()
if err != nil && err != io.EOF {
proxywasm.LogCriticalf("failed to read redis response body, request-id: %s, error: %v", requestID, err)
responseValue = resp.ErrorValue(fmt.Errorf("cannot read redis response"))
} else {
responseValue = value
proxywasm.LogDebugf("redis call end, request-id: %s, respQuery: %s, respValue: %s",
requestID, base64.StdEncoding.EncodeToString([]byte(respQuery)), base64.StdEncoding.EncodeToString(response))
}
}
}
if callback != nil {
callback(responseValue)
}
})
if err != nil {
proxywasm.LogCriticalf("redis call failed, request-id: %s, error: %v", requestID, err)
} else {
proxywasm.LogDebugf("redis call start, request-id: %s, respQuery: %s", requestID, base64.StdEncoding.EncodeToString([]byte(respQuery)))
}
return err
}
func respString(args []interface{}) []byte {
var buf bytes.Buffer
wr := resp.NewWriter(&buf)
arr := make([]resp.Value, 0)
for _, arg := range args {
arr = append(arr, resp.StringValue(fmt.Sprint(arg)))
}
wr.WriteArray(arr)
return buf.Bytes()
}
func (c *RedisClusterClient[C]) Ready() bool {
return c.ready
}
func (c *RedisClusterClient[C]) Init(username, password string, timeout int64, opts ...optionFunc) error {
for _, opt := range opts {
opt(&c.option)
}
clusterName := c.cluster.ClusterName()
if c.option.dataBase != 0 {
clusterName = fmt.Sprintf("%s?db=%d", clusterName, c.option.dataBase)
}
err := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
if err != nil {
c.checkReadyFunc = func() error {
if c.ready {
return nil
}
initErr := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
if initErr != nil {
return initErr
}
c.ready = true
return nil
}
proxywasm.LogWarnf("failed to init redis: %v, will retry after", err)
return nil
}
c.checkReadyFunc = func() error { return nil }
c.ready = true
return nil
}
func (c *RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
return RedisCall(c.cluster, respString(cmds), callback)
}
func (c *RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
params := make([]interface{}, 0)
params = append(params, "eval")
params = append(params, script)
params = append(params, numkeys)
params = append(params, keys...)
params = append(params, args...)
return RedisCall(c.cluster, respString(params), callback)
}
// Key
func (c *RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "del")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "exists")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "expire")
args = append(args, key)
args = append(args, ttl)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "persist")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
// String
func (c *RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "get")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "set")
args = append(args, key)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "set")
args = append(args, key)
args = append(args, value)
args = append(args, "ex")
args = append(args, ttl)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SetNX(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "set")
args = append(args, key)
args = append(args, value)
args = append(args, "nx")
if ttl > 0 {
args = append(args, "ex")
args = append(args, ttl)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "mget")
for _, k := range keys {
args = append(args, k)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "mset")
for k, v := range kvMap {
args = append(args, k)
args = append(args, v)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "incr")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "decr")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "incrby")
args = append(args, key)
args = append(args, delta)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "decrby")
args = append(args, key)
args = append(args, delta)
return RedisCall(c.cluster, respString(args), callback)
}
// List
func (c *RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "llen")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "rpush")
args = append(args, key)
for _, val := range vals {
args = append(args, val)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "rpop")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "lpush")
args = append(args, key)
for _, val := range vals {
args = append(args, val)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "lpop")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "lindex")
args = append(args, key)
args = append(args, index)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "lrange")
args = append(args, key)
args = append(args, start)
args = append(args, stop)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "lrem")
args = append(args, key)
args = append(args, count)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "linsert")
args = append(args, key)
args = append(args, "before")
args = append(args, pivot)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "linsert")
args = append(args, key)
args = append(args, "after")
args = append(args, pivot)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
// Hash
func (c *RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hexists")
args = append(args, key)
args = append(args, field)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hdel")
args = append(args, key)
for _, field := range fields {
args = append(args, field)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hlen")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hget")
args = append(args, key)
args = append(args, field)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hset")
args = append(args, key)
args = append(args, field)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hmget")
args = append(args, key)
for _, field := range fields {
args = append(args, field)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hmset")
args = append(args, key)
for k, v := range kvMap {
args = append(args, k)
args = append(args, v)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hkeys")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hvals")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hgetall")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hincrby")
args = append(args, key)
args = append(args, field)
args = append(args, delta)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "hincrbyfloat")
args = append(args, key)
args = append(args, field)
args = append(args, delta)
return RedisCall(c.cluster, respString(args), callback)
}
// Set
func (c *RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "scard")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sadd")
args = append(args, key)
for _, val := range vals {
args = append(args, val)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "srem")
args = append(args, key)
for _, val := range vals {
args = append(args, val)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sismember")
args = append(args, key)
args = append(args, value)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "smembers")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sdiff")
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sdiffstore")
args = append(args, destination)
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sinter")
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sinterstore")
args = append(args, destination)
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sunion")
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "sunionstore")
args = append(args, destination)
args = append(args, key1)
args = append(args, key2)
return RedisCall(c.cluster, respString(args), callback)
}
// ZSet
func (c *RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zcard")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zadd")
args = append(args, key)
for m, s := range msMap {
args = append(args, s)
args = append(args, m)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zcount")
args = append(args, key)
args = append(args, min)
args = append(args, max)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zincrby")
args = append(args, key)
args = append(args, delta)
args = append(args, member)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zscore")
args = append(args, key)
args = append(args, member)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zrank")
args = append(args, key)
args = append(args, member)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zrevrank")
args = append(args, key)
args = append(args, member)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zrem")
args = append(args, key)
for _, m := range members {
args = append(args, m)
}
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zrange")
args = append(args, key)
args = append(args, start)
args = append(args, stop)
return RedisCall(c.cluster, respString(args), callback)
}
func (c *RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error {
if err := c.checkReadyFunc(); err != nil {
return err
}
args := make([]interface{}, 0)
args = append(args, "zrevrange")
args = append(args, key)
args = append(args, start)
args = append(args, stop)
return RedisCall(c.cluster, respString(args), callback)
}

View File

@@ -1,116 +0,0 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wrapper
import (
"net/url"
"strconv"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
func GetRequestScheme() string {
scheme, err := proxywasm.GetHttpRequestHeader(":scheme")
if err != nil {
proxywasm.LogErrorf("get request scheme failed: %v", err)
return ""
}
return scheme
}
func GetRequestHost() string {
host, err := proxywasm.GetHttpRequestHeader(":authority")
if err != nil {
proxywasm.LogErrorf("get request host failed: %v", err)
return ""
}
return host
}
func GetRequestPath() string {
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
proxywasm.LogErrorf("get request path failed: %v", err)
return ""
}
return path
}
func GetRequestPathWithoutQuery() string {
rawPath := GetRequestPath()
if rawPath == "" {
return ""
}
path, err := url.Parse(rawPath)
if err != nil {
proxywasm.LogErrorf("failed to parse request path '%s': %v", rawPath, err)
return ""
}
return path.Path
}
func GetRequestMethod() string {
method, err := proxywasm.GetHttpRequestHeader(":method")
if err != nil {
proxywasm.LogErrorf("get request method failed: %v", err)
return ""
}
return method
}
func IsBinaryRequestBody() bool {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
if strings.Contains(contentType, "octet-stream") ||
strings.Contains(contentType, "grpc") {
return true
}
encoding, _ := proxywasm.GetHttpRequestHeader("content-encoding")
if encoding != "" {
return true
}
return false
}
func IsBinaryResponseBody() bool {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if strings.Contains(contentType, "octet-stream") ||
strings.Contains(contentType, "grpc") {
return true
}
encoding, _ := proxywasm.GetHttpResponseHeader("content-encoding")
if encoding != "" {
return true
}
return false
}
func HasRequestBody() bool {
contentTypeStr, _ := proxywasm.GetHttpRequestHeader("content-type")
contentLengthStr, _ := proxywasm.GetHttpRequestHeader("content-length")
transferEncodingStr, _ := proxywasm.GetHttpRequestHeader("transfer-encoding")
proxywasm.LogDebugf("check has request body: contentType:%s, contentLengthStr:%s, transferEncodingStr:%s",
contentTypeStr, contentLengthStr, transferEncodingStr)
if contentTypeStr != "" {
return true
}
if contentLengthStr != "" {
contentLength, err := strconv.Atoi(contentLengthStr)
if err == nil && contentLength > 0 {
return true
}
}
return strings.Contains(transferEncodingStr, "chunked")
}

View File

@@ -1,36 +0,0 @@
package wrapper
import (
"encoding/json"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
)
func UnmarshalStr(marshalledJsonStr string) string {
// e.g. "{\"field1\":\"value1\",\"field2\":\"value2\"}"
var jsonStr string
err := json.Unmarshal([]byte(marshalledJsonStr), &jsonStr)
if err != nil {
proxywasm.LogErrorf("failed to unmarshal json string, raw string is: %s, err is: %v", marshalledJsonStr, err)
return ""
}
// e.g. {"field1":"value1","field2":"value2"}
return jsonStr
}
func MarshalStr(raw string) string {
// e.g. {"field1":"value1","field2":"value2"}
helper := map[string]string{
"placeholder": raw,
}
marshalledHelper, _ := json.Marshal(helper)
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
if len(marshalledRaw) >= 2 {
// e.g. {\"field1\":\"value1\",\"field2\":\"value2\"}
return marshalledRaw[1 : len(marshalledRaw)-1]
} else {
proxywasm.LogErrorf("failed to marshal json string, raw string is: %s", raw)
return ""
}
}