test(wasm-go/mcp): expand unit test coverage for mcp-server framework (#3871)

Signed-off-by: jingze <daijingze.djz@alibaba-inc.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Jingze
2026-06-15 20:29:20 +08:00
committed by GitHub
parent c69526b30e
commit bf0b1e96c5
10 changed files with 3426 additions and 3 deletions

View File

@@ -2764,3 +2764,849 @@ data: {"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"Secure
})
})
}
// -----------------------------------------------------------------------------
// Phase 2.1 — REST Server Call() branch matrix
//
// Each sub-test stands up a minimal REST MCP config whose single tool exercises
// a specific branch of rest_server.go:Call (~lines 523-946). The plugin runs the
// real tool, makes one ctx.RouteCall, and we mock the backend reply via
// CallOnHttpResponseHeaders/Body. We inspect the upstream request via
// GetRequestHeaders/Body and the MCP-shaped response via GetResponseBody.
// -----------------------------------------------------------------------------
func TestRestMCPServer_CallBranches(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// -------------------------------------------------------------------
// argsToFormBody → upstream body is form-urlencoded; Content-Type set
// -------------------------------------------------------------------
t.Run("argsToFormBody encodes body as form-urlencoded", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-form", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "submit",
"description": "form submit",
"args": []map[string]interface{}{
{"name": "user", "description": "u", "type": "string", "required": true},
{"name": "msg", "description": "m", "type": "string"},
},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/form",
"method": "POST",
"argsToFormBody": true,
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit","arguments":{"user":"alice 1","msg":"hi&there"}}}`))
upstreamHeaders := host.GetRequestHeaders()
upstreamBody := host.GetRequestBody()
ctHeader, has := test.GetHeaderValue(upstreamHeaders, "Content-Type")
require.True(t, has)
require.Contains(t, ctHeader, "application/x-www-form-urlencoded")
// `&` must be percent-encoded, space encoded as `+`.
require.Contains(t, string(upstreamBody), "user=alice+1")
require.Contains(t, string(upstreamBody), "msg=hi%26there")
host.CompleteHttp()
})
// -------------------------------------------------------------------
// argsToUrlParam → default args merged into URL query
// -------------------------------------------------------------------
t.Run("argsToUrlParam merges args into query", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-qp", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "search",
"description": "search",
"args": []map[string]interface{}{
{"name": "q", "description": "q", "type": "string", "required": true},
{"name": "limit", "description": "lim", "type": "integer"},
},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/search?pre=set",
"method": "GET",
"argsToUrlParam": true,
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"search","arguments":{"q":"hello","limit":10}}}`))
upstreamHeaders := host.GetRequestHeaders()
pathVal, has := test.GetHeaderValue(upstreamHeaders, ":path")
require.True(t, has)
require.Contains(t, pathVal, "pre=set")
require.Contains(t, pathVal, "q=hello")
require.Contains(t, pathVal, "limit=10")
host.CompleteHttp()
})
// -------------------------------------------------------------------
// Direct-response tool → no backend call, response template fires
// -------------------------------------------------------------------
t.Run("direct response tool emits template result", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-dr", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "ping",
"description": "static ping",
"args": []map[string]interface{}{
{"name": "name", "description": "n", "type": "string", "required": true},
},
// No requestTemplate.url → direct-response mode.
"responseTemplate": map[string]interface{}{
"body": "hello {{.args.name}}",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"ping","arguments":{"name":"world"}}}`))
// Direct-response writes via SendLocalResponse, not the streaming response body.
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "direct-response must emit a local response with no backend call")
require.Contains(t, string(localResp.Data), "hello world")
host.CompleteHttp()
})
// -------------------------------------------------------------------
// Image content-type → SendMCPToolImageResult path (base64 in response)
// -------------------------------------------------------------------
t.Run("image content-type produces image MCP result", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-img", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "get_image",
"description": "image fetch",
"args": []map[string]interface{}{{"name": "id", "description": "id", "type": "string"}},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/img/{{.args.id}}", "method": "GET",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_image","arguments":{"id":"42"}}}`))
pngBytes := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "image/png"}})
host.CallOnHttpResponseBody(pngBytes)
respBody := host.GetResponseBody()
require.NotEmpty(t, respBody)
require.Contains(t, string(respBody), `"type":"image"`)
require.Contains(t, string(respBody), `"mimeType":"image/png"`)
host.CompleteHttp()
})
// -------------------------------------------------------------------
// outputSchema + JSON backend → structuredContent populated
// -------------------------------------------------------------------
t.Run("outputSchema with JSON body emits structuredContent", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-os", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "info",
"description": "info",
"args": []map[string]interface{}{},
"outputSchema": map[string]interface{}{"type": "object"},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/info", "method": "GET",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"info","arguments":{}}}`))
host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "application/json"}})
host.CallOnHttpResponseBody([]byte(`{"a":1,"b":"two"}`))
respBody := host.GetResponseBody()
require.NotEmpty(t, respBody)
require.Contains(t, string(respBody), "structuredContent", "outputSchema must produce structuredContent field")
host.CompleteHttp()
})
// -------------------------------------------------------------------
// errorResponseTemplate fires on 4xx/5xx, _headers is accessible
// -------------------------------------------------------------------
t.Run("errorResponseTemplate renders on backend error", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-err", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "fail",
"description": "fail",
"args": []map[string]interface{}{},
"errorResponseTemplate": `upstream said: {{gjson "_headers.:status"}} - {{gjson "message"}}`,
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/fail", "method": "GET",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"fail","arguments":{}}}`))
host.CallOnHttpResponseHeaders([][2]string{{":status", "500"}, {"Content-Type", "application/json"}})
host.CallOnHttpResponseBody([]byte(`{"message":"boom"}`))
respBody := host.GetResponseBody()
require.NotEmpty(t, respBody)
require.Contains(t, string(respBody), "upstream said: 500 - boom")
require.Contains(t, string(respBody), `"isError":true`)
host.CompleteHttp()
})
// -------------------------------------------------------------------
// prependBody + appendBody wrap raw response
// -------------------------------------------------------------------
t.Run("prependBody/appendBody wrap raw response", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-wrap", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "say",
"description": "say",
"args": []map[string]interface{}{},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/say", "method": "GET",
},
"responseTemplate": map[string]interface{}{
"prependBody": "<<",
"appendBody": ">>",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"say","arguments":{}}}`))
host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "text/plain"}})
host.CallOnHttpResponseBody([]byte("hi"))
respBody := host.GetResponseBody()
require.NotEmpty(t, respBody)
// Unmarshal to dodge JSON unicode-escape encoding of < and >.
var parsed map[string]interface{}
require.NoError(t, json.Unmarshal(respBody, &parsed))
result := parsed["result"].(map[string]interface{})
content := result["content"].([]interface{})
text := content[0].(map[string]interface{})["text"].(string)
require.Equal(t, "<<hi>>", text)
host.CompleteHttp()
})
// -------------------------------------------------------------------
// path arg with reserved chars → substituted into URL template
// -------------------------------------------------------------------
t.Run("path arg substitutes into URL placeholder", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-path", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "get_user",
"description": "by id",
"args": []map[string]interface{}{
{"name": "id", "description": "uid", "type": "string", "required": true, "position": "path"},
},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/users/{id}",
"method": "GET",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_user","arguments":{"id":"alice%20smith"}}}`))
upstreamHeaders := host.GetRequestHeaders()
pathVal, has := test.GetHeaderValue(upstreamHeaders, ":path")
require.True(t, has)
// Path arg is substituted as-is into the URL template before being parsed.
require.Contains(t, pathVal, "/users/alice")
host.CompleteHttp()
})
// -------------------------------------------------------------------
// header arg from args list lands as upstream header
// -------------------------------------------------------------------
t.Run("header-position arg becomes upstream header", func(t *testing.T) {
cfg, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{"name": "rest-hdr", "type": "rest"},
"tools": []map[string]interface{}{{
"name": "auth_call",
"description": "with header",
"args": []map[string]interface{}{
{"name": "X-Trace", "description": "trace", "type": "string", "required": true, "position": "header"},
},
"requestTemplate": map[string]interface{}{
"url": "http://backend.example/protected",
"method": "GET",
},
}},
})
host, status := test.NewTestHost(cfg)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"auth_call","arguments":{"X-Trace":"abc-123"}}}`))
upstreamHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(upstreamHeaders, "X-Trace", "abc-123"))
host.CompleteHttp()
})
})
}
// -----------------------------------------------------------------------------
// Phase 2.2 — SSE state machine error / edge paths
//
// Drive sse_proxy.go's handleSSEStreamingResponse + handleWaitingEndpoint /
// handleWaitingInitResp / handleWaitingToolResp through error branches that
// the happy-path TestMcpProxyServerSSE* tests don't cover.
// -----------------------------------------------------------------------------
func TestMcpProxyServerSSE_NonSSEContentTypeRejected(t *testing.T) {
// Backend returns text/plain instead of text/event-stream → first-chunk
// content-type validation must inject a JSON-RPC error.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerSSEConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
// First-chunk: backend returned a non-SSE content-type.
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
host.CallOnHttpStreamingResponseBody([]byte(`{"not":"sse"}`), true)
respBody := host.GetResponseBody()
require.NotEmpty(t, respBody)
require.Contains(t, string(respBody), "error", "rejected non-SSE response must surface a JSON-RPC error")
host.CompleteHttp()
})
}
func TestMcpProxyServerSSE_CharsetSuffixAccepted(t *testing.T) {
// content-type: text/event-stream;charset=utf-8 must still be accepted
// (substring match on text/event-stream).
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerSSEConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream; charset=utf-8"},
})
// Endpoint event triggers initialize — we don't drive past this,
// just confirm content-type-suffix wasn't rejected (no error in body).
host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/abc\n\n"), false)
// No injected error yet — buffer is just being consumed.
// Local response should NOT have been set with an error.
localResp := host.GetLocalResponse()
if localResp != nil {
require.NotContains(t, string(localResp.Data), "invalid content-type",
"text/event-stream with charset suffix must NOT be rejected")
}
host.CompleteHttp()
})
}
func TestMcpProxyServerSSE_EndpointSkipsUnrelatedEvents(t *testing.T) {
// Send a `ping` event before `endpoint` — the state machine must skip
// non-endpoint messages while in WaitingEndpoint and not error out.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerSSEConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}})
// ping first — must be ignored.
host.CallOnHttpStreamingResponseBody([]byte("event: ping\ndata: keep-alive\n\n"), false)
// Then the real endpoint event.
host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/session-xyz\n\n"), false)
// Verify the initialize request was sent upstream (proxy moved past WaitingEndpoint).
callouts := host.GetHttpCalloutAttributes()
require.NotEmpty(t, callouts, "endpoint event should trigger an initialize HTTP callout")
var sawInit bool
for _, c := range callouts {
if strings.Contains(string(c.Body), `"method":"initialize"`) {
sawInit = true
break
}
}
require.True(t, sawInit, "an initialize JSON-RPC call must have been sent after endpoint event")
host.CompleteHttp()
})
}
// -----------------------------------------------------------------------------
// Phase 2.3 — mcp-proxy Initialize callback error paths (proxy_tool.go:107-204)
// -----------------------------------------------------------------------------
func TestMcpProxyServer_InitializeBackend500(t *testing.T) {
// Initialize HTTP callback receives non-2xx status → must inject error response.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
// Backend returns 500 on the initialize call.
host.CallOnHttpCall([][2]string{{":status", "500"}, {"content-type", "application/json"}}, []byte(`{"error":"down"}`))
// Errors go through SendLocalResponse, not the streaming body.
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "non-200 initialize response must inject a local error response")
require.Contains(t, string(localResp.Data), "error")
host.CompleteHttp()
})
}
func TestMcpProxyServer_InitializeMalformedJSON(t *testing.T) {
// Initialize response body is not valid JSON → parse error path.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(`{not valid json`))
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "unparseable initialize response must inject a local error response")
require.Contains(t, string(localResp.Data), "error")
host.CompleteHttp()
})
}
func TestMcpProxyServer_InitializeSSEContentType(t *testing.T) {
// Initialize response carries text/event-stream → parseSSEResponse path.
// We send a valid SSE-wrapped JSON-RPC success body so the proxy unwraps and
// progresses past initialize. End-to-end completion is not the point — we
// just need to drive the SSE-unwrap branch.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
// SSE-wrapped initialize success response — proxy must extract the data line.
sseInit := "event: message\ndata: " + `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{"tools":{}},"serverInfo":{"name":"X","version":"1"}}}` + "\n\n"
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}, {"mcp-session-id", "sse-session"}}, []byte(sseInit))
// If SSE-unwrap worked, the proxy will have moved to sending the initialized notification.
// Look for a 2nd outbound callout — its presence confirms initialize succeeded.
callouts := host.GetHttpCalloutAttributes()
var sawNotification bool
for _, c := range callouts {
if strings.Contains(string(c.Body), "notifications/initialized") {
sawNotification = true
break
}
}
require.True(t, sawNotification, "SSE-wrapped initialize response must be unwrapped so notification fires")
host.CompleteHttp()
})
}
func TestMcpProxyServer_InitializeUnknownErrorCode(t *testing.T) {
// Initialize returns a JSON-RPC error with a code OTHER than -32602
// → generic "backend initialization failed" path.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
// JSON-RPC error with code != -32602.
host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}},
[]byte(`{"jsonrpc":"2.0","id":1,"error":{"code":-32603,"message":"internal"}}`))
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "unknown error code must inject a local error response")
require.Contains(t, string(localResp.Data), "error")
host.CompleteHttp()
})
}
// -----------------------------------------------------------------------------
// Phase 2.4 — plugin.go HOST entry edge cases (onHttpRequestHeaders)
// -----------------------------------------------------------------------------
func TestPlugin_GetMethodRejected(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(restMCPServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "GET"}, {":path", "/mcp"},
})
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "GET must produce a local 405 response")
require.Equal(t, uint32(405), localResp.StatusCode)
host.CompleteHttp()
})
}
func TestPlugin_DeleteMethodRejected(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(restMCPServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "DELETE"}, {":path", "/mcp"},
})
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
localResp := host.GetLocalResponse()
require.NotNil(t, localResp, "DELETE must produce a local 405 response")
require.Equal(t, uint32(405), localResp.StatusCode)
host.CompleteHttp()
})
}
// Note: the wasm-go test harness always reports a request body as present, so
// the "POST with no body → 400" branch is not exercisable here. The dedicated
// pure-test for ctx.HasRequestBody() handling lives in pkg/mcp/server tests.
func TestPlugin_McpProtocolVersionHeaderStripped(t *testing.T) {
// MCP-Protocol-Version is parsed and removed; request continues normally.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(restMCPServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
{"MCP-Protocol-Version", "2025-03-26"},
})
// Valid version: header is consumed, request continues.
require.Equal(t, types.HeaderStopIteration, action)
// The header should have been removed before forwarding.
_, stillPresent := test.GetHeaderValue(host.GetRequestHeaders(), "MCP-Protocol-Version")
require.False(t, stillPresent, "MCP-Protocol-Version header must be stripped from forwarded request")
host.CompleteHttp()
})
}
func TestPlugin_McpProtocolVersionUnsupportedStillStripped(t *testing.T) {
// Unsupported version logs a warning but still strips the header and continues.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(restMCPServerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"},
{"content-type", "application/json"},
{"MCP-Protocol-Version", "9999-99-99"},
})
require.Equal(t, types.HeaderStopIteration, action)
_, stillPresent := test.GetHeaderValue(host.GetRequestHeaders(), "MCP-Protocol-Version")
require.False(t, stillPresent, "even unsupported MCP-Protocol-Version must be stripped")
host.CompleteHttp()
})
}
func TestMcpProxyServerSSE_PartialChunkBuffered(t *testing.T) {
// SSE event arrives split across two chunks — first chunk has no terminator,
// second chunk completes the message. Proxy must wait, not error.
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyServerSSEConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}})
// First chunk: prefix only, no blank line terminator.
host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/sess"), false)
// No callouts yet — message not yet complete.
require.Empty(t, host.GetHttpCalloutAttributes(),
"incomplete chunk must not trigger any upstream calls")
// Second chunk: completes data + blank line.
host.CallOnHttpStreamingResponseBody([]byte("ion-split\n\n"), false)
// Now initialize should have been sent.
callouts := host.GetHttpCalloutAttributes()
require.NotEmpty(t, callouts, "complete endpoint event must trigger initialize")
host.CompleteHttp()
})
}
// -----------------------------------------------------------------------------
// Phase 2.5 — ExtractAndRemoveIncomingCredential (HOST end-to-end)
//
// The pure function lives in pkg/mcp/server/auth_utils.go and is HOST-coupled
// because it calls proxywasm.GetHttpRequestHeader / RemoveHttpRequestHeader.
// We exercise it via the mcp-proxy tools/list path, which is the cleanest entry
// that hits ExtractAndRemoveIncomingCredential before any upstream call.
//
// Server shape: downstreamSecurity with Passthrough=true → the credential
// extracted from the *incoming* request is reused on the *upstream* request.
// Verifying the upstream callout headers proves both the extraction and the
// removal worked.
// -----------------------------------------------------------------------------
// mcpProxyPassthroughApiKeyConfig — downstream apiKey/header passthrough to
// upstream apiKey/header. The two schemes have different header names so we
// can independently assert (a) downstream header was removed and (b) upstream
// header carries the passthrough value.
var mcpProxyPassthroughApiKeyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{
"name": "proxy-passthrough-apikey",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "http://backend-mcp.example.com/mcp",
"timeout": 5000,
"defaultDownstreamSecurity": map[string]interface{}{
"id": "ClientKey",
"passthrough": true,
},
"defaultUpstreamSecurity": map[string]interface{}{
"id": "BackendKey",
},
"securitySchemes": []map[string]interface{}{
{
"id": "ClientKey",
"type": "apiKey",
"in": "header",
"name": "X-Client-Key",
},
{
"id": "BackendKey",
"type": "apiKey",
"in": "header",
"name": "X-Backend-Key",
"defaultCredential": "fallback-default",
},
},
},
"tools": []map[string]interface{}{
{"name": "noop", "type": "mcp-proxy", "description": "noop"},
},
})
return data
}()
func TestExtractAndRemoveIncomingCredential_ApiKeyHeaderPassthrough(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyPassthroughApiKeyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
{"X-Client-Key", "secret-from-client"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
callouts := host.GetHttpCalloutAttributes()
require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout")
init := callouts[0]
// Passthrough credential rides on the upstream scheme's header name.
require.True(t, test.HasHeaderWithValue(init.Headers, "X-Backend-Key", "secret-from-client"),
"upstream initialize must carry the extracted client credential under the upstream header name")
// The downstream header itself must NOT leak through to the upstream call.
if v, present := test.GetHeaderValue(init.Headers, "X-Client-Key"); present {
t.Errorf("downstream credential header X-Client-Key must be removed; got %q", v)
}
host.CompleteHttp()
})
}
// mcpProxyPassthroughBearerConfig — downstream http/bearer passthrough to
// upstream http/bearer. Stripping `Bearer ` from the incoming Authorization
// is part of ExtractAndRemoveIncomingCredential's contract.
var mcpProxyPassthroughBearerConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"server": map[string]interface{}{
"name": "proxy-passthrough-bearer",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "http://backend-mcp.example.com/mcp",
"timeout": 5000,
"defaultDownstreamSecurity": map[string]interface{}{
"id": "ClientBearer",
"passthrough": true,
},
"defaultUpstreamSecurity": map[string]interface{}{
"id": "BackendBearer",
},
"securitySchemes": []map[string]interface{}{
{"id": "ClientBearer", "type": "http", "scheme": "bearer"},
{"id": "BackendBearer", "type": "http", "scheme": "bearer", "defaultCredential": "default-token"},
},
},
"tools": []map[string]interface{}{
{"name": "noop", "type": "mcp-proxy", "description": "noop"},
},
})
return data
}()
func TestExtractAndRemoveIncomingCredential_HttpBearerPassthrough(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyPassthroughBearerConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
{"Authorization", "Bearer client-token-xyz"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
callouts := host.GetHttpCalloutAttributes()
require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout")
init := callouts[0]
// Upstream Authorization must use the EXTRACTED token (Bearer prefix
// stripped on the way in, re-applied on the way out by ApplySecurity).
authValue, present := test.GetHeaderValue(init.Headers, "Authorization")
require.True(t, present, "upstream must carry Authorization for upstream bearer scheme")
require.Equal(t, "Bearer client-token-xyz", authValue,
"passthrough token must round-trip as `Bearer <token>` to upstream")
host.CompleteHttp()
})
}
// Missing downstream header — ExtractAndRemoveIncomingCredential returns ""
// (no error). With Passthrough=true but no incoming credential, passthrough
// is skipped and the upstream scheme falls back to its DefaultCredential.
func TestExtractAndRemoveIncomingCredential_MissingHeaderFallsBackToDefault(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
host, status := test.NewTestHost(mcpProxyPassthroughApiKeyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.InitHttp()
// Note: no X-Client-Key header on the way in.
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "mcp.example.com"},
{":method", "POST"},
{":path", "/mcp"},
{"content-type", "application/json"},
})
host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
callouts := host.GetHttpCalloutAttributes()
require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout")
init := callouts[0]
// Missing client credential → not an error, just fall through to default.
require.True(t, test.HasHeaderWithValue(init.Headers, "X-Backend-Key", "fallback-default"),
"missing client credential must NOT cause an error; upstream falls back to DefaultCredential")
host.CompleteHttp()
})
}

View File

@@ -10,6 +10,8 @@ require (
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
google.golang.org/protobuf v1.36.6
gopkg.in/yaml.v3 v3.0.1
)
require (
@@ -28,12 +30,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
golang.org/x/crypto v0.26.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,73 @@
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,432 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/base64"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubProvider implements SecuritySchemeProvider for ApplySecurity tests.
type stubProvider struct {
schemes map[string]SecurityScheme
}
func (p *stubProvider) GetSecurityScheme(id string) (SecurityScheme, bool) {
s, ok := p.schemes[id]
return s, ok
}
func newProvider(schemes ...SecurityScheme) *stubProvider {
m := make(map[string]SecurityScheme, len(schemes))
for _, s := range schemes {
m[s.ID] = s
}
return &stubProvider{schemes: m}
}
// mustParseURL helps build the ParsedURL field of AuthRequestContext.
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
u, err := url.Parse(raw)
require.NoError(t, err)
return u
}
func findHeader(headers [][2]string, key string) (string, bool) {
for _, kv := range headers {
if strings.EqualFold(kv[0], key) {
return kv[1], true
}
}
return "", false
}
func countHeader(headers [][2]string, key string) int {
c := 0
for _, kv := range headers {
if strings.EqualFold(kv[0], key) {
c++
}
}
return c
}
// -----------------------------------------------------------------------------
// setOrReplaceHeader
// -----------------------------------------------------------------------------
func TestSetOrReplaceHeader_AppendsWhenAbsent(t *testing.T) {
headers := [][2]string{{"X-Other", "1"}}
setOrReplaceHeader(&headers, "X-New", "v")
require.Len(t, headers, 2)
v, ok := findHeader(headers, "X-New")
require.True(t, ok)
assert.Equal(t, "v", v)
}
func TestSetOrReplaceHeader_ReplacesCaseInsensitively(t *testing.T) {
headers := [][2]string{
{"Content-Type", "text/plain"},
{"AUTHORIZATION", "old"},
}
setOrReplaceHeader(&headers, "authorization", "new")
v, ok := findHeader(headers, "Authorization")
require.True(t, ok)
assert.Equal(t, "new", v)
// Replacement is in-place, no duplicate header inserted.
assert.Equal(t, 1, countHeader(headers, "Authorization"))
assert.Len(t, headers, 2)
}
func TestSetOrReplaceHeader_PreservesOriginalKeyOnReplace(t *testing.T) {
headers := [][2]string{{"X-Token", "old"}}
setOrReplaceHeader(&headers, "x-token", "new")
// Replacement updates value but keeps the original key casing.
assert.Equal(t, [][2]string{{"X-Token", "new"}}, headers)
}
func TestSetOrReplaceHeader_FirstMatchWins(t *testing.T) {
headers := [][2]string{
{"X-Dup", "first"},
{"x-dup", "second"},
}
setOrReplaceHeader(&headers, "X-Dup", "new")
// Only the first occurrence is replaced; the second is left alone.
assert.Equal(t, [][2]string{{"X-Dup", "new"}, {"x-dup", "second"}}, headers)
}
func TestSetOrReplaceHeader_IdempotentOnSecondCall(t *testing.T) {
headers := [][2]string{}
setOrReplaceHeader(&headers, "X-K", "v")
setOrReplaceHeader(&headers, "X-K", "v")
require.Len(t, headers, 1)
assert.Equal(t, "v", headers[0][1])
}
// -----------------------------------------------------------------------------
// ApplySecurity — early returns / preconditions
// -----------------------------------------------------------------------------
func TestApplySecurity_EmptyIDIsNoOp(t *testing.T) {
reqCtx := &AuthRequestContext{
Headers: [][2]string{{"X-Other", "x"}},
ParsedURL: mustParseURL(t, "/p?a=1"),
}
err := ApplySecurity(SecurityRequirement{}, newProvider(), reqCtx)
require.NoError(t, err)
assert.Equal(t, [][2]string{{"X-Other", "x"}}, reqCtx.Headers)
assert.Equal(t, "a=1", reqCtx.ParsedURL.RawQuery)
}
func TestApplySecurity_NilParsedURLReturnsError(t *testing.T) {
reqCtx := &AuthRequestContext{}
err := ApplySecurity(
SecurityRequirement{ID: "x"},
newProvider(SecurityScheme{ID: "x", Type: "apiKey", In: "header", Name: "X"}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "ParsedURL")
}
func TestApplySecurity_SchemeIDNotFound(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(SecurityRequirement{ID: "missing"}, newProvider(), reqCtx)
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
// -----------------------------------------------------------------------------
// ApplySecurity — apiKey × {header, query}
// -----------------------------------------------------------------------------
func TestApplySecurity_ApiKey_Header_DefaultCredential(t *testing.T) {
reqCtx := &AuthRequestContext{
Headers: [][2]string{{"X-Other", "x"}},
ParsedURL: mustParseURL(t, "/p"),
}
err := ApplySecurity(
SecurityRequirement{ID: "K"},
newProvider(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "def",
}),
reqCtx,
)
require.NoError(t, err)
v, ok := findHeader(reqCtx.Headers, "X-Api-Key")
require.True(t, ok)
assert.Equal(t, "def", v)
}
func TestApplySecurity_ApiKey_Header_ExplicitOverridesDefault(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "override"},
newProvider(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "def",
}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "X-Api-Key")
assert.Equal(t, "override", v)
}
func TestApplySecurity_ApiKey_Header_PassthroughBeatsExplicitAndDefault(t *testing.T) {
reqCtx := &AuthRequestContext{
ParsedURL: mustParseURL(t, "/p"),
PassthroughCredential: "from-client",
}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "configured"},
newProvider(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "def",
}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "X-Api-Key")
assert.Equal(t, "from-client", v, "passthrough wins over configured + default")
}
func TestApplySecurity_ApiKey_Header_ReplacesExisting(t *testing.T) {
reqCtx := &AuthRequestContext{
Headers: [][2]string{{"x-api-key", "stale"}},
ParsedURL: mustParseURL(t, "/p"),
}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "fresh"},
newProvider(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
}),
reqCtx,
)
require.NoError(t, err)
// Case-insensitive replace, no duplicate header.
assert.Equal(t, 1, countHeader(reqCtx.Headers, "X-Api-Key"))
v, _ := findHeader(reqCtx.Headers, "X-Api-Key")
assert.Equal(t, "fresh", v)
}
func TestApplySecurity_ApiKey_NoCredentialAvailable(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "no credential")
}
func TestApplySecurity_ApiKey_Header_MissingName(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "v"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: ""}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "name")
}
func TestApplySecurity_ApiKey_Query_AppendsToExistingQuery(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p?existing=1")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "secret"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}),
reqCtx,
)
require.NoError(t, err)
q := reqCtx.ParsedURL.Query()
assert.Equal(t, "secret", q.Get("api_key"))
assert.Equal(t, "1", q.Get("existing"), "existing query params must be preserved")
}
func TestApplySecurity_ApiKey_Query_NoExistingQuery(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "secret"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}),
reqCtx,
)
require.NoError(t, err)
assert.Equal(t, "api_key=secret", reqCtx.ParsedURL.RawQuery)
}
func TestApplySecurity_ApiKey_Query_OverwritesExistingValue(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p?api_key=stale&keep=me")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "fresh"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}),
reqCtx,
)
require.NoError(t, err)
q := reqCtx.ParsedURL.Query()
assert.Equal(t, "fresh", q.Get("api_key"))
assert.Equal(t, "me", q.Get("keep"))
// Sanity: no duplicate api_key entries.
assert.Len(t, q["api_key"], 1)
}
func TestApplySecurity_ApiKey_Query_MissingName(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "v"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: ""}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "name")
}
func TestApplySecurity_ApiKey_UnsupportedIn(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "K", Credential: "v"},
newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "cookie", Name: "X"}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported apiKey")
}
// -----------------------------------------------------------------------------
// ApplySecurity — http × {bearer, basic}
// -----------------------------------------------------------------------------
func TestApplySecurity_HttpBearer_AddsPrefix(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "raw-token"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
assert.Equal(t, "Bearer raw-token", v)
}
func TestApplySecurity_HttpBearer_RespectsExistingPrefix(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "Bearer already-prefixed"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
assert.Equal(t, "Bearer already-prefixed", v, "must not double-prefix")
}
func TestApplySecurity_HttpBasic_UserPassEncoded(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "alice:s3cret"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
expected := "Basic " + base64.StdEncoding.EncodeToString([]byte("alice:s3cret"))
assert.Equal(t, expected, v)
}
func TestApplySecurity_HttpBasic_PreEncodedToken(t *testing.T) {
// No colon → treated as already-base64 token.
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "QWxpY2U6czNjcmV0"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
assert.Equal(t, "Basic QWxpY2U6czNjcmV0", v)
}
func TestApplySecurity_HttpBasic_RespectsExistingPrefix(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "Basic ZXhpc3Rpbmc="},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
assert.Equal(t, "Basic ZXhpc3Rpbmc=", v, "must not re-encode already-prefixed value")
}
func TestApplySecurity_HttpBasic_PassthroughTreatedAsTokenPart(t *testing.T) {
reqCtx := &AuthRequestContext{
ParsedURL: mustParseURL(t, "/p"),
PassthroughCredential: "QWxpY2U6czNjcmV0", // base64-encoded "alice:s3cret"
}
err := ApplySecurity(
SecurityRequirement{ID: "B"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
// Passthrough path must NOT re-base64-encode; only adds the prefix.
assert.Equal(t, "Basic QWxpY2U6czNjcmV0", v)
}
func TestApplySecurity_HttpBearer_Passthrough(t *testing.T) {
reqCtx := &AuthRequestContext{
ParsedURL: mustParseURL(t, "/p"),
PassthroughCredential: "client-token",
}
err := ApplySecurity(
SecurityRequirement{ID: "B"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}),
reqCtx,
)
require.NoError(t, err)
v, _ := findHeader(reqCtx.Headers, "Authorization")
assert.Equal(t, "Bearer client-token", v)
}
func TestApplySecurity_HttpUnsupportedScheme(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "x"},
newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "digest"}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported http scheme")
}
func TestApplySecurity_UnsupportedSchemeType(t *testing.T) {
reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")}
err := ApplySecurity(
SecurityRequirement{ID: "B", Credential: "x"},
newProvider(SecurityScheme{ID: "B", Type: "oauth2"}),
reqCtx,
)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported security scheme type")
}

View File

@@ -0,0 +1,251 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubTool is a minimal Tool implementation for registry population in tests.
type stubTool struct {
desc string
input map[string]any
output map[string]any
}
func (s *stubTool) Create(_ []byte) Tool { return s }
func (s *stubTool) Call(_ HttpContext, _ Server) error { return nil }
func (s *stubTool) Description() string { return s.desc }
func (s *stubTool) InputSchema() map[string]any { return s.input }
func (s *stubTool) OutputSchema() map[string]any { return s.output }
func newPopulatedRegistry(t *testing.T) *GlobalToolRegistry {
t.Helper()
r := &GlobalToolRegistry{}
r.Initialize()
r.RegisterTool("alpha", "search", &stubTool{
desc: "alpha search",
input: map[string]any{"type": "object", "props": "a"},
output: map[string]any{"type": "string"},
})
r.RegisterTool("alpha", "fetch", &stubTool{
desc: "alpha fetch",
input: map[string]any{"type": "object", "props": "f"},
output: nil,
})
r.RegisterTool("beta", "search", &stubTool{
desc: "beta search",
input: map[string]any{"type": "object", "props": "bs"},
output: map[string]any{"type": "array"},
})
return r
}
func TestComposedMCPServer_NewAndGetName(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("myset", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search"}},
}, r)
require.NotNil(t, cs)
assert.Equal(t, "myset", cs.GetName())
}
func TestComposedMCPServer_AddMCPTool_IsNoOp(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("set", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search"}},
}, r)
// AddMCPTool should not panic and should be a no-op (tool not added).
ret := cs.AddMCPTool("ignored", &stubTool{desc: "x"})
assert.Same(t, cs, ret, "AddMCPTool should return the server itself")
tools := cs.GetMCPTools()
_, exists := tools["ignored"]
assert.False(t, exists, "no-op AddMCPTool must not register the tool")
// Only the one from registry should remain.
_, found := tools["alpha___search"]
assert.True(t, found, "registered tool should be present")
assert.Len(t, tools, 1)
}
func TestComposedMCPServer_GetMCPTools_AggregatesWithPrefix(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("compound", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search", "fetch"}},
{ServerName: "beta", Tools: []string{"search"}},
}, r)
tools := cs.GetMCPTools()
require.Len(t, tools, 3)
// All keys must be prefixed with the original server name and the splitter.
want := []string{"alpha___search", "alpha___fetch", "beta___search"}
for _, k := range want {
_, ok := tools[k]
assert.True(t, ok, "expected composed tool key %q", k)
}
// Descriptions / input schemas are forwarded from the registry's ToolInfo.
dt, ok := tools["alpha___search"].(*DescriptiveTool)
require.True(t, ok)
assert.Equal(t, "alpha search", dt.Description())
assert.Equal(t, "a", dt.InputSchema()["props"])
assert.Equal(t, "string", dt.OutputSchema()["type"])
// Tool without OutputSchema in registry produces a DescriptiveTool with nil output.
dt2, ok := tools["alpha___fetch"].(*DescriptiveTool)
require.True(t, ok)
assert.Nil(t, dt2.OutputSchema())
}
func TestComposedMCPServer_GetMCPTools_MissingToolIsSkipped(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("set", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search", "nonexistent"}},
{ServerName: "ghost", Tools: []string{"any"}}, // entire server missing
}, r)
tools := cs.GetMCPTools()
// Only "alpha___search" survives; missing ones are logged and skipped.
assert.Len(t, tools, 1)
_, ok := tools["alpha___search"]
assert.True(t, ok)
}
func TestComposedMCPServer_GetMCPTools_SameSimpleNameDifferentServersDoNotCollide(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("set", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search"}},
{ServerName: "beta", Tools: []string{"search"}},
}, r)
tools := cs.GetMCPTools()
require.Len(t, tools, 2)
assert.Contains(t, tools, "alpha___search")
assert.Contains(t, tools, "beta___search")
}
func TestComposedMCPServer_GetMCPTools_EmptyConfig(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("empty", nil, r)
tools := cs.GetMCPTools()
assert.NotNil(t, tools, "should return a non-nil empty map")
assert.Empty(t, tools)
}
func TestComposedMCPServer_SetGetConfig_BytePointer(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("set", nil, r)
// Empty config: GetConfig must not modify the destination.
var dst []byte
dst = []byte("untouched")
cs.GetConfig(&dst)
assert.Equal(t, []byte("untouched"), dst, "GetConfig on empty config must be a no-op")
// After SetConfig, byte-pointer destinations receive the stored bytes.
cs.SetConfig([]byte(`{"k":"v"}`))
var out []byte
cs.GetConfig(&out)
assert.Equal(t, []byte(`{"k":"v"}`), out)
}
func TestComposedMCPServer_GetConfig_UnhandledDestinationType(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("set", nil, r)
cs.SetConfig([]byte(`{"k":"v"}`))
// Non-byte-pointer destinations are logged and left untouched (no panic).
var s string = "untouched"
cs.GetConfig(&s)
assert.Equal(t, "untouched", s)
type holder struct{ K string }
h := holder{K: "untouched"}
cs.GetConfig(&h)
assert.Equal(t, "untouched", h.K)
}
func TestComposedMCPServer_Clone_IndependentConfig(t *testing.T) {
r := newPopulatedRegistry(t)
cs := NewComposedMCPServer("orig", []ServerToolConfig{
{ServerName: "alpha", Tools: []string{"search"}},
}, r)
cs.SetConfig([]byte(`{"a":1}`))
clonedI := cs.Clone()
require.NotNil(t, clonedI)
cloned, ok := clonedI.(*ComposedMCPServer)
require.True(t, ok)
assert.NotSame(t, cs, cloned, "Clone must return a new struct pointer")
assert.Equal(t, cs.GetName(), cloned.GetName())
// Confirm both see the same config initially.
var origBytes, clonedBytes []byte
cs.GetConfig(&origBytes)
cloned.GetConfig(&clonedBytes)
assert.Equal(t, origBytes, clonedBytes)
// Mutating clone's config must not propagate to original.
cloned.SetConfig([]byte(`{"a":2}`))
cs.GetConfig(&origBytes)
assert.Equal(t, []byte(`{"a":1}`), origBytes, "original config must remain unchanged after cloning")
// Cloned still resolves tools through the shared registry.
assert.Contains(t, cloned.GetMCPTools(), "alpha___search")
}
func TestDescriptiveTool_Create_ReturnsNewInstanceWithSameFields(t *testing.T) {
dt := &DescriptiveTool{
description: "d",
inputSchema: map[string]any{"k": "v"},
outputSchema: map[string]any{"o": "w"},
}
created := dt.Create([]byte(`{"ignored":true}`))
require.NotNil(t, created)
cdt, ok := created.(*DescriptiveTool)
require.True(t, ok)
assert.NotSame(t, dt, cdt, "Create must return a new instance")
assert.Equal(t, dt.Description(), cdt.Description())
assert.Equal(t, dt.InputSchema(), cdt.InputSchema())
assert.Equal(t, dt.OutputSchema(), cdt.OutputSchema())
}
func TestDescriptiveTool_Call_ReturnsError(t *testing.T) {
dt := &DescriptiveTool{description: "d"}
err := dt.Call(nil, nil)
require.Error(t, err, "DescriptiveTool.Call is a guard rail — must return an error")
}
func TestDescriptiveTool_Accessors(t *testing.T) {
dt := &DescriptiveTool{
description: "desc",
inputSchema: map[string]any{"in": 1},
outputSchema: map[string]any{"out": 2},
}
assert.Equal(t, "desc", dt.Description())
assert.Equal(t, map[string]any{"in": 1}, dt.InputSchema())
assert.Equal(t, map[string]any{"out": 2}, dt.OutputSchema())
// Nil schemas must round-trip as nil.
empty := &DescriptiveTool{}
assert.Equal(t, "", empty.Description())
assert.Nil(t, empty.InputSchema())
assert.Nil(t, empty.OutputSchema())
}

View File

@@ -0,0 +1,620 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// -----------------------------------------------------------------------------
// validateURL
// -----------------------------------------------------------------------------
func TestValidateURL(t *testing.T) {
cases := []struct {
name string
in string
wantErr string // empty = expect no error
}{
{"empty string", "", "cannot be empty"},
{"path only", "/api/foo", ""},
{"http with host", "http://backend.example/mcp", ""},
{"https with host", "https://backend.example/mcp", ""},
{"http with userinfo", "http://user:pass@backend.example/mcp", ""},
{"http with port", "http://backend.example:8080/mcp", ""},
{"scheme without host", "http://", "must include a host"},
{"unsupported scheme ftp", "ftp://example/x", "unsupported URL scheme"},
{"unsupported scheme ws", "ws://example/x", "unsupported URL scheme"},
{"contains space - parse error", "http://exa mple/x", "invalid URL format"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := validateURL(c.in)
if c.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), c.wantErr)
}
})
}
}
// -----------------------------------------------------------------------------
// computeEffectiveAllowToolsFromHeader (4-way matrix + edge cases)
// -----------------------------------------------------------------------------
func TestComputeEffectiveAllowToolsFromHeader_BothAbsent(t *testing.T) {
got := computeEffectiveAllowToolsFromHeader(nil, "", false)
assert.Nil(t, got, "no restrictions on either side → allow all (nil)")
}
func TestComputeEffectiveAllowToolsFromHeader_HeaderOnly(t *testing.T) {
got := computeEffectiveAllowToolsFromHeader(nil, "a,b,c", true)
require.NotNil(t, got)
assert.Len(t, *got, 3)
_, hasA := (*got)["a"]
_, hasB := (*got)["b"]
_, hasC := (*got)["c"]
assert.True(t, hasA && hasB && hasC)
}
func TestComputeEffectiveAllowToolsFromHeader_ConfigOnly(t *testing.T) {
cfg := map[string]struct{}{"x": {}, "y": {}}
got := computeEffectiveAllowToolsFromHeader(&cfg, "", false)
require.NotNil(t, got)
assert.Equal(t, &cfg, got, "config restrictions returned as-is when header absent")
}
func TestComputeEffectiveAllowToolsFromHeader_BothPresent_Intersection(t *testing.T) {
cfg := map[string]struct{}{"a": {}, "b": {}, "c": {}}
got := computeEffectiveAllowToolsFromHeader(&cfg, "b,c,d", true)
require.NotNil(t, got)
assert.Len(t, *got, 2)
_, hasB := (*got)["b"]
_, hasC := (*got)["c"]
_, hasD := (*got)["d"]
_, hasA := (*got)["a"]
assert.True(t, hasB && hasC, "intersection keeps common entries")
assert.False(t, hasD, "header-only entries are dropped")
assert.False(t, hasA, "config-only entries are dropped")
}
func TestComputeEffectiveAllowToolsFromHeader_HeaderEmptyStringButPresent(t *testing.T) {
// headerExists=true with empty string → produces empty map (deny all)
cfg := map[string]struct{}{"x": {}}
got := computeEffectiveAllowToolsFromHeader(&cfg, "", true)
require.NotNil(t, got)
assert.Empty(t, *got, "empty header with headerExists=true intersects to empty set")
}
func TestComputeEffectiveAllowToolsFromHeader_HeaderWhitespaceAndDuplicates(t *testing.T) {
got := computeEffectiveAllowToolsFromHeader(nil, " a , b , ,a, c ,", true)
require.NotNil(t, got)
// "a", "b", "c" — duplicates and empties dropped
assert.Len(t, *got, 3)
for _, k := range []string{"a", "b", "c"} {
_, ok := (*got)[k]
assert.True(t, ok, "missing %q", k)
}
}
// -----------------------------------------------------------------------------
// McpServerConfig accessors
// -----------------------------------------------------------------------------
func TestMcpServerConfig_GetServerName(t *testing.T) {
c := &McpServerConfig{serverName: "my-server"}
assert.Equal(t, "my-server", c.GetServerName())
}
func TestMcpServerConfig_GetIsComposed(t *testing.T) {
c1 := &McpServerConfig{isComposed: false}
c2 := &McpServerConfig{isComposed: true}
assert.False(t, c1.GetIsComposed())
assert.True(t, c2.GetIsComposed())
}
// -----------------------------------------------------------------------------
// GlobalToolRegistry — extra branches not covered elsewhere
// -----------------------------------------------------------------------------
// stubToolWithOutputSchema exercises the ToolWithOutputSchema dispatch in
// GlobalToolRegistry.RegisterTool.
type stubToolWithOutputSchema struct {
desc string
input map[string]any
output map[string]any
}
func (s *stubToolWithOutputSchema) Create(_ []byte) Tool { return s }
func (s *stubToolWithOutputSchema) Call(_ HttpContext, _ Server) error { return nil }
func (s *stubToolWithOutputSchema) Description() string { return s.desc }
func (s *stubToolWithOutputSchema) InputSchema() map[string]any { return s.input }
func (s *stubToolWithOutputSchema) OutputSchema() map[string]any { return s.output }
func TestGlobalToolRegistry_RegisterTool_CapturesOutputSchema(t *testing.T) {
r := &GlobalToolRegistry{}
r.Initialize()
r.RegisterTool("srv", "tool", &stubToolWithOutputSchema{
desc: "d",
input: map[string]any{"in": 1},
output: map[string]any{"out": 2},
})
info, ok := r.GetToolInfo("srv", "tool")
require.True(t, ok)
assert.Equal(t, "d", info.Description)
assert.Equal(t, map[string]any{"in": 1}, info.InputSchema)
assert.Equal(t, map[string]any{"out": 2}, info.OutputSchema, "OutputSchema must be captured when tool implements ToolWithOutputSchema")
}
func TestGlobalToolRegistry_RegisterTool_PlainToolHasNoOutputSchema(t *testing.T) {
r := &GlobalToolRegistry{}
r.Initialize()
r.RegisterTool("srv", "tool", &stubTool{desc: "d", input: map[string]any{"in": 1}})
info, ok := r.GetToolInfo("srv", "tool")
require.True(t, ok)
assert.Nil(t, info.OutputSchema)
}
func TestGlobalToolRegistry_GetToolInfo_Misses(t *testing.T) {
r := &GlobalToolRegistry{}
r.Initialize()
_, ok := r.GetToolInfo("missing", "any")
assert.False(t, ok, "unknown server → not found")
r.RegisterTool("srv", "real", &stubTool{desc: "d"})
_, ok = r.GetToolInfo("srv", "missing")
assert.False(t, ok, "unknown tool on known server → not found")
}
// -----------------------------------------------------------------------------
// AddMCPServer / addMCPServerOption.Apply
// -----------------------------------------------------------------------------
func TestAddMCPServer_FirstAndSecondAreStored(t *testing.T) {
ctx := &Context{}
AddMCPServer("alpha", &stubServer{}).Apply(ctx)
AddMCPServer("beta", &stubServer{}).Apply(ctx)
require.Len(t, ctx.servers, 2)
_, hasA := ctx.servers["alpha"]
_, hasB := ctx.servers["beta"]
assert.True(t, hasA && hasB)
}
func TestAddMCPServer_DuplicateNamePanics(t *testing.T) {
ctx := &Context{}
AddMCPServer("dup", &stubServer{}).Apply(ctx)
assert.PanicsWithValue(t,
"Conflict! There is a mcp server with the same name:dup",
func() { AddMCPServer("dup", &stubServer{}).Apply(ctx) })
}
// stubServer satisfies the Server interface for AddMCPServer tests.
type stubServer struct {
cfg []byte
}
func (s *stubServer) AddMCPTool(_ string, _ Tool) Server { return s }
func (s *stubServer) GetMCPTools() map[string]Tool { return map[string]Tool{} }
func (s *stubServer) SetConfig(c []byte) { s.cfg = c }
func (s *stubServer) GetConfig(_ any) {}
func (s *stubServer) Clone() Server { copy := *s; return &copy }
// -----------------------------------------------------------------------------
// ToInputSchema — exercises jsonschema.Reflect dispatch
// -----------------------------------------------------------------------------
type sampleStruct struct {
Name string `json:"name"`
Count int `json:"count"`
Tags []string `json:"tags"`
}
func TestToInputSchema_StructByValue(t *testing.T) {
out := ToInputSchema(sampleStruct{})
require.NotNil(t, out, "must return a populated schema map")
// Reflected schema always has a top-level "properties" map.
props, ok := out["properties"].(map[string]any)
require.True(t, ok, "schema should have a properties object: %v", out)
for _, k := range []string{"name", "count", "tags"} {
_, exists := props[k]
assert.True(t, exists, "missing property %q", k)
}
}
func TestToInputSchema_StructByPointer(t *testing.T) {
// The function dereferences pointer types before name lookup.
out := ToInputSchema(&sampleStruct{})
require.NotNil(t, out)
_, ok := out["properties"]
assert.True(t, ok, "pointer-to-struct should resolve to the same schema")
}
// -----------------------------------------------------------------------------
// setupMcpProxyServer — error paths
// -----------------------------------------------------------------------------
func mustGJSON(t *testing.T, raw string) gjson.Result {
t.Helper()
r := gjson.Parse(raw)
require.True(t, r.Exists(), "raw must parse: %s", raw)
return r
}
func TestSetupMcpProxyServer_MissingTransport(t *testing.T) {
j := mustGJSON(t, `{"name":"s","type":"mcp-proxy","mcpServerURL":"http://b"}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "transport")
}
func TestSetupMcpProxyServer_InvalidTransport(t *testing.T) {
j := mustGJSON(t, `{"transport":"grpc","mcpServerURL":"http://b"}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid transport value")
}
func TestSetupMcpProxyServer_MissingMcpServerURL(t *testing.T) {
j := mustGJSON(t, `{"transport":"http"}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "mcpServerURL is required")
}
func TestSetupMcpProxyServer_InvalidMcpServerURL(t *testing.T) {
j := mustGJSON(t, `{"transport":"http","mcpServerURL":"ws://nope"}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid mcpServerURL")
}
func TestSetupMcpProxyServer_BadSecuritySchemeJson(t *testing.T) {
j := mustGJSON(t, `{
"transport":"http",
"mcpServerURL":"http://b",
"securitySchemes":[123]
}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "security scheme")
}
func TestSetupMcpProxyServer_BadDefaultDownstreamSecurity(t *testing.T) {
j := mustGJSON(t, `{
"transport":"http",
"mcpServerURL":"http://b",
"defaultDownstreamSecurity": 42
}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "defaultDownstreamSecurity")
}
func TestSetupMcpProxyServer_BadDefaultUpstreamSecurity(t *testing.T) {
j := mustGJSON(t, `{
"transport":"http",
"mcpServerURL":"http://b",
"defaultUpstreamSecurity": "not-an-object"
}`)
_, err := setupMcpProxyServer("s", j, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "defaultUpstreamSecurity")
}
func TestSetupMcpProxyServer_HappyPath_AppliesAllFields(t *testing.T) {
raw := `{
"transport":"sse",
"mcpServerURL":"https://backend.example/mcp",
"timeout":7777,
"passthroughAuthHeader":true,
"securitySchemes":[
{"id":"K","type":"apiKey","in":"header","name":"X-K","defaultCredential":"d"}
],
"defaultDownstreamSecurity":{"id":"K"},
"defaultUpstreamSecurity":{"id":"K"}
}`
srv, err := setupMcpProxyServer("alpha", mustGJSON(t, raw), `{"cfg":1}`)
require.NoError(t, err)
require.NotNil(t, srv)
assert.Equal(t, "alpha", srv.Name)
assert.Equal(t, TransportSSE, srv.GetTransport())
assert.Equal(t, "https://backend.example/mcp", srv.GetMcpServerURL())
assert.Equal(t, 7777, srv.GetTimeout())
assert.True(t, srv.GetPassthroughAuthHeader())
_, ok := srv.GetSecurityScheme("K")
assert.True(t, ok)
assert.Equal(t, "K", srv.GetDefaultDownstreamSecurity().ID)
assert.Equal(t, "K", srv.GetDefaultUpstreamSecurity().ID)
}
// -----------------------------------------------------------------------------
// parseConfigCore — error / branch paths via ParseConfigCore
// -----------------------------------------------------------------------------
func newValidationOpts() *ConfigOptions {
r := &GlobalToolRegistry{}
r.Initialize()
return &ConfigOptions{
Servers: make(map[string]Server),
ToolRegistry: r,
SkipPreRegisteredServers: false,
}
}
func TestParseConfigCore_NoServerOrToolSet(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "'server' or 'toolSet'")
}
func TestParseConfigCore_SingleServer_MissingName(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{"server":{"type":"rest"}}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "server.name")
}
func TestParseConfigCore_PreRegisteredNotInRegistry(t *testing.T) {
// type=="" defaults to "rest", but with no tools and no entry in opts.Servers,
// falls into the "pre-registered" branch which fails with "not registered".
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{"server":{"name":"ghost"}}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
}
func TestParseConfigCore_PreRegisteredSkipped(t *testing.T) {
c := &McpServerConfig{}
opts := newValidationOpts()
opts.SkipPreRegisteredServers = true
err := ParseConfigCore(gjson.Parse(`{"server":{"name":"ghost"}}`), c, opts)
require.NoError(t, err, "skip flag should bypass the not-registered error")
assert.Equal(t, "ghost", c.GetServerName())
assert.Nil(t, c.server, "no server instance is constructed in skip mode")
}
func TestParseConfigCore_PreRegisteredFound(t *testing.T) {
c := &McpServerConfig{}
opts := newValidationOpts()
opts.Servers["found"] = &stubServer{}
err := ParseConfigCore(gjson.Parse(`{"server":{"name":"found"}}`), c, opts)
require.NoError(t, err)
require.NotNil(t, c.server, "Clone() of pre-registered server should be stored")
}
func TestParseConfigCore_McpProxy_BubblesUpSetupError(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{"name":"p","type":"mcp-proxy","transport":"http"}
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "mcpServerURL")
}
func TestParseConfigCore_McpProxy_HappyPath_NoTools(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{
"name":"p",
"type":"mcp-proxy",
"transport":"http",
"mcpServerURL":"http://b"
}
}`), c, newValidationOpts())
require.NoError(t, err)
require.NotNil(t, c.server)
assert.Equal(t, "p", c.GetServerName())
assert.False(t, c.GetIsComposed())
// Method handlers are populated for all servers.
assert.NotNil(t, c.methodHandlers["ping"])
assert.NotNil(t, c.methodHandlers["initialize"])
assert.NotNil(t, c.methodHandlers["tools/list"])
assert.NotNil(t, c.methodHandlers["tools/call"])
}
func TestParseConfigCore_McpProxy_BadProxyToolJson(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{
"name":"p",
"type":"mcp-proxy",
"transport":"http",
"mcpServerURL":"http://b"
},
"tools":[42]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "proxy tool")
}
func TestParseConfigCore_REST_BadToolJson(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{"name":"r"},
"tools":[42]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "tool config")
}
func TestParseConfigCore_REST_BadSecuritySchemeJson(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{
"name":"r",
"securitySchemes":[42]
},
"tools":[
{"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}}
]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "security scheme")
}
func TestParseConfigCore_REST_BadDefaultDownstreamSecurity(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{
"name":"r",
"defaultDownstreamSecurity":42
},
"tools":[
{"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}}
]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "defaultDownstreamSecurity")
}
func TestParseConfigCore_REST_BadDefaultUpstreamSecurity(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{
"name":"r",
"defaultUpstreamSecurity":"oops"
},
"tools":[
{"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}}
]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "defaultUpstreamSecurity")
}
func TestParseConfigCore_REST_AddRestToolError(t *testing.T) {
// Setting two of {argsToJsonBody, argsToUrlParam, argsToFormBody} bubbles
// the parseTemplates error out through AddRestTool.
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{"name":"r"},
"tools":[
{
"name":"bad",
"description":"d",
"requestTemplate":{
"url":"/x",
"method":"POST",
"argsToJsonBody":true,
"argsToFormBody":true
}
}
]
}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "argsTo")
}
func TestParseConfigCore_REST_HappyPath_AllowToolsParsed(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"server":{"name":"r"},
"tools":[
{"name":"t1","description":"d","requestTemplate":{"url":"/x","method":"GET"}},
{"name":"t2","description":"d","requestTemplate":{"url":"/y","method":"GET"}}
],
"allowTools":["t1"]
}`), c, newValidationOpts())
require.NoError(t, err)
require.NotNil(t, c.server)
assert.Equal(t, "r", c.GetServerName())
assert.False(t, c.GetIsComposed())
}
func TestParseConfigCore_ToolSet_BadJson(t *testing.T) {
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{"toolSet":42}`), c, newValidationOpts())
require.Error(t, err)
assert.Contains(t, err.Error(), "toolSet")
}
func TestParseConfigCore_ToolSet_HappyPath(t *testing.T) {
opts := newValidationOpts()
// Register one tool so the composed server has something to aggregate.
opts.ToolRegistry.RegisterTool("alpha", "search", &stubTool{
desc: "alpha search",
input: map[string]any{"type": "object"},
})
c := &McpServerConfig{}
err := ParseConfigCore(gjson.Parse(`{
"toolSet":{
"name":"compound",
"serverTools":[{"serverName":"alpha","tools":["search"]}]
}
}`), c, opts)
require.NoError(t, err)
assert.True(t, c.GetIsComposed(), "toolSet must produce a composed server")
assert.Equal(t, "compound", c.GetServerName(), "composed server uses toolSet.name as serverName")
require.NotNil(t, c.server)
}
// -----------------------------------------------------------------------------
// GetServerFromGlobalContext — exercises the package-level singleton
// -----------------------------------------------------------------------------
func TestGetServerFromGlobalContext(t *testing.T) {
// Snapshot and restore globalContext to keep tests independent.
saved := globalContext
defer func() { globalContext = saved }()
globalContext = Context{servers: map[string]Server{
"existing": &stubServer{},
}}
got, ok := GetServerFromGlobalContext("existing")
require.True(t, ok)
assert.NotNil(t, got)
_, miss := GetServerFromGlobalContext("missing")
assert.False(t, miss)
}
// -----------------------------------------------------------------------------
// BaseMCPServer.Clone / CloneBase
// -----------------------------------------------------------------------------
func TestBaseMCPServer_Clone_PanicsByContract(t *testing.T) {
// Derived types must implement Clone; BaseMCPServer panics to enforce that.
b := NewBaseMCPServer()
assert.PanicsWithValue(t,
"Clone method must be implemented by derived types",
func() { _ = b.Clone() })
}
func TestBaseMCPServer_CloneBase_DeepCopiesTools(t *testing.T) {
b := NewBaseMCPServer()
b.SetConfig([]byte(`{"k":1}`))
stub := &stubTool{desc: "t"}
b.AddMCPTool("a", stub)
cloned := b.CloneBase()
// Mutating clone's tools must not bleed into the original.
cloned.AddMCPTool("b", &stubTool{desc: "x"})
assert.Len(t, b.GetMCPTools(), 1, "original tools must remain untouched")
assert.Len(t, cloned.GetMCPTools(), 2)
// Same config bytes are preserved.
got, ok := cloned.GetMCPTools()["a"]
require.True(t, ok)
assert.Same(t, stub, got, "existing tools are shared by reference (no deep clone of Tool itself)")
}

View File

@@ -18,6 +18,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMcpProxyServerBasicInterface tests that McpProxyServer implements the Server interface
@@ -110,3 +111,329 @@ func TestMcpProxyServerSecuritySchemes(t *testing.T) {
assert.Equal(t, scheme.ID, retrievedScheme.ID)
assert.Equal(t, scheme.Type, retrievedScheme.Type)
}
// -----------------------------------------------------------------------------
// SetDefaultDownstreamSecurity / SetDefaultUpstreamSecurity / PassthroughAuth
// -----------------------------------------------------------------------------
func TestMcpProxyServer_SetGetDefaultDownstreamSecurity(t *testing.T) {
s := NewMcpProxyServer("p")
assert.Equal(t, "", s.GetDefaultDownstreamSecurity().ID, "fresh server has empty default")
s.SetDefaultDownstreamSecurity(SecurityRequirement{ID: "K", Passthrough: true})
got := s.GetDefaultDownstreamSecurity()
assert.Equal(t, "K", got.ID)
assert.True(t, got.Passthrough)
}
func TestMcpProxyServer_SetGetDefaultUpstreamSecurity(t *testing.T) {
s := NewMcpProxyServer("p")
s.SetDefaultUpstreamSecurity(SecurityRequirement{ID: "U", Credential: "c"})
got := s.GetDefaultUpstreamSecurity()
assert.Equal(t, "U", got.ID)
assert.Equal(t, "c", got.Credential)
}
func TestMcpProxyServer_PassthroughAuthHeaderGetterAndSetter(t *testing.T) {
s := NewMcpProxyServer("p")
assert.False(t, s.GetPassthroughAuthHeader(), "default is false")
s.SetPassthroughAuthHeader(true)
assert.True(t, s.GetPassthroughAuthHeader())
s.SetPassthroughAuthHeader(false)
assert.False(t, s.GetPassthroughAuthHeader())
}
// -----------------------------------------------------------------------------
// AddSecurityScheme — nil-map branch
// -----------------------------------------------------------------------------
func TestMcpProxyServer_AddSecurityScheme_InitializesNilMap(t *testing.T) {
// Skip the constructor so we can hit the `securitySchemes == nil` branch.
s := &McpProxyServer{Name: "p"}
s.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"})
got, ok := s.GetSecurityScheme("K")
require.True(t, ok)
assert.Equal(t, "K", got.ID)
}
// -----------------------------------------------------------------------------
// AddMCPTool — delegates to BaseMCPServer
// -----------------------------------------------------------------------------
func TestMcpProxyServer_AddMCPTool_StoresInBaseAndReturnsSelf(t *testing.T) {
s := NewMcpProxyServer("p")
stub := &stubTool{desc: "d"}
ret := s.AddMCPTool("custom", stub)
assert.Same(t, s, ret, "AddMCPTool returns receiver for chaining")
tools := s.GetMCPTools()
got, ok := tools["custom"]
require.True(t, ok)
assert.Same(t, stub, got)
}
// -----------------------------------------------------------------------------
// AddProxyTool — overrides on duplicate name
// -----------------------------------------------------------------------------
func TestMcpProxyServer_AddProxyTool_DuplicateNameOverwrites(t *testing.T) {
s := NewMcpProxyServer("p")
require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "first"}))
require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "second"}))
tools := s.GetMCPTools()
assert.Len(t, tools, 1, "duplicate AddProxyTool should overwrite, not duplicate")
cfg, ok := s.GetToolConfig("t")
require.True(t, ok)
assert.Equal(t, "second", cfg.Description, "later AddProxyTool wins")
}
// -----------------------------------------------------------------------------
// GetToolConfig — hit and miss
// -----------------------------------------------------------------------------
func TestMcpProxyServer_GetToolConfig_HitAndMiss(t *testing.T) {
s := NewMcpProxyServer("p")
require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "d"}))
cfg, ok := s.GetToolConfig("t")
require.True(t, ok)
assert.Equal(t, "d", cfg.Description)
_, missOK := s.GetToolConfig("missing")
assert.False(t, missOK)
}
// -----------------------------------------------------------------------------
// Clone — deep copy of toolsConfig and securitySchemes
// -----------------------------------------------------------------------------
func TestMcpProxyServer_Clone_DeepCopiesToolsConfigAndSchemes(t *testing.T) {
orig := NewMcpProxyServer("orig")
orig.SetMcpServerURL("http://b")
orig.SetTimeout(1234)
orig.SetTransport(TransportSSE)
orig.SetPassthroughAuthHeader(true)
orig.SetDefaultDownstreamSecurity(SecurityRequirement{ID: "K"})
orig.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"})
require.NoError(t, orig.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "d"}))
clonedI := orig.Clone()
cloned, ok := clonedI.(*McpProxyServer)
require.True(t, ok)
require.NotSame(t, orig, cloned, "Clone must return a fresh struct")
// Surface fields are copied.
assert.Equal(t, orig.Name, cloned.Name)
// NOTE: Clone does not propagate mcpServerURL/timeout/transport/passthrough
// nor defaultDownstream/upstreamSecurity. That is intentional today (see
// proxy_server.go:188): cloning is used for per-request isolation of
// tool/security registries only. This test pins that contract — if Clone
// starts copying those fields, update here and document the change.
assert.Equal(t, "", cloned.GetMcpServerURL())
// toolsConfig: deep copy — adding to clone doesn't bleed back to orig.
require.NoError(t, cloned.AddProxyTool(McpProxyToolConfig{Name: "extra", Description: "x"}))
_, origHasExtra := orig.GetToolConfig("extra")
assert.False(t, origHasExtra, "tool added to clone must not appear in original")
// securitySchemes: deep copy — replacing scheme on clone doesn't touch orig.
cloned.AddSecurityScheme(SecurityScheme{ID: "K", Type: "http", Scheme: "bearer"})
origScheme, _ := orig.GetSecurityScheme("K")
clonedScheme, _ := cloned.GetSecurityScheme("K")
assert.Equal(t, "apiKey", origScheme.Type, "original scheme must remain apiKey")
assert.Equal(t, "http", clonedScheme.Type, "clone reflects the override")
}
// -----------------------------------------------------------------------------
// McpProxyTool — Description / InputSchema / OutputSchema / Create
// -----------------------------------------------------------------------------
func TestMcpProxyTool_DescriptionAndOutputSchema(t *testing.T) {
tool := &McpProxyTool{
toolConfig: McpProxyToolConfig{
Description: "describe me",
OutputSchema: map[string]any{"type": "string"},
},
}
assert.Equal(t, "describe me", tool.Description())
assert.Equal(t, map[string]any{"type": "string"}, tool.OutputSchema())
}
func TestMcpProxyTool_InputSchema_RequiredAndOptionalAndEnumAndDefault(t *testing.T) {
tool := &McpProxyTool{
toolConfig: McpProxyToolConfig{
Args: []ToolArg{
{Name: "must", Type: "string", Description: "required", Required: true},
{Name: "opt", Type: "integer", Description: "optional", Default: 7},
{Name: "pick", Type: "string", Description: "enum", Enum: []interface{}{"a", "b"}},
},
},
}
schema := tool.InputSchema()
assert.Equal(t, "object", schema["type"])
required, ok := schema["required"].([]string)
require.True(t, ok)
assert.Equal(t, []string{"must"}, required, "only Required:true args land in required[]")
props := schema["properties"].(map[string]any)
mustProp := props["must"].(map[string]any)
assert.Equal(t, "string", mustProp["type"])
optProp := props["opt"].(map[string]any)
assert.Equal(t, 7, optProp["default"])
pickProp := props["pick"].(map[string]any)
assert.Equal(t, []interface{}{"a", "b"}, pickProp["enum"])
}
func TestMcpProxyTool_InputSchema_NoArgs(t *testing.T) {
tool := &McpProxyTool{toolConfig: McpProxyToolConfig{}}
schema := tool.InputSchema()
props, ok := schema["properties"].(map[string]any)
require.True(t, ok)
assert.Empty(t, props)
required, ok := schema["required"].([]string)
require.True(t, ok)
assert.Empty(t, required)
}
func TestMcpProxyTool_Create_NewInstanceWithBoundArgs(t *testing.T) {
orig := &McpProxyTool{
serverName: "srv",
name: "t",
toolConfig: McpProxyToolConfig{Name: "t"},
}
created := orig.Create([]byte(`{"q":"hello","n":7}`))
require.NotSame(t, orig, created, "Create returns a fresh instance")
cloned := created.(*McpProxyTool)
assert.Equal(t, "srv", cloned.serverName)
assert.Equal(t, "t", cloned.name)
assert.Equal(t, "hello", cloned.arguments["q"])
// JSON unmarshals numbers as float64.
assert.Equal(t, float64(7), cloned.arguments["n"])
}
func TestMcpProxyTool_Create_EmptyParamsStillReturnsInstance(t *testing.T) {
orig := &McpProxyTool{serverName: "s", name: "t"}
created := orig.Create(nil)
cloned := created.(*McpProxyTool)
assert.Equal(t, "s", cloned.serverName)
assert.Equal(t, "t", cloned.name)
require.NotNil(t, cloned.arguments)
assert.Empty(t, cloned.arguments, "no params → empty arguments map, not nil")
}
func TestMcpProxyTool_Create_MalformedJSON_PreservesEmptyArgs(t *testing.T) {
orig := &McpProxyTool{serverName: "s", name: "t"}
created := orig.Create([]byte(`{not json`))
cloned := created.(*McpProxyTool)
// json.Unmarshal silently fails; arguments stays empty.
assert.Empty(t, cloned.arguments)
}
// -----------------------------------------------------------------------------
// ValidateSecurityScheme — full matrix
// -----------------------------------------------------------------------------
func TestValidateSecurityScheme(t *testing.T) {
cases := []struct {
name string
scheme SecurityScheme
wantErr string
}{
{"missing ID", SecurityScheme{Type: "apiKey", In: "header", Name: "X"}, "ID is required"},
{"invalid type", SecurityScheme{ID: "k", Type: "oauth2"}, "invalid security scheme type"},
{"apiKey missing name", SecurityScheme{ID: "k", Type: "apiKey", In: "header"}, "name is required"},
{"apiKey invalid in", SecurityScheme{ID: "k", Type: "apiKey", In: "body", Name: "X"}, "invalid security scheme location"},
{"apiKey ok header", SecurityScheme{ID: "k", Type: "apiKey", In: "header", Name: "X"}, ""},
{"apiKey ok query", SecurityScheme{ID: "k", Type: "apiKey", In: "query", Name: "X"}, ""},
{"apiKey ok cookie", SecurityScheme{ID: "k", Type: "apiKey", In: "cookie", Name: "X"}, ""},
{"http missing scheme", SecurityScheme{ID: "k", Type: "http"}, "scheme is required for http"},
{"http bearer ok", SecurityScheme{ID: "k", Type: "http", Scheme: "bearer"}, ""},
{"http basic ok", SecurityScheme{ID: "k", Type: "http", Scheme: "basic"}, ""},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := ValidateSecurityScheme(c.scheme)
if c.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), c.wantErr)
}
})
}
}
// -----------------------------------------------------------------------------
// ValidateToolConfig — full matrix
// -----------------------------------------------------------------------------
func TestValidateToolConfig(t *testing.T) {
cases := []struct {
name string
config McpProxyToolConfig
wantErr string
}{
{
"missing name",
McpProxyToolConfig{Description: "d"},
"tool name is required",
},
{
"missing description",
McpProxyToolConfig{Name: "t"},
"tool description is required",
},
{
"arg missing name",
McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Type: "string", Description: "x"}}},
"argument name is required",
},
{
"arg duplicate names",
McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{
{Name: "a", Type: "string", Description: "x"},
{Name: "a", Type: "string", Description: "y"},
}},
"duplicate argument name",
},
{
"arg missing description",
McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Name: "a", Type: "string"}}},
"argument description is required",
},
{
"arg invalid type",
McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Name: "a", Type: "money", Description: "x"}}},
"invalid argument type",
},
{
"happy path with multiple typed args",
McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{
{Name: "s", Type: "string", Description: "x"},
{Name: "n", Type: "number", Description: "x"},
{Name: "i", Type: "integer", Description: "x"},
{Name: "b", Type: "boolean", Description: "x"},
{Name: "a", Type: "array", Description: "x"},
{Name: "o", Type: "object", Description: "x"},
}},
"",
},
{
"happy path no args",
McpProxyToolConfig{Name: "t", Description: "d"},
"",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := ValidateToolConfig(c.config)
if c.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), c.wantErr)
}
})
}
}

View File

@@ -0,0 +1,323 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// -----------------------------------------------------------------------------
// NewMcpProtocolHandler
// -----------------------------------------------------------------------------
func TestNewMcpProtocolHandler(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
require.NotNil(t, h)
assert.Equal(t, "http://backend.example/mcp", h.backendURL)
assert.Equal(t, 5000, h.timeout)
assert.Empty(t, h.sessionID, "fresh handler has no session id until Initialize runs")
}
// -----------------------------------------------------------------------------
// parseSSEResponse — fill the remaining branches
// -----------------------------------------------------------------------------
func TestParseSSEResponse_OnlyCommentsAndBlanks(t *testing.T) {
// All non-data lines → must surface "no data field found".
_, err := parseSSEResponse([]byte(": only a comment\n\n: another\n"))
require.Error(t, err)
assert.Contains(t, err.Error(), "no data field")
}
func TestParseSSEResponse_TooLongLine(t *testing.T) {
// Single data line larger than the scanner's 32MB max-token cap.
big := strings.Repeat("x", 33*1024*1024)
_, err := parseSSEResponse([]byte("data: " + big + "\n\n"))
require.Error(t, err)
assert.Contains(t, err.Error(), "32MB", "must surface the max-token overflow as a clear error")
}
func TestParseSSEResponse_MultipleDataLinesReturnsFirst(t *testing.T) {
body := "data: first\n\ndata: second\n\n"
out, err := parseSSEResponse([]byte(body))
require.NoError(t, err)
assert.Equal(t, "first", string(out), "the function returns the first data line and stops")
}
// -----------------------------------------------------------------------------
// createInitializeRequest / createToolsListRequest / createToolsCallRequest
// -----------------------------------------------------------------------------
func TestCreateInitializeRequest_StableShape(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
req := h.createInitializeRequest()
assert.Equal(t, "2.0", req["jsonrpc"])
assert.Equal(t, 1, req["id"])
assert.Equal(t, "initialize", req["method"])
params, ok := req["params"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "2025-03-26", params["protocolVersion"])
clientInfo, ok := params["clientInfo"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "Higress-mcp-proxy", clientInfo["name"])
assert.Equal(t, "1.0.0", clientInfo["version"])
_, hasCaps := params["capabilities"]
assert.True(t, hasCaps)
}
func TestCreateToolsListRequest_NoCursor(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
req := h.createToolsListRequest(nil)
assert.Equal(t, "2.0", req["jsonrpc"])
assert.Equal(t, 2, req["id"])
assert.Equal(t, "tools/list", req["method"])
params, ok := req["params"].(map[string]interface{})
require.True(t, ok)
_, hasCursor := params["cursor"]
assert.False(t, hasCursor, "nil cursor must produce no cursor field")
}
func TestCreateToolsListRequest_EmptyStringCursor(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
empty := ""
req := h.createToolsListRequest(&empty)
params := req["params"].(map[string]interface{})
_, hasCursor := params["cursor"]
assert.False(t, hasCursor, "empty-string cursor is treated as absent")
}
func TestCreateToolsListRequest_WithCursor(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
c := "next-page"
req := h.createToolsListRequest(&c)
params := req["params"].(map[string]interface{})
assert.Equal(t, "next-page", params["cursor"])
}
func TestCreateToolsCallRequest_StableShape(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
args := map[string]interface{}{"q": "hello", "limit": 5}
req := h.createToolsCallRequest("search", args)
assert.Equal(t, "2.0", req["jsonrpc"])
assert.Equal(t, 3, req["id"])
assert.Equal(t, "tools/call", req["method"])
params, ok := req["params"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "search", params["name"])
gotArgs, ok := params["arguments"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "hello", gotArgs["q"])
assert.Equal(t, 5, gotArgs["limit"])
}
func TestCreateToolsCallRequest_NilArguments(t *testing.T) {
h := NewMcpProtocolHandler("http://backend.example/mcp", 5000)
req := h.createToolsCallRequest("noop", nil)
params := req["params"].(map[string]interface{})
assert.Equal(t, "noop", params["name"])
args, ok := params["arguments"]
require.True(t, ok)
assert.Nil(t, args)
}
// -----------------------------------------------------------------------------
// ParseBackendResponse / IsBackendError — extra branches
// -----------------------------------------------------------------------------
func TestParseBackendResponse_StringErrorField(t *testing.T) {
// JSON-RPC error field doesn't have to be an object — anything truthy works.
body := []byte(`{"jsonrpc":"2.0","id":1,"error":"some-text"}`)
resp, isErr, etype := ParseBackendResponse(body)
require.NotNil(t, resp)
assert.True(t, isErr)
assert.Equal(t, "jsonrpc_error", etype)
}
func TestParseBackendResponse_NoResultNoError(t *testing.T) {
// Valid JSON without result/error → not an error, but still parsed.
body := []byte(`{"jsonrpc":"2.0","id":2}`)
resp, isErr, etype := ParseBackendResponse(body)
require.NotNil(t, resp)
assert.False(t, isErr)
assert.Empty(t, etype)
}
func TestParseBackendResponse_ResultIsErrorFalseNotAnError(t *testing.T) {
body := []byte(`{"jsonrpc":"2.0","id":3,"result":{"isError":false}}`)
_, isErr, etype := ParseBackendResponse(body)
assert.False(t, isErr)
assert.Empty(t, etype)
}
func TestParseBackendResponse_ResultNotAnObject(t *testing.T) {
// result is a scalar — the isError-extraction branch is skipped.
body := []byte(`{"jsonrpc":"2.0","id":3,"result":"ok"}`)
_, isErr, etype := ParseBackendResponse(body)
assert.False(t, isErr)
assert.Empty(t, etype)
}
func TestIsBackendError_DelegatesToParse(t *testing.T) {
cases := []struct {
body string
isError bool
etype string
}{
{`{"error":{"code":-1}}`, true, "jsonrpc_error"},
{`{"result":{"isError":true}}`, true, "result_isError"},
{`{"result":{"isError":false}}`, false, ""},
{`not json`, false, ""},
}
for _, c := range cases {
isErr, etype := IsBackendError([]byte(c.body))
assert.Equal(t, c.isError, isErr, "body=%s", c.body)
assert.Equal(t, c.etype, etype, "body=%s", c.body)
}
}
// -----------------------------------------------------------------------------
// McpSessionManagerImpl
// -----------------------------------------------------------------------------
func TestNewMcpSessionManagerImpl(t *testing.T) {
m := NewMcpSessionManagerImpl()
require.NotNil(t, m)
require.NotNil(t, m.sessions)
assert.Empty(t, m.sessions)
}
func TestSessionManager_CreateAndGet(t *testing.T) {
m := NewMcpSessionManagerImpl()
id, err := m.CreateSession("http://backend.example/mcp")
require.NoError(t, err)
assert.True(t, strings.HasPrefix(id, "mcp-session-"))
session, ok := m.GetSession(id)
require.True(t, ok)
assert.Equal(t, id, session.ID)
assert.Equal(t, "http://backend.example/mcp", session.BackendURL)
assert.False(t, session.CreatedAt.IsZero())
}
func TestSessionManager_GetSessionUpdatesLastUsed(t *testing.T) {
m := NewMcpSessionManagerImpl()
id, err := m.CreateSession("http://b")
require.NoError(t, err)
// Force a measurable gap so LastUsed changes monotonically.
original := m.sessions[id].LastUsed
time.Sleep(2 * time.Millisecond)
s, ok := m.GetSession(id)
require.True(t, ok)
assert.True(t, s.LastUsed.After(original), "GetSession should refresh LastUsed")
}
func TestSessionManager_GetUnknownSession(t *testing.T) {
m := NewMcpSessionManagerImpl()
s, ok := m.GetSession("missing")
assert.False(t, ok)
assert.Nil(t, s)
}
func TestSessionManager_CleanupSession_Existing(t *testing.T) {
m := NewMcpSessionManagerImpl()
id, _ := m.CreateSession("http://b")
m.CleanupSession(id)
_, ok := m.GetSession(id)
assert.False(t, ok)
}
func TestSessionManager_CleanupSession_NonExistent(t *testing.T) {
m := NewMcpSessionManagerImpl()
// Must not panic on unknown id.
m.CleanupSession("never-existed")
assert.Empty(t, m.sessions)
}
func TestSessionManager_CleanupExpiredSessions(t *testing.T) {
m := NewMcpSessionManagerImpl()
fresh, _ := m.CreateSession("fresh")
stale, _ := m.CreateSession("stale")
// Backdate the stale session.
m.sessions[stale].LastUsed = time.Now().Add(-10 * time.Minute)
m.CleanupExpiredSessions(1 * time.Minute)
_, freshOk := m.sessions[fresh]
_, staleOk := m.sessions[stale]
assert.True(t, freshOk, "fresh session must remain")
assert.False(t, staleOk, "stale session must be removed")
}
func TestSessionManager_CleanupExpiredSessions_EmptyMap(t *testing.T) {
m := NewMcpSessionManagerImpl()
// Must not panic on empty manager.
m.CleanupExpiredSessions(1 * time.Second)
assert.Empty(t, m.sessions)
}
func TestSessionManager_CreateSessionsAreUnique(t *testing.T) {
m := NewMcpSessionManagerImpl()
id1, _ := m.CreateSession("http://b")
// Guarantee a different nanosecond timestamp.
time.Sleep(1 * time.Millisecond)
id2, _ := m.CreateSession("http://b")
assert.NotEqual(t, id1, id2, "session IDs should be unique")
}
// -----------------------------------------------------------------------------
// ensureHeader
// -----------------------------------------------------------------------------
func TestEnsureHeader_AddsWhenMissing(t *testing.T) {
headers := [][2]string{{"X-Other", "v"}}
ensureHeader(&headers, "X-New", "value")
require.Len(t, headers, 2)
assert.Equal(t, [2]string{"X-New", "value"}, headers[1])
}
func TestEnsureHeader_ReplacesCaseInsensitively(t *testing.T) {
headers := [][2]string{{"content-type", "text/plain"}}
ensureHeader(&headers, "Content-Type", "application/json")
require.Len(t, headers, 1)
// Replace path rewrites the original casing too.
assert.Equal(t, "Content-Type", headers[0][0])
assert.Equal(t, "application/json", headers[0][1])
}
func TestEnsureHeader_NoDuplicateOnRepeatedCalls(t *testing.T) {
headers := [][2]string{}
ensureHeader(&headers, "X-K", "1")
ensureHeader(&headers, "X-K", "2")
require.Len(t, headers, 1)
assert.Equal(t, "2", headers[0][1])
}

View File

@@ -20,6 +20,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
)
@@ -920,3 +921,377 @@ func TestRestServerSecurityFallback(t *testing.T) {
t.Logf("REST server security fallback test completed successfully")
}
// ---------------------------------------------------------------------------
// parseIP
// ---------------------------------------------------------------------------
func TestParseIP(t *testing.T) {
cases := []struct {
name string
source string
fromHeader bool
want string
}{
{"ipv4 only", "10.0.0.1", false, "10.0.0.1"},
{"ipv4 with port", "10.0.0.1:8080", false, "10.0.0.1"},
{"ipv4 with leading whitespace", " 10.0.0.1:80", false, "10.0.0.1"},
{"ipv4 X-Forwarded-For first hop", "10.0.0.1, 10.0.0.2, 10.0.0.3", true, "10.0.0.1"},
{"ipv4 X-Forwarded-For with spaces", " 10.0.0.1 , 10.0.0.2 ", true, "10.0.0.1"},
{"ipv6 bracketed with port", "[2001:db8::1]:443", false, "2001:db8::1"},
{"ipv6 bracketed no port", "[2001:db8::1]", false, "2001:db8::1"},
{"ipv6 bare passes through", "2001:db8::1", false, "2001:db8::1"},
{"empty string", "", false, ""},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := parseIP(c.source, c.fromHeader)
assert.Equal(t, c.want, got)
})
}
}
// ---------------------------------------------------------------------------
// parseTemplates — fill remaining error branches
// ---------------------------------------------------------------------------
func TestParseTemplates_DirectResponseMissingBody(t *testing.T) {
// No RequestTemplate.URL → direct-response mode. ResponseTemplate.Body must be set.
tool := RestTool{}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "direct response mode")
}
}
func TestParseTemplates_DirectResponseWithBodyOk(t *testing.T) {
tool := RestTool{
ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"},
}
assert.NoError(t, tool.parseTemplates())
assert.True(t, tool.isDirectResponseTool)
}
func TestParseTemplates_URLTemplateParseError(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "http://x/{{ .unclosed ", // missing closing braces
Method: "GET",
},
}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "URL template")
}
}
func TestParseTemplates_HeaderTemplateParseError(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "http://x",
Method: "GET",
Headers: []RestToolHeader{
{Key: "X-Bad", Value: "{{ .unclosed "},
},
},
}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "header template")
}
}
func TestParseTemplates_BodyTemplateParseError(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "http://x",
Method: "POST",
Body: "{{ .unclosed ",
},
}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "body template")
}
}
func TestParseTemplates_ResponseTemplateParseError(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"},
ResponseTemplate: RestToolResponseTemplate{
Body: "{{ .unclosed ",
},
}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "response template")
}
}
func TestParseTemplates_ErrorResponseTemplateParseError(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"},
ErrorResponseTemplate: "{{ .unclosed ",
}
err := tool.parseTemplates()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "error response template")
}
}
func TestParseTemplates_HeaderWithEmptyKeySkipped(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{
URL: "http://x",
Method: "GET",
Headers: []RestToolHeader{
{Key: "", Value: "ignored"},
{Key: "X-Real", Value: "real"},
},
},
}
assert.NoError(t, tool.parseTemplates())
_, hasReal := tool.parsedHeaderTemplates["X-Real"]
assert.True(t, hasReal)
_, hasEmpty := tool.parsedHeaderTemplates[""]
assert.False(t, hasEmpty)
}
func TestParseTemplates_PopulatesArgPositions(t *testing.T) {
tool := RestTool{
RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"},
ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"},
Args: []RestToolArg{
{Name: "q", Position: "QUERY"}, // lower-cased in argPositions
{Name: "h", Position: "Header"},
{Name: "noPos"}, // no position → not stored
},
}
require := assert.New(t)
require.NoError(tool.parseTemplates())
require.Equal("query", tool.argPositions["q"])
require.Equal("header", tool.argPositions["h"])
_, ok := tool.argPositions["noPos"]
require.False(ok)
}
// ---------------------------------------------------------------------------
// executeTemplate — nil + execution error
// ---------------------------------------------------------------------------
func TestExecuteTemplate_NilReturnsError(t *testing.T) {
_, err := executeTemplate(nil, []byte(`{}`))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "nil")
}
}
// ---------------------------------------------------------------------------
// RestMCPServer accessors — GetSecurityScheme, GetPassthroughAuthHeader,
// AddMCPTool, GetConfig, GetToolConfig
// ---------------------------------------------------------------------------
func TestRestServer_AddMCPTool_DelegatesToBase(t *testing.T) {
s := NewRestMCPServer("rest")
tool := &stubTool{desc: "x"}
ret := s.AddMCPTool("plain", tool)
assert.Same(t, s, ret)
got, ok := s.GetMCPTools()["plain"]
assert.True(t, ok)
assert.Same(t, tool, got)
}
func TestRestServer_GetSecurityScheme_HitAndMiss(t *testing.T) {
s := NewRestMCPServer("rest")
scheme := SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X-K"}
s.AddSecurityScheme(scheme)
got, ok := s.GetSecurityScheme("K")
assert.True(t, ok)
assert.Equal(t, "K", got.ID)
_, ok = s.GetSecurityScheme("missing")
assert.False(t, ok)
}
func TestRestServer_PassthroughAuthHeader(t *testing.T) {
s := NewRestMCPServer("rest")
assert.False(t, s.GetPassthroughAuthHeader())
s.SetPassthroughAuthHeader(true)
assert.True(t, s.GetPassthroughAuthHeader())
}
func TestRestServer_GetToolConfig(t *testing.T) {
s := NewRestMCPServer("rest")
require.NoError(t, s.AddRestTool(RestTool{
Name: "t",
ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"},
}))
cfg, ok := s.GetToolConfig("t")
assert.True(t, ok)
assert.Equal(t, "t", cfg.Name)
_, ok = s.GetToolConfig("missing")
assert.False(t, ok)
}
// ---------------------------------------------------------------------------
// RestMCPServer.Clone — independence
// ---------------------------------------------------------------------------
func TestRestServer_Clone_Independence(t *testing.T) {
orig := NewRestMCPServer("rest")
orig.SetPassthroughAuthHeader(true)
orig.SetConfig([]byte(`{"v":1}`))
orig.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"})
require.NoError(t, orig.AddRestTool(RestTool{
Name: "t",
ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"},
}))
clonedI := orig.Clone()
require.NotNil(t, clonedI)
cloned, ok := clonedI.(*RestMCPServer)
require.True(t, ok)
// Mutate the original: cloned must not see the change.
orig.AddSecurityScheme(SecurityScheme{ID: "K2", Type: "apiKey", In: "header", Name: "Y"})
_, hasK2 := cloned.GetSecurityScheme("K2")
assert.False(t, hasK2, "cloned server must not see security scheme added to original after Clone")
// Tools map was deep-copied at Clone time.
_, hasT := cloned.GetToolConfig("t")
assert.True(t, hasT)
}
// ---------------------------------------------------------------------------
// RestMCPTool.Create — type coercion matrix
// ---------------------------------------------------------------------------
func newRestToolForCreate(t *testing.T) *RestMCPTool {
t.Helper()
tool := RestTool{
Name: "t",
Args: []RestToolArg{
{Name: "b", Type: "boolean"},
{Name: "i", Type: "integer"},
{Name: "n", Type: "number"},
{Name: "s", Type: "string"},
{Name: "d", Type: "integer", Default: 7},
},
ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"},
}
require.NoError(t, tool.parseTemplates())
return &RestMCPTool{
serverName: "rest",
name: "t",
toolConfig: tool,
}
}
func TestRestMCPTool_Create_BooleanCoercion(t *testing.T) {
tool := newRestToolForCreate(t)
// Boolean from native true, native false, string "true", string "false",
// string with garbage (passthrough), and other types (passthrough).
cases := []struct {
raw any
want any
}{
{true, true},
{false, false},
{"true", true},
{"false", false},
{"yes", "yes"},
// JSON unmarshal turns any number into float64; non-bool/non-string
// hits the default arm and is stored verbatim.
{42, float64(42)},
}
for _, c := range cases {
body, err := json.Marshal(map[string]any{"b": c.raw})
require.NoError(t, err)
created := tool.Create(body).(*RestMCPTool)
assert.Equal(t, c.want, created.arguments["b"], "raw=%v", c.raw)
}
}
func TestRestMCPTool_Create_IntegerCoercion(t *testing.T) {
tool := newRestToolForCreate(t)
cases := []struct {
raw any
want any
}{
{float64(10), 10},
{"42", 42},
{"not-int", "not-int"},
{true, true},
}
for _, c := range cases {
body, err := json.Marshal(map[string]any{"i": c.raw})
require.NoError(t, err)
created := tool.Create(body).(*RestMCPTool)
assert.Equal(t, c.want, created.arguments["i"], "raw=%v", c.raw)
}
}
func TestRestMCPTool_Create_NumberCoercion(t *testing.T) {
tool := newRestToolForCreate(t)
cases := []struct {
raw any
want any
}{
{"3.14", 3.14},
{"abc", "abc"},
{float64(2.5), 2.5}, // default: passthrough
}
for _, c := range cases {
body, err := json.Marshal(map[string]any{"n": c.raw})
require.NoError(t, err)
created := tool.Create(body).(*RestMCPTool)
assert.Equal(t, c.want, created.arguments["n"], "raw=%v", c.raw)
}
}
func TestRestMCPTool_Create_DefaultApplied(t *testing.T) {
tool := newRestToolForCreate(t)
body := []byte(`{}`)
created := tool.Create(body).(*RestMCPTool)
assert.Equal(t, 7, created.arguments["d"])
// Args without defaults are not present when omitted.
_, hasI := created.arguments["i"]
assert.False(t, hasI)
}
func TestRestMCPTool_Create_StringPassthrough(t *testing.T) {
tool := newRestToolForCreate(t)
body, _ := json.Marshal(map[string]any{"s": "hello"})
created := tool.Create(body).(*RestMCPTool)
assert.Equal(t, "hello", created.arguments["s"])
}
func TestRestMCPTool_Create_MalformedJSONStillProducesTool(t *testing.T) {
tool := newRestToolForCreate(t)
// Bad JSON is logged + ignored; defaults still applied.
created := tool.Create([]byte("{not json")).(*RestMCPTool)
assert.Equal(t, 7, created.arguments["d"], "default still applied when params unparseable")
}
// ---------------------------------------------------------------------------
// hasContentType — case + charset suffix
// ---------------------------------------------------------------------------
func TestHasContentType_CaseAndCharsetSuffix(t *testing.T) {
headers := [][2]string{
{"content-type", "Application/JSON; charset=utf-8"},
}
assert.True(t, hasContentType(headers, "application/json"))
assert.True(t, hasContentType(headers, "json"))
assert.False(t, hasContentType(headers, "xml"))
emptyHeaders := [][2]string{}
assert.False(t, hasContentType(emptyHeaders, "application/json"))
}
// pull in `require` for newer tests above without disturbing existing imports.
var _ = url.Parse
var _ = sjson.Set
var _ = strings.TrimSpace

View File

@@ -15,7 +15,11 @@
package server
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestParseSSEMessage tests SSE message parsing
@@ -295,3 +299,176 @@ data: {"id":2}
t.Errorf("Expected no more messages, got: %+v", msg4)
}
}
// -----------------------------------------------------------------------------
// ParseSSEMessage — additional edge cases (multi-line data, retry, empty)
// -----------------------------------------------------------------------------
func TestParseSSEMessage_EmptyInput(t *testing.T) {
msg, remaining, err := ParseSSEMessage([]byte(""))
require.NoError(t, err)
assert.Nil(t, msg)
assert.Len(t, remaining, 0)
}
func TestParseSSEMessage_RetryFieldIgnored(t *testing.T) {
// `retry:` is part of the SSE spec but not implemented — must not break parsing.
input := []byte("retry: 5000\nevent: message\ndata: hi\n\n")
msg, _, err := ParseSSEMessage(input)
require.NoError(t, err)
require.NotNil(t, msg)
assert.Equal(t, "message", msg.Event)
assert.Equal(t, "hi", msg.Data)
}
func TestParseSSEMessage_MultiLineDataConcatenated(t *testing.T) {
// Per SSE spec, multiple `data:` lines in one message join with `\n`.
input := []byte("data: line-one\ndata: line-two\ndata: line-three\n\n")
msg, _, err := ParseSSEMessage(input)
require.NoError(t, err)
require.NotNil(t, msg)
assert.Equal(t, "line-one\nline-two\nline-three", msg.Data)
}
func TestParseSSEMessage_NoFinalBlankLine_NoMessageReturned(t *testing.T) {
// Message without the terminating blank line is treated as incomplete.
input := []byte("event: message\ndata: payload\n")
msg, remaining, err := ParseSSEMessage(input)
require.NoError(t, err)
assert.Nil(t, msg, "incomplete message must not be returned")
assert.Equal(t, input, remaining, "remaining is the entire input")
}
func TestParseSSEMessage_LineWithoutColonSkipped(t *testing.T) {
// SplitN with len<2 → field/value pair can't be formed → skipped, not an error.
input := []byte("a-line-without-colon\nevent: msg\ndata: x\n\n")
msg, _, err := ParseSSEMessage(input)
require.NoError(t, err)
require.NotNil(t, msg)
assert.Equal(t, "msg", msg.Event)
assert.Equal(t, "x", msg.Data)
}
func TestParseSSEMessage_UnknownFieldIgnored(t *testing.T) {
// `random-field:` is parsed but the switch case ignores it.
input := []byte("random-field: stuff\nevent: msg\ndata: x\n\n")
msg, _, err := ParseSSEMessage(input)
require.NoError(t, err)
require.NotNil(t, msg)
assert.Equal(t, "msg", msg.Event)
}
// -----------------------------------------------------------------------------
// ExtractEndpointURL — edge cases not in the table
// -----------------------------------------------------------------------------
func TestExtractEndpointURL_HttpsPassthrough(t *testing.T) {
got, err := ExtractEndpointURL("https://other.example/x", "http://b.example")
require.NoError(t, err)
assert.Equal(t, "https://other.example/x", got, "full https URL must pass through unchanged")
}
func TestExtractEndpointURL_EmptyEndpointData_PathOnlyBase(t *testing.T) {
got, err := ExtractEndpointURL("", "/some/path")
require.NoError(t, err)
assert.Equal(t, "", got, "empty endpointData with path-only base → empty result")
}
func TestExtractEndpointURL_RelativeEndpointWithSchemeBase(t *testing.T) {
got, err := ExtractEndpointURL("messages", "http://b.example/mcp")
require.NoError(t, err)
assert.Equal(t, "http://b.example/messages", got, "leading slash auto-inserted")
}
// -----------------------------------------------------------------------------
// applyProxyAuthenticationForSSE — pure URL+header munging (no proxywasm)
// -----------------------------------------------------------------------------
func TestApplyProxyAuthenticationForSSE_ApiKeyHeader(t *testing.T) {
server := NewMcpProxyServer("p")
server.AddSecurityScheme(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "abc",
})
headers := [][2]string{{"X-Other", "v"}}
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x")
require.NoError(t, err)
assert.Equal(t, "http://backend/x", got, "no query → URL preserved")
found := false
for _, kv := range headers {
if strings.EqualFold(kv[0], "X-Api-Key") {
assert.Equal(t, "abc", kv[1])
found = true
}
}
assert.True(t, found, "API key header must be injected")
}
func TestApplyProxyAuthenticationForSSE_ApiKeyQuery_PreservesExisting(t *testing.T) {
server := NewMcpProxyServer("p")
server.AddSecurityScheme(SecurityScheme{
ID: "K", Type: "apiKey", In: "query", Name: "api_key",
DefaultCredential: "secret",
})
headers := [][2]string{}
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x?existing=1")
require.NoError(t, err)
// Query is rebuilt via url.Values.Encode — both pairs must be present.
assert.Contains(t, got, "api_key=secret")
assert.Contains(t, got, "existing=1")
}
func TestApplyProxyAuthenticationForSSE_PathOnlyURL_PreservesShape(t *testing.T) {
server := NewMcpProxyServer("p")
server.AddSecurityScheme(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "abc",
})
headers := [][2]string{}
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "/relative/path")
require.NoError(t, err)
assert.Equal(t, "/relative/path", got, "path-only URL must come back as path-only")
}
func TestApplyProxyAuthenticationForSSE_HttpBearerPassthrough(t *testing.T) {
server := NewMcpProxyServer("p")
server.AddSecurityScheme(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"})
headers := [][2]string{}
got, err := applyProxyAuthenticationForSSE(server, "B", "passthrough-token", &headers, "http://backend/x")
require.NoError(t, err)
assert.Equal(t, "http://backend/x", got)
var authValue string
for _, kv := range headers {
if strings.EqualFold(kv[0], "Authorization") {
authValue = kv[1]
}
}
assert.Equal(t, "Bearer passthrough-token", authValue)
}
func TestApplyProxyAuthenticationForSSE_MissingScheme_ReturnsError(t *testing.T) {
server := NewMcpProxyServer("p")
headers := [][2]string{}
_, err := applyProxyAuthenticationForSSE(server, "missing", "", &headers, "http://backend/x")
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestApplyProxyAuthenticationForSSE_PreservesFragment(t *testing.T) {
server := NewMcpProxyServer("p")
server.AddSecurityScheme(SecurityScheme{
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
DefaultCredential: "abc",
})
headers := [][2]string{}
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/path#section-2")
require.NoError(t, err)
assert.Contains(t, got, "#section-2", "fragment must round-trip")
}