diff --git a/plugins/wasm-go/extensions/mcp-server/main_test.go b/plugins/wasm-go/extensions/mcp-server/main_test.go index 648cf02bb..c1dc34c3f 100644 --- a/plugins/wasm-go/extensions/mcp-server/main_test.go +++ b/plugins/wasm-go/extensions/mcp-server/main_test.go @@ -2764,3 +2764,849 @@ data: {"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"Secure }) }) } + +// ----------------------------------------------------------------------------- +// Phase 2.1 — REST Server Call() branch matrix +// +// Each sub-test stands up a minimal REST MCP config whose single tool exercises +// a specific branch of rest_server.go:Call (~lines 523-946). The plugin runs the +// real tool, makes one ctx.RouteCall, and we mock the backend reply via +// CallOnHttpResponseHeaders/Body. We inspect the upstream request via +// GetRequestHeaders/Body and the MCP-shaped response via GetResponseBody. +// ----------------------------------------------------------------------------- + +func TestRestMCPServer_CallBranches(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // ------------------------------------------------------------------- + // argsToFormBody → upstream body is form-urlencoded; Content-Type set + // ------------------------------------------------------------------- + t.Run("argsToFormBody encodes body as form-urlencoded", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-form", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "submit", + "description": "form submit", + "args": []map[string]interface{}{ + {"name": "user", "description": "u", "type": "string", "required": true}, + {"name": "msg", "description": "m", "type": "string"}, + }, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/form", + "method": "POST", + "argsToFormBody": true, + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit","arguments":{"user":"alice 1","msg":"hi&there"}}}`)) + + upstreamHeaders := host.GetRequestHeaders() + upstreamBody := host.GetRequestBody() + + ctHeader, has := test.GetHeaderValue(upstreamHeaders, "Content-Type") + require.True(t, has) + require.Contains(t, ctHeader, "application/x-www-form-urlencoded") + // `&` must be percent-encoded, space encoded as `+`. + require.Contains(t, string(upstreamBody), "user=alice+1") + require.Contains(t, string(upstreamBody), "msg=hi%26there") + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // argsToUrlParam → default args merged into URL query + // ------------------------------------------------------------------- + t.Run("argsToUrlParam merges args into query", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-qp", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "search", + "description": "search", + "args": []map[string]interface{}{ + {"name": "q", "description": "q", "type": "string", "required": true}, + {"name": "limit", "description": "lim", "type": "integer"}, + }, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/search?pre=set", + "method": "GET", + "argsToUrlParam": true, + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"search","arguments":{"q":"hello","limit":10}}}`)) + + upstreamHeaders := host.GetRequestHeaders() + pathVal, has := test.GetHeaderValue(upstreamHeaders, ":path") + require.True(t, has) + require.Contains(t, pathVal, "pre=set") + require.Contains(t, pathVal, "q=hello") + require.Contains(t, pathVal, "limit=10") + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // Direct-response tool → no backend call, response template fires + // ------------------------------------------------------------------- + t.Run("direct response tool emits template result", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-dr", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "ping", + "description": "static ping", + "args": []map[string]interface{}{ + {"name": "name", "description": "n", "type": "string", "required": true}, + }, + // No requestTemplate.url → direct-response mode. + "responseTemplate": map[string]interface{}{ + "body": "hello {{.args.name}}", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"ping","arguments":{"name":"world"}}}`)) + + // Direct-response writes via SendLocalResponse, not the streaming response body. + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "direct-response must emit a local response with no backend call") + require.Contains(t, string(localResp.Data), "hello world") + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // Image content-type → SendMCPToolImageResult path (base64 in response) + // ------------------------------------------------------------------- + t.Run("image content-type produces image MCP result", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-img", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "get_image", + "description": "image fetch", + "args": []map[string]interface{}{{"name": "id", "description": "id", "type": "string"}}, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/img/{{.args.id}}", "method": "GET", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_image","arguments":{"id":"42"}}}`)) + + pngBytes := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "image/png"}}) + host.CallOnHttpResponseBody(pngBytes) + + respBody := host.GetResponseBody() + require.NotEmpty(t, respBody) + require.Contains(t, string(respBody), `"type":"image"`) + require.Contains(t, string(respBody), `"mimeType":"image/png"`) + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // outputSchema + JSON backend → structuredContent populated + // ------------------------------------------------------------------- + t.Run("outputSchema with JSON body emits structuredContent", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-os", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "info", + "description": "info", + "args": []map[string]interface{}{}, + "outputSchema": map[string]interface{}{"type": "object"}, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/info", "method": "GET", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"info","arguments":{}}}`)) + + host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "application/json"}}) + host.CallOnHttpResponseBody([]byte(`{"a":1,"b":"two"}`)) + + respBody := host.GetResponseBody() + require.NotEmpty(t, respBody) + require.Contains(t, string(respBody), "structuredContent", "outputSchema must produce structuredContent field") + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // errorResponseTemplate fires on 4xx/5xx, _headers is accessible + // ------------------------------------------------------------------- + t.Run("errorResponseTemplate renders on backend error", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-err", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "fail", + "description": "fail", + "args": []map[string]interface{}{}, + "errorResponseTemplate": `upstream said: {{gjson "_headers.:status"}} - {{gjson "message"}}`, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/fail", "method": "GET", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"fail","arguments":{}}}`)) + + host.CallOnHttpResponseHeaders([][2]string{{":status", "500"}, {"Content-Type", "application/json"}}) + host.CallOnHttpResponseBody([]byte(`{"message":"boom"}`)) + + respBody := host.GetResponseBody() + require.NotEmpty(t, respBody) + require.Contains(t, string(respBody), "upstream said: 500 - boom") + require.Contains(t, string(respBody), `"isError":true`) + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // prependBody + appendBody wrap raw response + // ------------------------------------------------------------------- + t.Run("prependBody/appendBody wrap raw response", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-wrap", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "say", + "description": "say", + "args": []map[string]interface{}{}, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/say", "method": "GET", + }, + "responseTemplate": map[string]interface{}{ + "prependBody": "<<", + "appendBody": ">>", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"say","arguments":{}}}`)) + + host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"Content-Type", "text/plain"}}) + host.CallOnHttpResponseBody([]byte("hi")) + + respBody := host.GetResponseBody() + require.NotEmpty(t, respBody) + // Unmarshal to dodge JSON unicode-escape encoding of < and >. + var parsed map[string]interface{} + require.NoError(t, json.Unmarshal(respBody, &parsed)) + result := parsed["result"].(map[string]interface{}) + content := result["content"].([]interface{}) + text := content[0].(map[string]interface{})["text"].(string) + require.Equal(t, "<>", text) + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // path arg with reserved chars → substituted into URL template + // ------------------------------------------------------------------- + t.Run("path arg substitutes into URL placeholder", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-path", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "get_user", + "description": "by id", + "args": []map[string]interface{}{ + {"name": "id", "description": "uid", "type": "string", "required": true, "position": "path"}, + }, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/users/{id}", + "method": "GET", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_user","arguments":{"id":"alice%20smith"}}}`)) + + upstreamHeaders := host.GetRequestHeaders() + pathVal, has := test.GetHeaderValue(upstreamHeaders, ":path") + require.True(t, has) + // Path arg is substituted as-is into the URL template before being parsed. + require.Contains(t, pathVal, "/users/alice") + host.CompleteHttp() + }) + + // ------------------------------------------------------------------- + // header arg from args list lands as upstream header + // ------------------------------------------------------------------- + t.Run("header-position arg becomes upstream header", func(t *testing.T) { + cfg, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{"name": "rest-hdr", "type": "rest"}, + "tools": []map[string]interface{}{{ + "name": "auth_call", + "description": "with header", + "args": []map[string]interface{}{ + {"name": "X-Trace", "description": "trace", "type": "string", "required": true, "position": "header"}, + }, + "requestTemplate": map[string]interface{}{ + "url": "http://backend.example/protected", + "method": "GET", + }, + }}, + }) + host, status := test.NewTestHost(cfg) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"auth_call","arguments":{"X-Trace":"abc-123"}}}`)) + + upstreamHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(upstreamHeaders, "X-Trace", "abc-123")) + host.CompleteHttp() + }) + }) +} + +// ----------------------------------------------------------------------------- +// Phase 2.2 — SSE state machine error / edge paths +// +// Drive sse_proxy.go's handleSSEStreamingResponse + handleWaitingEndpoint / +// handleWaitingInitResp / handleWaitingToolResp through error branches that +// the happy-path TestMcpProxyServerSSE* tests don't cover. +// ----------------------------------------------------------------------------- + +func TestMcpProxyServerSSE_NonSSEContentTypeRejected(t *testing.T) { + // Backend returns text/plain instead of text/event-stream → first-chunk + // content-type validation must inject a JSON-RPC error. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerSSEConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + // First-chunk: backend returned a non-SSE content-type. + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpStreamingResponseBody([]byte(`{"not":"sse"}`), true) + + respBody := host.GetResponseBody() + require.NotEmpty(t, respBody) + require.Contains(t, string(respBody), "error", "rejected non-SSE response must surface a JSON-RPC error") + host.CompleteHttp() + }) +} + +func TestMcpProxyServerSSE_CharsetSuffixAccepted(t *testing.T) { + // content-type: text/event-stream;charset=utf-8 must still be accepted + // (substring match on text/event-stream). + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerSSEConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"content-type", "text/event-stream; charset=utf-8"}, + }) + // Endpoint event triggers initialize — we don't drive past this, + // just confirm content-type-suffix wasn't rejected (no error in body). + host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/abc\n\n"), false) + + // No injected error yet — buffer is just being consumed. + // Local response should NOT have been set with an error. + localResp := host.GetLocalResponse() + if localResp != nil { + require.NotContains(t, string(localResp.Data), "invalid content-type", + "text/event-stream with charset suffix must NOT be rejected") + } + host.CompleteHttp() + }) +} + +func TestMcpProxyServerSSE_EndpointSkipsUnrelatedEvents(t *testing.T) { + // Send a `ping` event before `endpoint` — the state machine must skip + // non-endpoint messages while in WaitingEndpoint and not error out. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerSSEConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}}) + + // ping first — must be ignored. + host.CallOnHttpStreamingResponseBody([]byte("event: ping\ndata: keep-alive\n\n"), false) + + // Then the real endpoint event. + host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/session-xyz\n\n"), false) + + // Verify the initialize request was sent upstream (proxy moved past WaitingEndpoint). + callouts := host.GetHttpCalloutAttributes() + require.NotEmpty(t, callouts, "endpoint event should trigger an initialize HTTP callout") + var sawInit bool + for _, c := range callouts { + if strings.Contains(string(c.Body), `"method":"initialize"`) { + sawInit = true + break + } + } + require.True(t, sawInit, "an initialize JSON-RPC call must have been sent after endpoint event") + host.CompleteHttp() + }) +} + +// ----------------------------------------------------------------------------- +// Phase 2.3 — mcp-proxy Initialize callback error paths (proxy_tool.go:107-204) +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_InitializeBackend500(t *testing.T) { + // Initialize HTTP callback receives non-2xx status → must inject error response. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + // Backend returns 500 on the initialize call. + host.CallOnHttpCall([][2]string{{":status", "500"}, {"content-type", "application/json"}}, []byte(`{"error":"down"}`)) + + // Errors go through SendLocalResponse, not the streaming body. + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "non-200 initialize response must inject a local error response") + require.Contains(t, string(localResp.Data), "error") + host.CompleteHttp() + }) +} + +func TestMcpProxyServer_InitializeMalformedJSON(t *testing.T) { + // Initialize response body is not valid JSON → parse error path. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, []byte(`{not valid json`)) + + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "unparseable initialize response must inject a local error response") + require.Contains(t, string(localResp.Data), "error") + host.CompleteHttp() + }) +} + +func TestMcpProxyServer_InitializeSSEContentType(t *testing.T) { + // Initialize response carries text/event-stream → parseSSEResponse path. + // We send a valid SSE-wrapped JSON-RPC success body so the proxy unwraps and + // progresses past initialize. End-to-end completion is not the point — we + // just need to drive the SSE-unwrap branch. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + // SSE-wrapped initialize success response — proxy must extract the data line. + sseInit := "event: message\ndata: " + `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{"tools":{}},"serverInfo":{"name":"X","version":"1"}}}` + "\n\n" + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}, {"mcp-session-id", "sse-session"}}, []byte(sseInit)) + + // If SSE-unwrap worked, the proxy will have moved to sending the initialized notification. + // Look for a 2nd outbound callout — its presence confirms initialize succeeded. + callouts := host.GetHttpCalloutAttributes() + var sawNotification bool + for _, c := range callouts { + if strings.Contains(string(c.Body), "notifications/initialized") { + sawNotification = true + break + } + } + require.True(t, sawNotification, "SSE-wrapped initialize response must be unwrapped so notification fires") + host.CompleteHttp() + }) +} + +func TestMcpProxyServer_InitializeUnknownErrorCode(t *testing.T) { + // Initialize returns a JSON-RPC error with a code OTHER than -32602 + // → generic "backend initialization failed" path. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + // JSON-RPC error with code != -32602. + host.CallOnHttpCall([][2]string{{":status", "200"}, {"content-type", "application/json"}}, + []byte(`{"jsonrpc":"2.0","id":1,"error":{"code":-32603,"message":"internal"}}`)) + + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "unknown error code must inject a local error response") + require.Contains(t, string(localResp.Data), "error") + host.CompleteHttp() + }) +} + +// ----------------------------------------------------------------------------- +// Phase 2.4 — plugin.go HOST entry edge cases (onHttpRequestHeaders) +// ----------------------------------------------------------------------------- + +func TestPlugin_GetMethodRejected(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(restMCPServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "GET"}, {":path", "/mcp"}, + }) + require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) + + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "GET must produce a local 405 response") + require.Equal(t, uint32(405), localResp.StatusCode) + host.CompleteHttp() + }) +} + +func TestPlugin_DeleteMethodRejected(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(restMCPServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "DELETE"}, {":path", "/mcp"}, + }) + require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) + + localResp := host.GetLocalResponse() + require.NotNil(t, localResp, "DELETE must produce a local 405 response") + require.Equal(t, uint32(405), localResp.StatusCode) + host.CompleteHttp() + }) +} + +// Note: the wasm-go test harness always reports a request body as present, so +// the "POST with no body → 400" branch is not exercisable here. The dedicated +// pure-test for ctx.HasRequestBody() handling lives in pkg/mcp/server tests. + +func TestPlugin_McpProtocolVersionHeaderStripped(t *testing.T) { + // MCP-Protocol-Version is parsed and removed; request continues normally. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(restMCPServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + {"MCP-Protocol-Version", "2025-03-26"}, + }) + // Valid version: header is consumed, request continues. + require.Equal(t, types.HeaderStopIteration, action) + // The header should have been removed before forwarding. + _, stillPresent := test.GetHeaderValue(host.GetRequestHeaders(), "MCP-Protocol-Version") + require.False(t, stillPresent, "MCP-Protocol-Version header must be stripped from forwarded request") + host.CompleteHttp() + }) +} + +func TestPlugin_McpProtocolVersionUnsupportedStillStripped(t *testing.T) { + // Unsupported version logs a warning but still strips the header and continues. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(restMCPServerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, + {"content-type", "application/json"}, + {"MCP-Protocol-Version", "9999-99-99"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + _, stillPresent := test.GetHeaderValue(host.GetRequestHeaders(), "MCP-Protocol-Version") + require.False(t, stillPresent, "even unsupported MCP-Protocol-Version must be stripped") + host.CompleteHttp() + }) +} + +func TestMcpProxyServerSSE_PartialChunkBuffered(t *testing.T) { + // SSE event arrives split across two chunks — first chunk has no terminator, + // second chunk completes the message. Proxy must wait, not error. + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyServerSSEConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{{":authority", "x"}, {":method", "POST"}, {":path", "/mcp"}, {"content-type", "application/json"}}) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + host.CallOnHttpResponseHeaders([][2]string{{":status", "200"}, {"content-type", "text/event-stream"}}) + + // First chunk: prefix only, no blank line terminator. + host.CallOnHttpStreamingResponseBody([]byte("event: endpoint\ndata: /sse/sess"), false) + + // No callouts yet — message not yet complete. + require.Empty(t, host.GetHttpCalloutAttributes(), + "incomplete chunk must not trigger any upstream calls") + + // Second chunk: completes data + blank line. + host.CallOnHttpStreamingResponseBody([]byte("ion-split\n\n"), false) + + // Now initialize should have been sent. + callouts := host.GetHttpCalloutAttributes() + require.NotEmpty(t, callouts, "complete endpoint event must trigger initialize") + host.CompleteHttp() + }) +} + +// ----------------------------------------------------------------------------- +// Phase 2.5 — ExtractAndRemoveIncomingCredential (HOST end-to-end) +// +// The pure function lives in pkg/mcp/server/auth_utils.go and is HOST-coupled +// because it calls proxywasm.GetHttpRequestHeader / RemoveHttpRequestHeader. +// We exercise it via the mcp-proxy tools/list path, which is the cleanest entry +// that hits ExtractAndRemoveIncomingCredential before any upstream call. +// +// Server shape: downstreamSecurity with Passthrough=true → the credential +// extracted from the *incoming* request is reused on the *upstream* request. +// Verifying the upstream callout headers proves both the extraction and the +// removal worked. +// ----------------------------------------------------------------------------- + +// mcpProxyPassthroughApiKeyConfig — downstream apiKey/header passthrough to +// upstream apiKey/header. The two schemes have different header names so we +// can independently assert (a) downstream header was removed and (b) upstream +// header carries the passthrough value. +var mcpProxyPassthroughApiKeyConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{ + "name": "proxy-passthrough-apikey", + "type": "mcp-proxy", + "transport": "http", + "mcpServerURL": "http://backend-mcp.example.com/mcp", + "timeout": 5000, + "defaultDownstreamSecurity": map[string]interface{}{ + "id": "ClientKey", + "passthrough": true, + }, + "defaultUpstreamSecurity": map[string]interface{}{ + "id": "BackendKey", + }, + "securitySchemes": []map[string]interface{}{ + { + "id": "ClientKey", + "type": "apiKey", + "in": "header", + "name": "X-Client-Key", + }, + { + "id": "BackendKey", + "type": "apiKey", + "in": "header", + "name": "X-Backend-Key", + "defaultCredential": "fallback-default", + }, + }, + }, + "tools": []map[string]interface{}{ + {"name": "noop", "type": "mcp-proxy", "description": "noop"}, + }, + }) + return data +}() + +func TestExtractAndRemoveIncomingCredential_ApiKeyHeaderPassthrough(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyPassthroughApiKeyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + {"X-Client-Key", "secret-from-client"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + callouts := host.GetHttpCalloutAttributes() + require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout") + + init := callouts[0] + // Passthrough credential rides on the upstream scheme's header name. + require.True(t, test.HasHeaderWithValue(init.Headers, "X-Backend-Key", "secret-from-client"), + "upstream initialize must carry the extracted client credential under the upstream header name") + + // The downstream header itself must NOT leak through to the upstream call. + if v, present := test.GetHeaderValue(init.Headers, "X-Client-Key"); present { + t.Errorf("downstream credential header X-Client-Key must be removed; got %q", v) + } + host.CompleteHttp() + }) +} + +// mcpProxyPassthroughBearerConfig — downstream http/bearer passthrough to +// upstream http/bearer. Stripping `Bearer ` from the incoming Authorization +// is part of ExtractAndRemoveIncomingCredential's contract. +var mcpProxyPassthroughBearerConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "server": map[string]interface{}{ + "name": "proxy-passthrough-bearer", + "type": "mcp-proxy", + "transport": "http", + "mcpServerURL": "http://backend-mcp.example.com/mcp", + "timeout": 5000, + "defaultDownstreamSecurity": map[string]interface{}{ + "id": "ClientBearer", + "passthrough": true, + }, + "defaultUpstreamSecurity": map[string]interface{}{ + "id": "BackendBearer", + }, + "securitySchemes": []map[string]interface{}{ + {"id": "ClientBearer", "type": "http", "scheme": "bearer"}, + {"id": "BackendBearer", "type": "http", "scheme": "bearer", "defaultCredential": "default-token"}, + }, + }, + "tools": []map[string]interface{}{ + {"name": "noop", "type": "mcp-proxy", "description": "noop"}, + }, + }) + return data +}() + +func TestExtractAndRemoveIncomingCredential_HttpBearerPassthrough(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyPassthroughBearerConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + {"Authorization", "Bearer client-token-xyz"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + callouts := host.GetHttpCalloutAttributes() + require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout") + init := callouts[0] + + // Upstream Authorization must use the EXTRACTED token (Bearer prefix + // stripped on the way in, re-applied on the way out by ApplySecurity). + authValue, present := test.GetHeaderValue(init.Headers, "Authorization") + require.True(t, present, "upstream must carry Authorization for upstream bearer scheme") + require.Equal(t, "Bearer client-token-xyz", authValue, + "passthrough token must round-trip as `Bearer ` to upstream") + host.CompleteHttp() + }) +} + +// Missing downstream header — ExtractAndRemoveIncomingCredential returns "" +// (no error). With Passthrough=true but no incoming credential, passthrough +// is skipped and the upstream scheme falls back to its DefaultCredential. +func TestExtractAndRemoveIncomingCredential_MissingHeaderFallsBackToDefault(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + host, status := test.NewTestHost(mcpProxyPassthroughApiKeyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.InitHttp() + // Note: no X-Client-Key header on the way in. + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "mcp.example.com"}, + {":method", "POST"}, + {":path", "/mcp"}, + {"content-type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) + + callouts := host.GetHttpCalloutAttributes() + require.NotEmpty(t, callouts, "tools/list must trigger an initialize callout") + init := callouts[0] + + // Missing client credential → not an error, just fall through to default. + require.True(t, test.HasHeaderWithValue(init.Headers, "X-Backend-Key", "fallback-default"), + "missing client credential must NOT cause an error; upstream falls back to DefaultCredential") + host.CompleteHttp() + }) +} diff --git a/plugins/wasm-go/pkg/mcp/go.mod b/plugins/wasm-go/pkg/mcp/go.mod index c4fbd3a16..3cc512849 100644 --- a/plugins/wasm-go/pkg/mcp/go.mod +++ b/plugins/wasm-go/pkg/mcp/go.mod @@ -10,6 +10,8 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 + google.golang.org/protobuf v1.36.6 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -28,12 +30,9 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect - github.com/tetratelabs/wazero v1.7.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect golang.org/x/crypto v0.26.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/plugins/wasm-go/pkg/mcp/go.sum b/plugins/wasm-go/pkg/mcp/go.sum new file mode 100644 index 000000000..06b699b85 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/go.sum @@ -0,0 +1,73 @@ +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= +github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= +github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE= +github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas= +github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8= +github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= +github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/wasm-go/pkg/mcp/server/auth_utils_test.go b/plugins/wasm-go/pkg/mcp/server/auth_utils_test.go new file mode 100644 index 000000000..5f05cdd5f --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/auth_utils_test.go @@ -0,0 +1,432 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/base64" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubProvider implements SecuritySchemeProvider for ApplySecurity tests. +type stubProvider struct { + schemes map[string]SecurityScheme +} + +func (p *stubProvider) GetSecurityScheme(id string) (SecurityScheme, bool) { + s, ok := p.schemes[id] + return s, ok +} + +func newProvider(schemes ...SecurityScheme) *stubProvider { + m := make(map[string]SecurityScheme, len(schemes)) + for _, s := range schemes { + m[s.ID] = s + } + return &stubProvider{schemes: m} +} + +// mustParseURL helps build the ParsedURL field of AuthRequestContext. +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return u +} + +func findHeader(headers [][2]string, key string) (string, bool) { + for _, kv := range headers { + if strings.EqualFold(kv[0], key) { + return kv[1], true + } + } + return "", false +} + +func countHeader(headers [][2]string, key string) int { + c := 0 + for _, kv := range headers { + if strings.EqualFold(kv[0], key) { + c++ + } + } + return c +} + +// ----------------------------------------------------------------------------- +// setOrReplaceHeader +// ----------------------------------------------------------------------------- + +func TestSetOrReplaceHeader_AppendsWhenAbsent(t *testing.T) { + headers := [][2]string{{"X-Other", "1"}} + setOrReplaceHeader(&headers, "X-New", "v") + require.Len(t, headers, 2) + v, ok := findHeader(headers, "X-New") + require.True(t, ok) + assert.Equal(t, "v", v) +} + +func TestSetOrReplaceHeader_ReplacesCaseInsensitively(t *testing.T) { + headers := [][2]string{ + {"Content-Type", "text/plain"}, + {"AUTHORIZATION", "old"}, + } + setOrReplaceHeader(&headers, "authorization", "new") + v, ok := findHeader(headers, "Authorization") + require.True(t, ok) + assert.Equal(t, "new", v) + // Replacement is in-place, no duplicate header inserted. + assert.Equal(t, 1, countHeader(headers, "Authorization")) + assert.Len(t, headers, 2) +} + +func TestSetOrReplaceHeader_PreservesOriginalKeyOnReplace(t *testing.T) { + headers := [][2]string{{"X-Token", "old"}} + setOrReplaceHeader(&headers, "x-token", "new") + // Replacement updates value but keeps the original key casing. + assert.Equal(t, [][2]string{{"X-Token", "new"}}, headers) +} + +func TestSetOrReplaceHeader_FirstMatchWins(t *testing.T) { + headers := [][2]string{ + {"X-Dup", "first"}, + {"x-dup", "second"}, + } + setOrReplaceHeader(&headers, "X-Dup", "new") + // Only the first occurrence is replaced; the second is left alone. + assert.Equal(t, [][2]string{{"X-Dup", "new"}, {"x-dup", "second"}}, headers) +} + +func TestSetOrReplaceHeader_IdempotentOnSecondCall(t *testing.T) { + headers := [][2]string{} + setOrReplaceHeader(&headers, "X-K", "v") + setOrReplaceHeader(&headers, "X-K", "v") + require.Len(t, headers, 1) + assert.Equal(t, "v", headers[0][1]) +} + +// ----------------------------------------------------------------------------- +// ApplySecurity — early returns / preconditions +// ----------------------------------------------------------------------------- + +func TestApplySecurity_EmptyIDIsNoOp(t *testing.T) { + reqCtx := &AuthRequestContext{ + Headers: [][2]string{{"X-Other", "x"}}, + ParsedURL: mustParseURL(t, "/p?a=1"), + } + err := ApplySecurity(SecurityRequirement{}, newProvider(), reqCtx) + require.NoError(t, err) + assert.Equal(t, [][2]string{{"X-Other", "x"}}, reqCtx.Headers) + assert.Equal(t, "a=1", reqCtx.ParsedURL.RawQuery) +} + +func TestApplySecurity_NilParsedURLReturnsError(t *testing.T) { + reqCtx := &AuthRequestContext{} + err := ApplySecurity( + SecurityRequirement{ID: "x"}, + newProvider(SecurityScheme{ID: "x", Type: "apiKey", In: "header", Name: "X"}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "ParsedURL") +} + +func TestApplySecurity_SchemeIDNotFound(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity(SecurityRequirement{ID: "missing"}, newProvider(), reqCtx) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// ----------------------------------------------------------------------------- +// ApplySecurity — apiKey × {header, query} +// ----------------------------------------------------------------------------- + +func TestApplySecurity_ApiKey_Header_DefaultCredential(t *testing.T) { + reqCtx := &AuthRequestContext{ + Headers: [][2]string{{"X-Other", "x"}}, + ParsedURL: mustParseURL(t, "/p"), + } + err := ApplySecurity( + SecurityRequirement{ID: "K"}, + newProvider(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "def", + }), + reqCtx, + ) + require.NoError(t, err) + v, ok := findHeader(reqCtx.Headers, "X-Api-Key") + require.True(t, ok) + assert.Equal(t, "def", v) +} + +func TestApplySecurity_ApiKey_Header_ExplicitOverridesDefault(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "override"}, + newProvider(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "def", + }), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "X-Api-Key") + assert.Equal(t, "override", v) +} + +func TestApplySecurity_ApiKey_Header_PassthroughBeatsExplicitAndDefault(t *testing.T) { + reqCtx := &AuthRequestContext{ + ParsedURL: mustParseURL(t, "/p"), + PassthroughCredential: "from-client", + } + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "configured"}, + newProvider(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "def", + }), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "X-Api-Key") + assert.Equal(t, "from-client", v, "passthrough wins over configured + default") +} + +func TestApplySecurity_ApiKey_Header_ReplacesExisting(t *testing.T) { + reqCtx := &AuthRequestContext{ + Headers: [][2]string{{"x-api-key", "stale"}}, + ParsedURL: mustParseURL(t, "/p"), + } + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "fresh"}, + newProvider(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + }), + reqCtx, + ) + require.NoError(t, err) + // Case-insensitive replace, no duplicate header. + assert.Equal(t, 1, countHeader(reqCtx.Headers, "X-Api-Key")) + v, _ := findHeader(reqCtx.Headers, "X-Api-Key") + assert.Equal(t, "fresh", v) +} + +func TestApplySecurity_ApiKey_NoCredentialAvailable(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "no credential") +} + +func TestApplySecurity_ApiKey_Header_MissingName(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "v"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: ""}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "name") +} + +func TestApplySecurity_ApiKey_Query_AppendsToExistingQuery(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p?existing=1")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "secret"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}), + reqCtx, + ) + require.NoError(t, err) + q := reqCtx.ParsedURL.Query() + assert.Equal(t, "secret", q.Get("api_key")) + assert.Equal(t, "1", q.Get("existing"), "existing query params must be preserved") +} + +func TestApplySecurity_ApiKey_Query_NoExistingQuery(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "secret"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}), + reqCtx, + ) + require.NoError(t, err) + assert.Equal(t, "api_key=secret", reqCtx.ParsedURL.RawQuery) +} + +func TestApplySecurity_ApiKey_Query_OverwritesExistingValue(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p?api_key=stale&keep=me")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "fresh"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: "api_key"}), + reqCtx, + ) + require.NoError(t, err) + q := reqCtx.ParsedURL.Query() + assert.Equal(t, "fresh", q.Get("api_key")) + assert.Equal(t, "me", q.Get("keep")) + // Sanity: no duplicate api_key entries. + assert.Len(t, q["api_key"], 1) +} + +func TestApplySecurity_ApiKey_Query_MissingName(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "v"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "query", Name: ""}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "name") +} + +func TestApplySecurity_ApiKey_UnsupportedIn(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "K", Credential: "v"}, + newProvider(SecurityScheme{ID: "K", Type: "apiKey", In: "cookie", Name: "X"}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported apiKey") +} + +// ----------------------------------------------------------------------------- +// ApplySecurity — http × {bearer, basic} +// ----------------------------------------------------------------------------- + +func TestApplySecurity_HttpBearer_AddsPrefix(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "raw-token"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + assert.Equal(t, "Bearer raw-token", v) +} + +func TestApplySecurity_HttpBearer_RespectsExistingPrefix(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "Bearer already-prefixed"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + assert.Equal(t, "Bearer already-prefixed", v, "must not double-prefix") +} + +func TestApplySecurity_HttpBasic_UserPassEncoded(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "alice:s3cret"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + expected := "Basic " + base64.StdEncoding.EncodeToString([]byte("alice:s3cret")) + assert.Equal(t, expected, v) +} + +func TestApplySecurity_HttpBasic_PreEncodedToken(t *testing.T) { + // No colon → treated as already-base64 token. + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "QWxpY2U6czNjcmV0"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + assert.Equal(t, "Basic QWxpY2U6czNjcmV0", v) +} + +func TestApplySecurity_HttpBasic_RespectsExistingPrefix(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "Basic ZXhpc3Rpbmc="}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + assert.Equal(t, "Basic ZXhpc3Rpbmc=", v, "must not re-encode already-prefixed value") +} + +func TestApplySecurity_HttpBasic_PassthroughTreatedAsTokenPart(t *testing.T) { + reqCtx := &AuthRequestContext{ + ParsedURL: mustParseURL(t, "/p"), + PassthroughCredential: "QWxpY2U6czNjcmV0", // base64-encoded "alice:s3cret" + } + err := ApplySecurity( + SecurityRequirement{ID: "B"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "basic"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + // Passthrough path must NOT re-base64-encode; only adds the prefix. + assert.Equal(t, "Basic QWxpY2U6czNjcmV0", v) +} + +func TestApplySecurity_HttpBearer_Passthrough(t *testing.T) { + reqCtx := &AuthRequestContext{ + ParsedURL: mustParseURL(t, "/p"), + PassthroughCredential: "client-token", + } + err := ApplySecurity( + SecurityRequirement{ID: "B"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}), + reqCtx, + ) + require.NoError(t, err) + v, _ := findHeader(reqCtx.Headers, "Authorization") + assert.Equal(t, "Bearer client-token", v) +} + +func TestApplySecurity_HttpUnsupportedScheme(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "x"}, + newProvider(SecurityScheme{ID: "B", Type: "http", Scheme: "digest"}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported http scheme") +} + +func TestApplySecurity_UnsupportedSchemeType(t *testing.T) { + reqCtx := &AuthRequestContext{ParsedURL: mustParseURL(t, "/p")} + err := ApplySecurity( + SecurityRequirement{ID: "B", Credential: "x"}, + newProvider(SecurityScheme{ID: "B", Type: "oauth2"}), + reqCtx, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported security scheme type") +} diff --git a/plugins/wasm-go/pkg/mcp/server/composed_server_test.go b/plugins/wasm-go/pkg/mcp/server/composed_server_test.go new file mode 100644 index 000000000..b11da8d66 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/composed_server_test.go @@ -0,0 +1,251 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubTool is a minimal Tool implementation for registry population in tests. +type stubTool struct { + desc string + input map[string]any + output map[string]any +} + +func (s *stubTool) Create(_ []byte) Tool { return s } +func (s *stubTool) Call(_ HttpContext, _ Server) error { return nil } +func (s *stubTool) Description() string { return s.desc } +func (s *stubTool) InputSchema() map[string]any { return s.input } +func (s *stubTool) OutputSchema() map[string]any { return s.output } + +func newPopulatedRegistry(t *testing.T) *GlobalToolRegistry { + t.Helper() + r := &GlobalToolRegistry{} + r.Initialize() + r.RegisterTool("alpha", "search", &stubTool{ + desc: "alpha search", + input: map[string]any{"type": "object", "props": "a"}, + output: map[string]any{"type": "string"}, + }) + r.RegisterTool("alpha", "fetch", &stubTool{ + desc: "alpha fetch", + input: map[string]any{"type": "object", "props": "f"}, + output: nil, + }) + r.RegisterTool("beta", "search", &stubTool{ + desc: "beta search", + input: map[string]any{"type": "object", "props": "bs"}, + output: map[string]any{"type": "array"}, + }) + return r +} + +func TestComposedMCPServer_NewAndGetName(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("myset", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search"}}, + }, r) + require.NotNil(t, cs) + assert.Equal(t, "myset", cs.GetName()) +} + +func TestComposedMCPServer_AddMCPTool_IsNoOp(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("set", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search"}}, + }, r) + + // AddMCPTool should not panic and should be a no-op (tool not added). + ret := cs.AddMCPTool("ignored", &stubTool{desc: "x"}) + assert.Same(t, cs, ret, "AddMCPTool should return the server itself") + + tools := cs.GetMCPTools() + _, exists := tools["ignored"] + assert.False(t, exists, "no-op AddMCPTool must not register the tool") + // Only the one from registry should remain. + _, found := tools["alpha___search"] + assert.True(t, found, "registered tool should be present") + assert.Len(t, tools, 1) +} + +func TestComposedMCPServer_GetMCPTools_AggregatesWithPrefix(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("compound", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search", "fetch"}}, + {ServerName: "beta", Tools: []string{"search"}}, + }, r) + + tools := cs.GetMCPTools() + require.Len(t, tools, 3) + + // All keys must be prefixed with the original server name and the splitter. + want := []string{"alpha___search", "alpha___fetch", "beta___search"} + for _, k := range want { + _, ok := tools[k] + assert.True(t, ok, "expected composed tool key %q", k) + } + + // Descriptions / input schemas are forwarded from the registry's ToolInfo. + dt, ok := tools["alpha___search"].(*DescriptiveTool) + require.True(t, ok) + assert.Equal(t, "alpha search", dt.Description()) + assert.Equal(t, "a", dt.InputSchema()["props"]) + assert.Equal(t, "string", dt.OutputSchema()["type"]) + + // Tool without OutputSchema in registry produces a DescriptiveTool with nil output. + dt2, ok := tools["alpha___fetch"].(*DescriptiveTool) + require.True(t, ok) + assert.Nil(t, dt2.OutputSchema()) +} + +func TestComposedMCPServer_GetMCPTools_MissingToolIsSkipped(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("set", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search", "nonexistent"}}, + {ServerName: "ghost", Tools: []string{"any"}}, // entire server missing + }, r) + + tools := cs.GetMCPTools() + // Only "alpha___search" survives; missing ones are logged and skipped. + assert.Len(t, tools, 1) + _, ok := tools["alpha___search"] + assert.True(t, ok) +} + +func TestComposedMCPServer_GetMCPTools_SameSimpleNameDifferentServersDoNotCollide(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("set", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search"}}, + {ServerName: "beta", Tools: []string{"search"}}, + }, r) + + tools := cs.GetMCPTools() + require.Len(t, tools, 2) + assert.Contains(t, tools, "alpha___search") + assert.Contains(t, tools, "beta___search") +} + +func TestComposedMCPServer_GetMCPTools_EmptyConfig(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("empty", nil, r) + tools := cs.GetMCPTools() + assert.NotNil(t, tools, "should return a non-nil empty map") + assert.Empty(t, tools) +} + +func TestComposedMCPServer_SetGetConfig_BytePointer(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("set", nil, r) + + // Empty config: GetConfig must not modify the destination. + var dst []byte + dst = []byte("untouched") + cs.GetConfig(&dst) + assert.Equal(t, []byte("untouched"), dst, "GetConfig on empty config must be a no-op") + + // After SetConfig, byte-pointer destinations receive the stored bytes. + cs.SetConfig([]byte(`{"k":"v"}`)) + var out []byte + cs.GetConfig(&out) + assert.Equal(t, []byte(`{"k":"v"}`), out) +} + +func TestComposedMCPServer_GetConfig_UnhandledDestinationType(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("set", nil, r) + cs.SetConfig([]byte(`{"k":"v"}`)) + + // Non-byte-pointer destinations are logged and left untouched (no panic). + var s string = "untouched" + cs.GetConfig(&s) + assert.Equal(t, "untouched", s) + + type holder struct{ K string } + h := holder{K: "untouched"} + cs.GetConfig(&h) + assert.Equal(t, "untouched", h.K) +} + +func TestComposedMCPServer_Clone_IndependentConfig(t *testing.T) { + r := newPopulatedRegistry(t) + cs := NewComposedMCPServer("orig", []ServerToolConfig{ + {ServerName: "alpha", Tools: []string{"search"}}, + }, r) + cs.SetConfig([]byte(`{"a":1}`)) + + clonedI := cs.Clone() + require.NotNil(t, clonedI) + cloned, ok := clonedI.(*ComposedMCPServer) + require.True(t, ok) + assert.NotSame(t, cs, cloned, "Clone must return a new struct pointer") + assert.Equal(t, cs.GetName(), cloned.GetName()) + + // Confirm both see the same config initially. + var origBytes, clonedBytes []byte + cs.GetConfig(&origBytes) + cloned.GetConfig(&clonedBytes) + assert.Equal(t, origBytes, clonedBytes) + + // Mutating clone's config must not propagate to original. + cloned.SetConfig([]byte(`{"a":2}`)) + cs.GetConfig(&origBytes) + assert.Equal(t, []byte(`{"a":1}`), origBytes, "original config must remain unchanged after cloning") + + // Cloned still resolves tools through the shared registry. + assert.Contains(t, cloned.GetMCPTools(), "alpha___search") +} + +func TestDescriptiveTool_Create_ReturnsNewInstanceWithSameFields(t *testing.T) { + dt := &DescriptiveTool{ + description: "d", + inputSchema: map[string]any{"k": "v"}, + outputSchema: map[string]any{"o": "w"}, + } + created := dt.Create([]byte(`{"ignored":true}`)) + require.NotNil(t, created) + cdt, ok := created.(*DescriptiveTool) + require.True(t, ok) + assert.NotSame(t, dt, cdt, "Create must return a new instance") + assert.Equal(t, dt.Description(), cdt.Description()) + assert.Equal(t, dt.InputSchema(), cdt.InputSchema()) + assert.Equal(t, dt.OutputSchema(), cdt.OutputSchema()) +} + +func TestDescriptiveTool_Call_ReturnsError(t *testing.T) { + dt := &DescriptiveTool{description: "d"} + err := dt.Call(nil, nil) + require.Error(t, err, "DescriptiveTool.Call is a guard rail — must return an error") +} + +func TestDescriptiveTool_Accessors(t *testing.T) { + dt := &DescriptiveTool{ + description: "desc", + inputSchema: map[string]any{"in": 1}, + outputSchema: map[string]any{"out": 2}, + } + assert.Equal(t, "desc", dt.Description()) + assert.Equal(t, map[string]any{"in": 1}, dt.InputSchema()) + assert.Equal(t, map[string]any{"out": 2}, dt.OutputSchema()) + + // Nil schemas must round-trip as nil. + empty := &DescriptiveTool{} + assert.Equal(t, "", empty.Description()) + assert.Nil(t, empty.InputSchema()) + assert.Nil(t, empty.OutputSchema()) +} diff --git a/plugins/wasm-go/pkg/mcp/server/plugin_test.go b/plugins/wasm-go/pkg/mcp/server/plugin_test.go new file mode 100644 index 000000000..a0ea33c78 --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/plugin_test.go @@ -0,0 +1,620 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// ----------------------------------------------------------------------------- +// validateURL +// ----------------------------------------------------------------------------- + +func TestValidateURL(t *testing.T) { + cases := []struct { + name string + in string + wantErr string // empty = expect no error + }{ + {"empty string", "", "cannot be empty"}, + {"path only", "/api/foo", ""}, + {"http with host", "http://backend.example/mcp", ""}, + {"https with host", "https://backend.example/mcp", ""}, + {"http with userinfo", "http://user:pass@backend.example/mcp", ""}, + {"http with port", "http://backend.example:8080/mcp", ""}, + {"scheme without host", "http://", "must include a host"}, + {"unsupported scheme ftp", "ftp://example/x", "unsupported URL scheme"}, + {"unsupported scheme ws", "ws://example/x", "unsupported URL scheme"}, + {"contains space - parse error", "http://exa mple/x", "invalid URL format"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := validateURL(c.in) + if c.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), c.wantErr) + } + }) + } +} + +// ----------------------------------------------------------------------------- +// computeEffectiveAllowToolsFromHeader (4-way matrix + edge cases) +// ----------------------------------------------------------------------------- + +func TestComputeEffectiveAllowToolsFromHeader_BothAbsent(t *testing.T) { + got := computeEffectiveAllowToolsFromHeader(nil, "", false) + assert.Nil(t, got, "no restrictions on either side → allow all (nil)") +} + +func TestComputeEffectiveAllowToolsFromHeader_HeaderOnly(t *testing.T) { + got := computeEffectiveAllowToolsFromHeader(nil, "a,b,c", true) + require.NotNil(t, got) + assert.Len(t, *got, 3) + _, hasA := (*got)["a"] + _, hasB := (*got)["b"] + _, hasC := (*got)["c"] + assert.True(t, hasA && hasB && hasC) +} + +func TestComputeEffectiveAllowToolsFromHeader_ConfigOnly(t *testing.T) { + cfg := map[string]struct{}{"x": {}, "y": {}} + got := computeEffectiveAllowToolsFromHeader(&cfg, "", false) + require.NotNil(t, got) + assert.Equal(t, &cfg, got, "config restrictions returned as-is when header absent") +} + +func TestComputeEffectiveAllowToolsFromHeader_BothPresent_Intersection(t *testing.T) { + cfg := map[string]struct{}{"a": {}, "b": {}, "c": {}} + got := computeEffectiveAllowToolsFromHeader(&cfg, "b,c,d", true) + require.NotNil(t, got) + assert.Len(t, *got, 2) + _, hasB := (*got)["b"] + _, hasC := (*got)["c"] + _, hasD := (*got)["d"] + _, hasA := (*got)["a"] + assert.True(t, hasB && hasC, "intersection keeps common entries") + assert.False(t, hasD, "header-only entries are dropped") + assert.False(t, hasA, "config-only entries are dropped") +} + +func TestComputeEffectiveAllowToolsFromHeader_HeaderEmptyStringButPresent(t *testing.T) { + // headerExists=true with empty string → produces empty map (deny all) + cfg := map[string]struct{}{"x": {}} + got := computeEffectiveAllowToolsFromHeader(&cfg, "", true) + require.NotNil(t, got) + assert.Empty(t, *got, "empty header with headerExists=true intersects to empty set") +} + +func TestComputeEffectiveAllowToolsFromHeader_HeaderWhitespaceAndDuplicates(t *testing.T) { + got := computeEffectiveAllowToolsFromHeader(nil, " a , b , ,a, c ,", true) + require.NotNil(t, got) + // "a", "b", "c" — duplicates and empties dropped + assert.Len(t, *got, 3) + for _, k := range []string{"a", "b", "c"} { + _, ok := (*got)[k] + assert.True(t, ok, "missing %q", k) + } +} + +// ----------------------------------------------------------------------------- +// McpServerConfig accessors +// ----------------------------------------------------------------------------- + +func TestMcpServerConfig_GetServerName(t *testing.T) { + c := &McpServerConfig{serverName: "my-server"} + assert.Equal(t, "my-server", c.GetServerName()) +} + +func TestMcpServerConfig_GetIsComposed(t *testing.T) { + c1 := &McpServerConfig{isComposed: false} + c2 := &McpServerConfig{isComposed: true} + assert.False(t, c1.GetIsComposed()) + assert.True(t, c2.GetIsComposed()) +} + +// ----------------------------------------------------------------------------- +// GlobalToolRegistry — extra branches not covered elsewhere +// ----------------------------------------------------------------------------- + +// stubToolWithOutputSchema exercises the ToolWithOutputSchema dispatch in +// GlobalToolRegistry.RegisterTool. +type stubToolWithOutputSchema struct { + desc string + input map[string]any + output map[string]any +} + +func (s *stubToolWithOutputSchema) Create(_ []byte) Tool { return s } +func (s *stubToolWithOutputSchema) Call(_ HttpContext, _ Server) error { return nil } +func (s *stubToolWithOutputSchema) Description() string { return s.desc } +func (s *stubToolWithOutputSchema) InputSchema() map[string]any { return s.input } +func (s *stubToolWithOutputSchema) OutputSchema() map[string]any { return s.output } + +func TestGlobalToolRegistry_RegisterTool_CapturesOutputSchema(t *testing.T) { + r := &GlobalToolRegistry{} + r.Initialize() + r.RegisterTool("srv", "tool", &stubToolWithOutputSchema{ + desc: "d", + input: map[string]any{"in": 1}, + output: map[string]any{"out": 2}, + }) + info, ok := r.GetToolInfo("srv", "tool") + require.True(t, ok) + assert.Equal(t, "d", info.Description) + assert.Equal(t, map[string]any{"in": 1}, info.InputSchema) + assert.Equal(t, map[string]any{"out": 2}, info.OutputSchema, "OutputSchema must be captured when tool implements ToolWithOutputSchema") +} + +func TestGlobalToolRegistry_RegisterTool_PlainToolHasNoOutputSchema(t *testing.T) { + r := &GlobalToolRegistry{} + r.Initialize() + r.RegisterTool("srv", "tool", &stubTool{desc: "d", input: map[string]any{"in": 1}}) + info, ok := r.GetToolInfo("srv", "tool") + require.True(t, ok) + assert.Nil(t, info.OutputSchema) +} + +func TestGlobalToolRegistry_GetToolInfo_Misses(t *testing.T) { + r := &GlobalToolRegistry{} + r.Initialize() + _, ok := r.GetToolInfo("missing", "any") + assert.False(t, ok, "unknown server → not found") + + r.RegisterTool("srv", "real", &stubTool{desc: "d"}) + _, ok = r.GetToolInfo("srv", "missing") + assert.False(t, ok, "unknown tool on known server → not found") +} + +// ----------------------------------------------------------------------------- +// AddMCPServer / addMCPServerOption.Apply +// ----------------------------------------------------------------------------- + +func TestAddMCPServer_FirstAndSecondAreStored(t *testing.T) { + ctx := &Context{} + AddMCPServer("alpha", &stubServer{}).Apply(ctx) + AddMCPServer("beta", &stubServer{}).Apply(ctx) + require.Len(t, ctx.servers, 2) + _, hasA := ctx.servers["alpha"] + _, hasB := ctx.servers["beta"] + assert.True(t, hasA && hasB) +} + +func TestAddMCPServer_DuplicateNamePanics(t *testing.T) { + ctx := &Context{} + AddMCPServer("dup", &stubServer{}).Apply(ctx) + assert.PanicsWithValue(t, + "Conflict! There is a mcp server with the same name:dup", + func() { AddMCPServer("dup", &stubServer{}).Apply(ctx) }) +} + +// stubServer satisfies the Server interface for AddMCPServer tests. +type stubServer struct { + cfg []byte +} + +func (s *stubServer) AddMCPTool(_ string, _ Tool) Server { return s } +func (s *stubServer) GetMCPTools() map[string]Tool { return map[string]Tool{} } +func (s *stubServer) SetConfig(c []byte) { s.cfg = c } +func (s *stubServer) GetConfig(_ any) {} +func (s *stubServer) Clone() Server { copy := *s; return © } + +// ----------------------------------------------------------------------------- +// ToInputSchema — exercises jsonschema.Reflect dispatch +// ----------------------------------------------------------------------------- + +type sampleStruct struct { + Name string `json:"name"` + Count int `json:"count"` + Tags []string `json:"tags"` +} + +func TestToInputSchema_StructByValue(t *testing.T) { + out := ToInputSchema(sampleStruct{}) + require.NotNil(t, out, "must return a populated schema map") + // Reflected schema always has a top-level "properties" map. + props, ok := out["properties"].(map[string]any) + require.True(t, ok, "schema should have a properties object: %v", out) + for _, k := range []string{"name", "count", "tags"} { + _, exists := props[k] + assert.True(t, exists, "missing property %q", k) + } +} + +func TestToInputSchema_StructByPointer(t *testing.T) { + // The function dereferences pointer types before name lookup. + out := ToInputSchema(&sampleStruct{}) + require.NotNil(t, out) + _, ok := out["properties"] + assert.True(t, ok, "pointer-to-struct should resolve to the same schema") +} + +// ----------------------------------------------------------------------------- +// setupMcpProxyServer — error paths +// ----------------------------------------------------------------------------- + +func mustGJSON(t *testing.T, raw string) gjson.Result { + t.Helper() + r := gjson.Parse(raw) + require.True(t, r.Exists(), "raw must parse: %s", raw) + return r +} + +func TestSetupMcpProxyServer_MissingTransport(t *testing.T) { + j := mustGJSON(t, `{"name":"s","type":"mcp-proxy","mcpServerURL":"http://b"}`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "transport") +} + +func TestSetupMcpProxyServer_InvalidTransport(t *testing.T) { + j := mustGJSON(t, `{"transport":"grpc","mcpServerURL":"http://b"}`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid transport value") +} + +func TestSetupMcpProxyServer_MissingMcpServerURL(t *testing.T) { + j := mustGJSON(t, `{"transport":"http"}`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "mcpServerURL is required") +} + +func TestSetupMcpProxyServer_InvalidMcpServerURL(t *testing.T) { + j := mustGJSON(t, `{"transport":"http","mcpServerURL":"ws://nope"}`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid mcpServerURL") +} + +func TestSetupMcpProxyServer_BadSecuritySchemeJson(t *testing.T) { + j := mustGJSON(t, `{ + "transport":"http", + "mcpServerURL":"http://b", + "securitySchemes":[123] + }`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "security scheme") +} + +func TestSetupMcpProxyServer_BadDefaultDownstreamSecurity(t *testing.T) { + j := mustGJSON(t, `{ + "transport":"http", + "mcpServerURL":"http://b", + "defaultDownstreamSecurity": 42 + }`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "defaultDownstreamSecurity") +} + +func TestSetupMcpProxyServer_BadDefaultUpstreamSecurity(t *testing.T) { + j := mustGJSON(t, `{ + "transport":"http", + "mcpServerURL":"http://b", + "defaultUpstreamSecurity": "not-an-object" + }`) + _, err := setupMcpProxyServer("s", j, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "defaultUpstreamSecurity") +} + +func TestSetupMcpProxyServer_HappyPath_AppliesAllFields(t *testing.T) { + raw := `{ + "transport":"sse", + "mcpServerURL":"https://backend.example/mcp", + "timeout":7777, + "passthroughAuthHeader":true, + "securitySchemes":[ + {"id":"K","type":"apiKey","in":"header","name":"X-K","defaultCredential":"d"} + ], + "defaultDownstreamSecurity":{"id":"K"}, + "defaultUpstreamSecurity":{"id":"K"} + }` + srv, err := setupMcpProxyServer("alpha", mustGJSON(t, raw), `{"cfg":1}`) + require.NoError(t, err) + require.NotNil(t, srv) + + assert.Equal(t, "alpha", srv.Name) + assert.Equal(t, TransportSSE, srv.GetTransport()) + assert.Equal(t, "https://backend.example/mcp", srv.GetMcpServerURL()) + assert.Equal(t, 7777, srv.GetTimeout()) + assert.True(t, srv.GetPassthroughAuthHeader()) + _, ok := srv.GetSecurityScheme("K") + assert.True(t, ok) + assert.Equal(t, "K", srv.GetDefaultDownstreamSecurity().ID) + assert.Equal(t, "K", srv.GetDefaultUpstreamSecurity().ID) +} + +// ----------------------------------------------------------------------------- +// parseConfigCore — error / branch paths via ParseConfigCore +// ----------------------------------------------------------------------------- + +func newValidationOpts() *ConfigOptions { + r := &GlobalToolRegistry{} + r.Initialize() + return &ConfigOptions{ + Servers: make(map[string]Server), + ToolRegistry: r, + SkipPreRegisteredServers: false, + } +} + +func TestParseConfigCore_NoServerOrToolSet(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{}`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "'server' or 'toolSet'") +} + +func TestParseConfigCore_SingleServer_MissingName(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{"server":{"type":"rest"}}`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "server.name") +} + +func TestParseConfigCore_PreRegisteredNotInRegistry(t *testing.T) { + // type=="" defaults to "rest", but with no tools and no entry in opts.Servers, + // falls into the "pre-registered" branch which fails with "not registered". + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{"server":{"name":"ghost"}}`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not registered") +} + +func TestParseConfigCore_PreRegisteredSkipped(t *testing.T) { + c := &McpServerConfig{} + opts := newValidationOpts() + opts.SkipPreRegisteredServers = true + err := ParseConfigCore(gjson.Parse(`{"server":{"name":"ghost"}}`), c, opts) + require.NoError(t, err, "skip flag should bypass the not-registered error") + assert.Equal(t, "ghost", c.GetServerName()) + assert.Nil(t, c.server, "no server instance is constructed in skip mode") +} + +func TestParseConfigCore_PreRegisteredFound(t *testing.T) { + c := &McpServerConfig{} + opts := newValidationOpts() + opts.Servers["found"] = &stubServer{} + err := ParseConfigCore(gjson.Parse(`{"server":{"name":"found"}}`), c, opts) + require.NoError(t, err) + require.NotNil(t, c.server, "Clone() of pre-registered server should be stored") +} + +func TestParseConfigCore_McpProxy_BubblesUpSetupError(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{"name":"p","type":"mcp-proxy","transport":"http"} + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "mcpServerURL") +} + +func TestParseConfigCore_McpProxy_HappyPath_NoTools(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{ + "name":"p", + "type":"mcp-proxy", + "transport":"http", + "mcpServerURL":"http://b" + } + }`), c, newValidationOpts()) + require.NoError(t, err) + require.NotNil(t, c.server) + assert.Equal(t, "p", c.GetServerName()) + assert.False(t, c.GetIsComposed()) + // Method handlers are populated for all servers. + assert.NotNil(t, c.methodHandlers["ping"]) + assert.NotNil(t, c.methodHandlers["initialize"]) + assert.NotNil(t, c.methodHandlers["tools/list"]) + assert.NotNil(t, c.methodHandlers["tools/call"]) +} + +func TestParseConfigCore_McpProxy_BadProxyToolJson(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{ + "name":"p", + "type":"mcp-proxy", + "transport":"http", + "mcpServerURL":"http://b" + }, + "tools":[42] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "proxy tool") +} + +func TestParseConfigCore_REST_BadToolJson(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{"name":"r"}, + "tools":[42] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "tool config") +} + +func TestParseConfigCore_REST_BadSecuritySchemeJson(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{ + "name":"r", + "securitySchemes":[42] + }, + "tools":[ + {"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}} + ] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "security scheme") +} + +func TestParseConfigCore_REST_BadDefaultDownstreamSecurity(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{ + "name":"r", + "defaultDownstreamSecurity":42 + }, + "tools":[ + {"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}} + ] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "defaultDownstreamSecurity") +} + +func TestParseConfigCore_REST_BadDefaultUpstreamSecurity(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{ + "name":"r", + "defaultUpstreamSecurity":"oops" + }, + "tools":[ + {"name":"t","description":"d","requestTemplate":{"url":"/x","method":"GET"}} + ] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "defaultUpstreamSecurity") +} + +func TestParseConfigCore_REST_AddRestToolError(t *testing.T) { + // Setting two of {argsToJsonBody, argsToUrlParam, argsToFormBody} bubbles + // the parseTemplates error out through AddRestTool. + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{"name":"r"}, + "tools":[ + { + "name":"bad", + "description":"d", + "requestTemplate":{ + "url":"/x", + "method":"POST", + "argsToJsonBody":true, + "argsToFormBody":true + } + } + ] + }`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "argsTo") +} + +func TestParseConfigCore_REST_HappyPath_AllowToolsParsed(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "server":{"name":"r"}, + "tools":[ + {"name":"t1","description":"d","requestTemplate":{"url":"/x","method":"GET"}}, + {"name":"t2","description":"d","requestTemplate":{"url":"/y","method":"GET"}} + ], + "allowTools":["t1"] + }`), c, newValidationOpts()) + require.NoError(t, err) + require.NotNil(t, c.server) + assert.Equal(t, "r", c.GetServerName()) + assert.False(t, c.GetIsComposed()) +} + +func TestParseConfigCore_ToolSet_BadJson(t *testing.T) { + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{"toolSet":42}`), c, newValidationOpts()) + require.Error(t, err) + assert.Contains(t, err.Error(), "toolSet") +} + +func TestParseConfigCore_ToolSet_HappyPath(t *testing.T) { + opts := newValidationOpts() + // Register one tool so the composed server has something to aggregate. + opts.ToolRegistry.RegisterTool("alpha", "search", &stubTool{ + desc: "alpha search", + input: map[string]any{"type": "object"}, + }) + c := &McpServerConfig{} + err := ParseConfigCore(gjson.Parse(`{ + "toolSet":{ + "name":"compound", + "serverTools":[{"serverName":"alpha","tools":["search"]}] + } + }`), c, opts) + require.NoError(t, err) + assert.True(t, c.GetIsComposed(), "toolSet must produce a composed server") + assert.Equal(t, "compound", c.GetServerName(), "composed server uses toolSet.name as serverName") + require.NotNil(t, c.server) +} + +// ----------------------------------------------------------------------------- +// GetServerFromGlobalContext — exercises the package-level singleton +// ----------------------------------------------------------------------------- + +func TestGetServerFromGlobalContext(t *testing.T) { + // Snapshot and restore globalContext to keep tests independent. + saved := globalContext + defer func() { globalContext = saved }() + globalContext = Context{servers: map[string]Server{ + "existing": &stubServer{}, + }} + + got, ok := GetServerFromGlobalContext("existing") + require.True(t, ok) + assert.NotNil(t, got) + + _, miss := GetServerFromGlobalContext("missing") + assert.False(t, miss) +} + +// ----------------------------------------------------------------------------- +// BaseMCPServer.Clone / CloneBase +// ----------------------------------------------------------------------------- + +func TestBaseMCPServer_Clone_PanicsByContract(t *testing.T) { + // Derived types must implement Clone; BaseMCPServer panics to enforce that. + b := NewBaseMCPServer() + assert.PanicsWithValue(t, + "Clone method must be implemented by derived types", + func() { _ = b.Clone() }) +} + +func TestBaseMCPServer_CloneBase_DeepCopiesTools(t *testing.T) { + b := NewBaseMCPServer() + b.SetConfig([]byte(`{"k":1}`)) + stub := &stubTool{desc: "t"} + b.AddMCPTool("a", stub) + + cloned := b.CloneBase() + + // Mutating clone's tools must not bleed into the original. + cloned.AddMCPTool("b", &stubTool{desc: "x"}) + assert.Len(t, b.GetMCPTools(), 1, "original tools must remain untouched") + assert.Len(t, cloned.GetMCPTools(), 2) + + // Same config bytes are preserved. + got, ok := cloned.GetMCPTools()["a"] + require.True(t, ok) + assert.Same(t, stub, got, "existing tools are shared by reference (no deep clone of Tool itself)") +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go index 6c7079e2e..85efd8f58 100644 --- a/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go +++ b/plugins/wasm-go/pkg/mcp/server/proxy_server_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestMcpProxyServerBasicInterface tests that McpProxyServer implements the Server interface @@ -110,3 +111,329 @@ func TestMcpProxyServerSecuritySchemes(t *testing.T) { assert.Equal(t, scheme.ID, retrievedScheme.ID) assert.Equal(t, scheme.Type, retrievedScheme.Type) } + +// ----------------------------------------------------------------------------- +// SetDefaultDownstreamSecurity / SetDefaultUpstreamSecurity / PassthroughAuth +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_SetGetDefaultDownstreamSecurity(t *testing.T) { + s := NewMcpProxyServer("p") + assert.Equal(t, "", s.GetDefaultDownstreamSecurity().ID, "fresh server has empty default") + s.SetDefaultDownstreamSecurity(SecurityRequirement{ID: "K", Passthrough: true}) + got := s.GetDefaultDownstreamSecurity() + assert.Equal(t, "K", got.ID) + assert.True(t, got.Passthrough) +} + +func TestMcpProxyServer_SetGetDefaultUpstreamSecurity(t *testing.T) { + s := NewMcpProxyServer("p") + s.SetDefaultUpstreamSecurity(SecurityRequirement{ID: "U", Credential: "c"}) + got := s.GetDefaultUpstreamSecurity() + assert.Equal(t, "U", got.ID) + assert.Equal(t, "c", got.Credential) +} + +func TestMcpProxyServer_PassthroughAuthHeaderGetterAndSetter(t *testing.T) { + s := NewMcpProxyServer("p") + assert.False(t, s.GetPassthroughAuthHeader(), "default is false") + s.SetPassthroughAuthHeader(true) + assert.True(t, s.GetPassthroughAuthHeader()) + s.SetPassthroughAuthHeader(false) + assert.False(t, s.GetPassthroughAuthHeader()) +} + +// ----------------------------------------------------------------------------- +// AddSecurityScheme — nil-map branch +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_AddSecurityScheme_InitializesNilMap(t *testing.T) { + // Skip the constructor so we can hit the `securitySchemes == nil` branch. + s := &McpProxyServer{Name: "p"} + s.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"}) + got, ok := s.GetSecurityScheme("K") + require.True(t, ok) + assert.Equal(t, "K", got.ID) +} + +// ----------------------------------------------------------------------------- +// AddMCPTool — delegates to BaseMCPServer +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_AddMCPTool_StoresInBaseAndReturnsSelf(t *testing.T) { + s := NewMcpProxyServer("p") + stub := &stubTool{desc: "d"} + ret := s.AddMCPTool("custom", stub) + assert.Same(t, s, ret, "AddMCPTool returns receiver for chaining") + tools := s.GetMCPTools() + got, ok := tools["custom"] + require.True(t, ok) + assert.Same(t, stub, got) +} + +// ----------------------------------------------------------------------------- +// AddProxyTool — overrides on duplicate name +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_AddProxyTool_DuplicateNameOverwrites(t *testing.T) { + s := NewMcpProxyServer("p") + require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "first"})) + require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "second"})) + + tools := s.GetMCPTools() + assert.Len(t, tools, 1, "duplicate AddProxyTool should overwrite, not duplicate") + cfg, ok := s.GetToolConfig("t") + require.True(t, ok) + assert.Equal(t, "second", cfg.Description, "later AddProxyTool wins") +} + +// ----------------------------------------------------------------------------- +// GetToolConfig — hit and miss +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_GetToolConfig_HitAndMiss(t *testing.T) { + s := NewMcpProxyServer("p") + require.NoError(t, s.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "d"})) + + cfg, ok := s.GetToolConfig("t") + require.True(t, ok) + assert.Equal(t, "d", cfg.Description) + + _, missOK := s.GetToolConfig("missing") + assert.False(t, missOK) +} + +// ----------------------------------------------------------------------------- +// Clone — deep copy of toolsConfig and securitySchemes +// ----------------------------------------------------------------------------- + +func TestMcpProxyServer_Clone_DeepCopiesToolsConfigAndSchemes(t *testing.T) { + orig := NewMcpProxyServer("orig") + orig.SetMcpServerURL("http://b") + orig.SetTimeout(1234) + orig.SetTransport(TransportSSE) + orig.SetPassthroughAuthHeader(true) + orig.SetDefaultDownstreamSecurity(SecurityRequirement{ID: "K"}) + orig.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"}) + require.NoError(t, orig.AddProxyTool(McpProxyToolConfig{Name: "t", Description: "d"})) + + clonedI := orig.Clone() + cloned, ok := clonedI.(*McpProxyServer) + require.True(t, ok) + require.NotSame(t, orig, cloned, "Clone must return a fresh struct") + + // Surface fields are copied. + assert.Equal(t, orig.Name, cloned.Name) + // NOTE: Clone does not propagate mcpServerURL/timeout/transport/passthrough + // nor defaultDownstream/upstreamSecurity. That is intentional today (see + // proxy_server.go:188): cloning is used for per-request isolation of + // tool/security registries only. This test pins that contract — if Clone + // starts copying those fields, update here and document the change. + assert.Equal(t, "", cloned.GetMcpServerURL()) + + // toolsConfig: deep copy — adding to clone doesn't bleed back to orig. + require.NoError(t, cloned.AddProxyTool(McpProxyToolConfig{Name: "extra", Description: "x"})) + _, origHasExtra := orig.GetToolConfig("extra") + assert.False(t, origHasExtra, "tool added to clone must not appear in original") + + // securitySchemes: deep copy — replacing scheme on clone doesn't touch orig. + cloned.AddSecurityScheme(SecurityScheme{ID: "K", Type: "http", Scheme: "bearer"}) + origScheme, _ := orig.GetSecurityScheme("K") + clonedScheme, _ := cloned.GetSecurityScheme("K") + assert.Equal(t, "apiKey", origScheme.Type, "original scheme must remain apiKey") + assert.Equal(t, "http", clonedScheme.Type, "clone reflects the override") +} + +// ----------------------------------------------------------------------------- +// McpProxyTool — Description / InputSchema / OutputSchema / Create +// ----------------------------------------------------------------------------- + +func TestMcpProxyTool_DescriptionAndOutputSchema(t *testing.T) { + tool := &McpProxyTool{ + toolConfig: McpProxyToolConfig{ + Description: "describe me", + OutputSchema: map[string]any{"type": "string"}, + }, + } + assert.Equal(t, "describe me", tool.Description()) + assert.Equal(t, map[string]any{"type": "string"}, tool.OutputSchema()) +} + +func TestMcpProxyTool_InputSchema_RequiredAndOptionalAndEnumAndDefault(t *testing.T) { + tool := &McpProxyTool{ + toolConfig: McpProxyToolConfig{ + Args: []ToolArg{ + {Name: "must", Type: "string", Description: "required", Required: true}, + {Name: "opt", Type: "integer", Description: "optional", Default: 7}, + {Name: "pick", Type: "string", Description: "enum", Enum: []interface{}{"a", "b"}}, + }, + }, + } + schema := tool.InputSchema() + assert.Equal(t, "object", schema["type"]) + required, ok := schema["required"].([]string) + require.True(t, ok) + assert.Equal(t, []string{"must"}, required, "only Required:true args land in required[]") + + props := schema["properties"].(map[string]any) + mustProp := props["must"].(map[string]any) + assert.Equal(t, "string", mustProp["type"]) + + optProp := props["opt"].(map[string]any) + assert.Equal(t, 7, optProp["default"]) + + pickProp := props["pick"].(map[string]any) + assert.Equal(t, []interface{}{"a", "b"}, pickProp["enum"]) +} + +func TestMcpProxyTool_InputSchema_NoArgs(t *testing.T) { + tool := &McpProxyTool{toolConfig: McpProxyToolConfig{}} + schema := tool.InputSchema() + props, ok := schema["properties"].(map[string]any) + require.True(t, ok) + assert.Empty(t, props) + required, ok := schema["required"].([]string) + require.True(t, ok) + assert.Empty(t, required) +} + +func TestMcpProxyTool_Create_NewInstanceWithBoundArgs(t *testing.T) { + orig := &McpProxyTool{ + serverName: "srv", + name: "t", + toolConfig: McpProxyToolConfig{Name: "t"}, + } + created := orig.Create([]byte(`{"q":"hello","n":7}`)) + require.NotSame(t, orig, created, "Create returns a fresh instance") + cloned := created.(*McpProxyTool) + assert.Equal(t, "srv", cloned.serverName) + assert.Equal(t, "t", cloned.name) + assert.Equal(t, "hello", cloned.arguments["q"]) + // JSON unmarshals numbers as float64. + assert.Equal(t, float64(7), cloned.arguments["n"]) +} + +func TestMcpProxyTool_Create_EmptyParamsStillReturnsInstance(t *testing.T) { + orig := &McpProxyTool{serverName: "s", name: "t"} + created := orig.Create(nil) + cloned := created.(*McpProxyTool) + assert.Equal(t, "s", cloned.serverName) + assert.Equal(t, "t", cloned.name) + require.NotNil(t, cloned.arguments) + assert.Empty(t, cloned.arguments, "no params → empty arguments map, not nil") +} + +func TestMcpProxyTool_Create_MalformedJSON_PreservesEmptyArgs(t *testing.T) { + orig := &McpProxyTool{serverName: "s", name: "t"} + created := orig.Create([]byte(`{not json`)) + cloned := created.(*McpProxyTool) + // json.Unmarshal silently fails; arguments stays empty. + assert.Empty(t, cloned.arguments) +} + +// ----------------------------------------------------------------------------- +// ValidateSecurityScheme — full matrix +// ----------------------------------------------------------------------------- + +func TestValidateSecurityScheme(t *testing.T) { + cases := []struct { + name string + scheme SecurityScheme + wantErr string + }{ + {"missing ID", SecurityScheme{Type: "apiKey", In: "header", Name: "X"}, "ID is required"}, + {"invalid type", SecurityScheme{ID: "k", Type: "oauth2"}, "invalid security scheme type"}, + {"apiKey missing name", SecurityScheme{ID: "k", Type: "apiKey", In: "header"}, "name is required"}, + {"apiKey invalid in", SecurityScheme{ID: "k", Type: "apiKey", In: "body", Name: "X"}, "invalid security scheme location"}, + {"apiKey ok header", SecurityScheme{ID: "k", Type: "apiKey", In: "header", Name: "X"}, ""}, + {"apiKey ok query", SecurityScheme{ID: "k", Type: "apiKey", In: "query", Name: "X"}, ""}, + {"apiKey ok cookie", SecurityScheme{ID: "k", Type: "apiKey", In: "cookie", Name: "X"}, ""}, + {"http missing scheme", SecurityScheme{ID: "k", Type: "http"}, "scheme is required for http"}, + {"http bearer ok", SecurityScheme{ID: "k", Type: "http", Scheme: "bearer"}, ""}, + {"http basic ok", SecurityScheme{ID: "k", Type: "http", Scheme: "basic"}, ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := ValidateSecurityScheme(c.scheme) + if c.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), c.wantErr) + } + }) + } +} + +// ----------------------------------------------------------------------------- +// ValidateToolConfig — full matrix +// ----------------------------------------------------------------------------- + +func TestValidateToolConfig(t *testing.T) { + cases := []struct { + name string + config McpProxyToolConfig + wantErr string + }{ + { + "missing name", + McpProxyToolConfig{Description: "d"}, + "tool name is required", + }, + { + "missing description", + McpProxyToolConfig{Name: "t"}, + "tool description is required", + }, + { + "arg missing name", + McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Type: "string", Description: "x"}}}, + "argument name is required", + }, + { + "arg duplicate names", + McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{ + {Name: "a", Type: "string", Description: "x"}, + {Name: "a", Type: "string", Description: "y"}, + }}, + "duplicate argument name", + }, + { + "arg missing description", + McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Name: "a", Type: "string"}}}, + "argument description is required", + }, + { + "arg invalid type", + McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{{Name: "a", Type: "money", Description: "x"}}}, + "invalid argument type", + }, + { + "happy path with multiple typed args", + McpProxyToolConfig{Name: "t", Description: "d", Args: []ToolArg{ + {Name: "s", Type: "string", Description: "x"}, + {Name: "n", Type: "number", Description: "x"}, + {Name: "i", Type: "integer", Description: "x"}, + {Name: "b", Type: "boolean", Description: "x"}, + {Name: "a", Type: "array", Description: "x"}, + {Name: "o", Type: "object", Description: "x"}, + }}, + "", + }, + { + "happy path no args", + McpProxyToolConfig{Name: "t", Description: "d"}, + "", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := ValidateToolConfig(c.config) + if c.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), c.wantErr) + } + }) + } +} diff --git a/plugins/wasm-go/pkg/mcp/server/proxy_tool_pure_test.go b/plugins/wasm-go/pkg/mcp/server/proxy_tool_pure_test.go new file mode 100644 index 000000000..fa65f998c --- /dev/null +++ b/plugins/wasm-go/pkg/mcp/server/proxy_tool_pure_test.go @@ -0,0 +1,323 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ----------------------------------------------------------------------------- +// NewMcpProtocolHandler +// ----------------------------------------------------------------------------- + +func TestNewMcpProtocolHandler(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + require.NotNil(t, h) + assert.Equal(t, "http://backend.example/mcp", h.backendURL) + assert.Equal(t, 5000, h.timeout) + assert.Empty(t, h.sessionID, "fresh handler has no session id until Initialize runs") +} + +// ----------------------------------------------------------------------------- +// parseSSEResponse — fill the remaining branches +// ----------------------------------------------------------------------------- + +func TestParseSSEResponse_OnlyCommentsAndBlanks(t *testing.T) { + // All non-data lines → must surface "no data field found". + _, err := parseSSEResponse([]byte(": only a comment\n\n: another\n")) + require.Error(t, err) + assert.Contains(t, err.Error(), "no data field") +} + +func TestParseSSEResponse_TooLongLine(t *testing.T) { + // Single data line larger than the scanner's 32MB max-token cap. + big := strings.Repeat("x", 33*1024*1024) + _, err := parseSSEResponse([]byte("data: " + big + "\n\n")) + require.Error(t, err) + assert.Contains(t, err.Error(), "32MB", "must surface the max-token overflow as a clear error") +} + +func TestParseSSEResponse_MultipleDataLinesReturnsFirst(t *testing.T) { + body := "data: first\n\ndata: second\n\n" + out, err := parseSSEResponse([]byte(body)) + require.NoError(t, err) + assert.Equal(t, "first", string(out), "the function returns the first data line and stops") +} + +// ----------------------------------------------------------------------------- +// createInitializeRequest / createToolsListRequest / createToolsCallRequest +// ----------------------------------------------------------------------------- + +func TestCreateInitializeRequest_StableShape(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + req := h.createInitializeRequest() + + assert.Equal(t, "2.0", req["jsonrpc"]) + assert.Equal(t, 1, req["id"]) + assert.Equal(t, "initialize", req["method"]) + + params, ok := req["params"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, "2025-03-26", params["protocolVersion"]) + + clientInfo, ok := params["clientInfo"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, "Higress-mcp-proxy", clientInfo["name"]) + assert.Equal(t, "1.0.0", clientInfo["version"]) + + _, hasCaps := params["capabilities"] + assert.True(t, hasCaps) +} + +func TestCreateToolsListRequest_NoCursor(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + req := h.createToolsListRequest(nil) + + assert.Equal(t, "2.0", req["jsonrpc"]) + assert.Equal(t, 2, req["id"]) + assert.Equal(t, "tools/list", req["method"]) + + params, ok := req["params"].(map[string]interface{}) + require.True(t, ok) + _, hasCursor := params["cursor"] + assert.False(t, hasCursor, "nil cursor must produce no cursor field") +} + +func TestCreateToolsListRequest_EmptyStringCursor(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + empty := "" + req := h.createToolsListRequest(&empty) + + params := req["params"].(map[string]interface{}) + _, hasCursor := params["cursor"] + assert.False(t, hasCursor, "empty-string cursor is treated as absent") +} + +func TestCreateToolsListRequest_WithCursor(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + c := "next-page" + req := h.createToolsListRequest(&c) + + params := req["params"].(map[string]interface{}) + assert.Equal(t, "next-page", params["cursor"]) +} + +func TestCreateToolsCallRequest_StableShape(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + args := map[string]interface{}{"q": "hello", "limit": 5} + req := h.createToolsCallRequest("search", args) + + assert.Equal(t, "2.0", req["jsonrpc"]) + assert.Equal(t, 3, req["id"]) + assert.Equal(t, "tools/call", req["method"]) + + params, ok := req["params"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, "search", params["name"]) + gotArgs, ok := params["arguments"].(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, "hello", gotArgs["q"]) + assert.Equal(t, 5, gotArgs["limit"]) +} + +func TestCreateToolsCallRequest_NilArguments(t *testing.T) { + h := NewMcpProtocolHandler("http://backend.example/mcp", 5000) + req := h.createToolsCallRequest("noop", nil) + params := req["params"].(map[string]interface{}) + assert.Equal(t, "noop", params["name"]) + args, ok := params["arguments"] + require.True(t, ok) + assert.Nil(t, args) +} + +// ----------------------------------------------------------------------------- +// ParseBackendResponse / IsBackendError — extra branches +// ----------------------------------------------------------------------------- + +func TestParseBackendResponse_StringErrorField(t *testing.T) { + // JSON-RPC error field doesn't have to be an object — anything truthy works. + body := []byte(`{"jsonrpc":"2.0","id":1,"error":"some-text"}`) + resp, isErr, etype := ParseBackendResponse(body) + require.NotNil(t, resp) + assert.True(t, isErr) + assert.Equal(t, "jsonrpc_error", etype) +} + +func TestParseBackendResponse_NoResultNoError(t *testing.T) { + // Valid JSON without result/error → not an error, but still parsed. + body := []byte(`{"jsonrpc":"2.0","id":2}`) + resp, isErr, etype := ParseBackendResponse(body) + require.NotNil(t, resp) + assert.False(t, isErr) + assert.Empty(t, etype) +} + +func TestParseBackendResponse_ResultIsErrorFalseNotAnError(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":3,"result":{"isError":false}}`) + _, isErr, etype := ParseBackendResponse(body) + assert.False(t, isErr) + assert.Empty(t, etype) +} + +func TestParseBackendResponse_ResultNotAnObject(t *testing.T) { + // result is a scalar — the isError-extraction branch is skipped. + body := []byte(`{"jsonrpc":"2.0","id":3,"result":"ok"}`) + _, isErr, etype := ParseBackendResponse(body) + assert.False(t, isErr) + assert.Empty(t, etype) +} + +func TestIsBackendError_DelegatesToParse(t *testing.T) { + cases := []struct { + body string + isError bool + etype string + }{ + {`{"error":{"code":-1}}`, true, "jsonrpc_error"}, + {`{"result":{"isError":true}}`, true, "result_isError"}, + {`{"result":{"isError":false}}`, false, ""}, + {`not json`, false, ""}, + } + for _, c := range cases { + isErr, etype := IsBackendError([]byte(c.body)) + assert.Equal(t, c.isError, isErr, "body=%s", c.body) + assert.Equal(t, c.etype, etype, "body=%s", c.body) + } +} + +// ----------------------------------------------------------------------------- +// McpSessionManagerImpl +// ----------------------------------------------------------------------------- + +func TestNewMcpSessionManagerImpl(t *testing.T) { + m := NewMcpSessionManagerImpl() + require.NotNil(t, m) + require.NotNil(t, m.sessions) + assert.Empty(t, m.sessions) +} + +func TestSessionManager_CreateAndGet(t *testing.T) { + m := NewMcpSessionManagerImpl() + id, err := m.CreateSession("http://backend.example/mcp") + require.NoError(t, err) + assert.True(t, strings.HasPrefix(id, "mcp-session-")) + + session, ok := m.GetSession(id) + require.True(t, ok) + assert.Equal(t, id, session.ID) + assert.Equal(t, "http://backend.example/mcp", session.BackendURL) + assert.False(t, session.CreatedAt.IsZero()) +} + +func TestSessionManager_GetSessionUpdatesLastUsed(t *testing.T) { + m := NewMcpSessionManagerImpl() + id, err := m.CreateSession("http://b") + require.NoError(t, err) + + // Force a measurable gap so LastUsed changes monotonically. + original := m.sessions[id].LastUsed + time.Sleep(2 * time.Millisecond) + + s, ok := m.GetSession(id) + require.True(t, ok) + assert.True(t, s.LastUsed.After(original), "GetSession should refresh LastUsed") +} + +func TestSessionManager_GetUnknownSession(t *testing.T) { + m := NewMcpSessionManagerImpl() + s, ok := m.GetSession("missing") + assert.False(t, ok) + assert.Nil(t, s) +} + +func TestSessionManager_CleanupSession_Existing(t *testing.T) { + m := NewMcpSessionManagerImpl() + id, _ := m.CreateSession("http://b") + m.CleanupSession(id) + _, ok := m.GetSession(id) + assert.False(t, ok) +} + +func TestSessionManager_CleanupSession_NonExistent(t *testing.T) { + m := NewMcpSessionManagerImpl() + // Must not panic on unknown id. + m.CleanupSession("never-existed") + assert.Empty(t, m.sessions) +} + +func TestSessionManager_CleanupExpiredSessions(t *testing.T) { + m := NewMcpSessionManagerImpl() + fresh, _ := m.CreateSession("fresh") + stale, _ := m.CreateSession("stale") + + // Backdate the stale session. + m.sessions[stale].LastUsed = time.Now().Add(-10 * time.Minute) + + m.CleanupExpiredSessions(1 * time.Minute) + + _, freshOk := m.sessions[fresh] + _, staleOk := m.sessions[stale] + assert.True(t, freshOk, "fresh session must remain") + assert.False(t, staleOk, "stale session must be removed") +} + +func TestSessionManager_CleanupExpiredSessions_EmptyMap(t *testing.T) { + m := NewMcpSessionManagerImpl() + // Must not panic on empty manager. + m.CleanupExpiredSessions(1 * time.Second) + assert.Empty(t, m.sessions) +} + +func TestSessionManager_CreateSessionsAreUnique(t *testing.T) { + m := NewMcpSessionManagerImpl() + id1, _ := m.CreateSession("http://b") + // Guarantee a different nanosecond timestamp. + time.Sleep(1 * time.Millisecond) + id2, _ := m.CreateSession("http://b") + assert.NotEqual(t, id1, id2, "session IDs should be unique") +} + +// ----------------------------------------------------------------------------- +// ensureHeader +// ----------------------------------------------------------------------------- + +func TestEnsureHeader_AddsWhenMissing(t *testing.T) { + headers := [][2]string{{"X-Other", "v"}} + ensureHeader(&headers, "X-New", "value") + require.Len(t, headers, 2) + assert.Equal(t, [2]string{"X-New", "value"}, headers[1]) +} + +func TestEnsureHeader_ReplacesCaseInsensitively(t *testing.T) { + headers := [][2]string{{"content-type", "text/plain"}} + ensureHeader(&headers, "Content-Type", "application/json") + require.Len(t, headers, 1) + // Replace path rewrites the original casing too. + assert.Equal(t, "Content-Type", headers[0][0]) + assert.Equal(t, "application/json", headers[0][1]) +} + +func TestEnsureHeader_NoDuplicateOnRepeatedCalls(t *testing.T) { + headers := [][2]string{} + ensureHeader(&headers, "X-K", "1") + ensureHeader(&headers, "X-K", "2") + require.Len(t, headers, 1) + assert.Equal(t, "2", headers[0][1]) +} diff --git a/plugins/wasm-go/pkg/mcp/server/rest_server_test.go b/plugins/wasm-go/pkg/mcp/server/rest_server_test.go index 6f77c08b5..db4c0dfe1 100644 --- a/plugins/wasm-go/pkg/mcp/server/rest_server_test.go +++ b/plugins/wasm-go/pkg/mcp/server/rest_server_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tidwall/sjson" ) @@ -920,3 +921,377 @@ func TestRestServerSecurityFallback(t *testing.T) { t.Logf("REST server security fallback test completed successfully") } + +// --------------------------------------------------------------------------- +// parseIP +// --------------------------------------------------------------------------- + +func TestParseIP(t *testing.T) { + cases := []struct { + name string + source string + fromHeader bool + want string + }{ + {"ipv4 only", "10.0.0.1", false, "10.0.0.1"}, + {"ipv4 with port", "10.0.0.1:8080", false, "10.0.0.1"}, + {"ipv4 with leading whitespace", " 10.0.0.1:80", false, "10.0.0.1"}, + {"ipv4 X-Forwarded-For first hop", "10.0.0.1, 10.0.0.2, 10.0.0.3", true, "10.0.0.1"}, + {"ipv4 X-Forwarded-For with spaces", " 10.0.0.1 , 10.0.0.2 ", true, "10.0.0.1"}, + {"ipv6 bracketed with port", "[2001:db8::1]:443", false, "2001:db8::1"}, + {"ipv6 bracketed no port", "[2001:db8::1]", false, "2001:db8::1"}, + {"ipv6 bare passes through", "2001:db8::1", false, "2001:db8::1"}, + {"empty string", "", false, ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := parseIP(c.source, c.fromHeader) + assert.Equal(t, c.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// parseTemplates — fill remaining error branches +// --------------------------------------------------------------------------- + +func TestParseTemplates_DirectResponseMissingBody(t *testing.T) { + // No RequestTemplate.URL → direct-response mode. ResponseTemplate.Body must be set. + tool := RestTool{} + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "direct response mode") + } +} + +func TestParseTemplates_DirectResponseWithBodyOk(t *testing.T) { + tool := RestTool{ + ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"}, + } + assert.NoError(t, tool.parseTemplates()) + assert.True(t, tool.isDirectResponseTool) +} + +func TestParseTemplates_URLTemplateParseError(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "http://x/{{ .unclosed ", // missing closing braces + Method: "GET", + }, + } + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "URL template") + } +} + +func TestParseTemplates_HeaderTemplateParseError(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "http://x", + Method: "GET", + Headers: []RestToolHeader{ + {Key: "X-Bad", Value: "{{ .unclosed "}, + }, + }, + } + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "header template") + } +} + +func TestParseTemplates_BodyTemplateParseError(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "http://x", + Method: "POST", + Body: "{{ .unclosed ", + }, + } + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "body template") + } +} + +func TestParseTemplates_ResponseTemplateParseError(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"}, + ResponseTemplate: RestToolResponseTemplate{ + Body: "{{ .unclosed ", + }, + } + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "response template") + } +} + +func TestParseTemplates_ErrorResponseTemplateParseError(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"}, + ErrorResponseTemplate: "{{ .unclosed ", + } + err := tool.parseTemplates() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "error response template") + } +} + +func TestParseTemplates_HeaderWithEmptyKeySkipped(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{ + URL: "http://x", + Method: "GET", + Headers: []RestToolHeader{ + {Key: "", Value: "ignored"}, + {Key: "X-Real", Value: "real"}, + }, + }, + } + assert.NoError(t, tool.parseTemplates()) + _, hasReal := tool.parsedHeaderTemplates["X-Real"] + assert.True(t, hasReal) + _, hasEmpty := tool.parsedHeaderTemplates[""] + assert.False(t, hasEmpty) +} + +func TestParseTemplates_PopulatesArgPositions(t *testing.T) { + tool := RestTool{ + RequestTemplate: RestToolRequestTemplate{URL: "http://x", Method: "GET"}, + ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"}, + Args: []RestToolArg{ + {Name: "q", Position: "QUERY"}, // lower-cased in argPositions + {Name: "h", Position: "Header"}, + {Name: "noPos"}, // no position → not stored + }, + } + require := assert.New(t) + require.NoError(tool.parseTemplates()) + require.Equal("query", tool.argPositions["q"]) + require.Equal("header", tool.argPositions["h"]) + _, ok := tool.argPositions["noPos"] + require.False(ok) +} + +// --------------------------------------------------------------------------- +// executeTemplate — nil + execution error +// --------------------------------------------------------------------------- + +func TestExecuteTemplate_NilReturnsError(t *testing.T) { + _, err := executeTemplate(nil, []byte(`{}`)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "nil") + } +} + +// --------------------------------------------------------------------------- +// RestMCPServer accessors — GetSecurityScheme, GetPassthroughAuthHeader, +// AddMCPTool, GetConfig, GetToolConfig +// --------------------------------------------------------------------------- + +func TestRestServer_AddMCPTool_DelegatesToBase(t *testing.T) { + s := NewRestMCPServer("rest") + tool := &stubTool{desc: "x"} + ret := s.AddMCPTool("plain", tool) + assert.Same(t, s, ret) + got, ok := s.GetMCPTools()["plain"] + assert.True(t, ok) + assert.Same(t, tool, got) +} + +func TestRestServer_GetSecurityScheme_HitAndMiss(t *testing.T) { + s := NewRestMCPServer("rest") + scheme := SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X-K"} + s.AddSecurityScheme(scheme) + + got, ok := s.GetSecurityScheme("K") + assert.True(t, ok) + assert.Equal(t, "K", got.ID) + + _, ok = s.GetSecurityScheme("missing") + assert.False(t, ok) +} + +func TestRestServer_PassthroughAuthHeader(t *testing.T) { + s := NewRestMCPServer("rest") + assert.False(t, s.GetPassthroughAuthHeader()) + s.SetPassthroughAuthHeader(true) + assert.True(t, s.GetPassthroughAuthHeader()) +} + +func TestRestServer_GetToolConfig(t *testing.T) { + s := NewRestMCPServer("rest") + require.NoError(t, s.AddRestTool(RestTool{ + Name: "t", + ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"}, + })) + cfg, ok := s.GetToolConfig("t") + assert.True(t, ok) + assert.Equal(t, "t", cfg.Name) + + _, ok = s.GetToolConfig("missing") + assert.False(t, ok) +} + +// --------------------------------------------------------------------------- +// RestMCPServer.Clone — independence +// --------------------------------------------------------------------------- + +func TestRestServer_Clone_Independence(t *testing.T) { + orig := NewRestMCPServer("rest") + orig.SetPassthroughAuthHeader(true) + orig.SetConfig([]byte(`{"v":1}`)) + orig.AddSecurityScheme(SecurityScheme{ID: "K", Type: "apiKey", In: "header", Name: "X"}) + require.NoError(t, orig.AddRestTool(RestTool{ + Name: "t", + ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"}, + })) + + clonedI := orig.Clone() + require.NotNil(t, clonedI) + cloned, ok := clonedI.(*RestMCPServer) + require.True(t, ok) + + // Mutate the original: cloned must not see the change. + orig.AddSecurityScheme(SecurityScheme{ID: "K2", Type: "apiKey", In: "header", Name: "Y"}) + _, hasK2 := cloned.GetSecurityScheme("K2") + assert.False(t, hasK2, "cloned server must not see security scheme added to original after Clone") + + // Tools map was deep-copied at Clone time. + _, hasT := cloned.GetToolConfig("t") + assert.True(t, hasT) +} + +// --------------------------------------------------------------------------- +// RestMCPTool.Create — type coercion matrix +// --------------------------------------------------------------------------- + +func newRestToolForCreate(t *testing.T) *RestMCPTool { + t.Helper() + tool := RestTool{ + Name: "t", + Args: []RestToolArg{ + {Name: "b", Type: "boolean"}, + {Name: "i", Type: "integer"}, + {Name: "n", Type: "number"}, + {Name: "s", Type: "string"}, + {Name: "d", Type: "integer", Default: 7}, + }, + ResponseTemplate: RestToolResponseTemplate{Body: "{{.}}"}, + } + require.NoError(t, tool.parseTemplates()) + return &RestMCPTool{ + serverName: "rest", + name: "t", + toolConfig: tool, + } +} + +func TestRestMCPTool_Create_BooleanCoercion(t *testing.T) { + tool := newRestToolForCreate(t) + // Boolean from native true, native false, string "true", string "false", + // string with garbage (passthrough), and other types (passthrough). + cases := []struct { + raw any + want any + }{ + {true, true}, + {false, false}, + {"true", true}, + {"false", false}, + {"yes", "yes"}, + // JSON unmarshal turns any number into float64; non-bool/non-string + // hits the default arm and is stored verbatim. + {42, float64(42)}, + } + for _, c := range cases { + body, err := json.Marshal(map[string]any{"b": c.raw}) + require.NoError(t, err) + created := tool.Create(body).(*RestMCPTool) + assert.Equal(t, c.want, created.arguments["b"], "raw=%v", c.raw) + } +} + +func TestRestMCPTool_Create_IntegerCoercion(t *testing.T) { + tool := newRestToolForCreate(t) + cases := []struct { + raw any + want any + }{ + {float64(10), 10}, + {"42", 42}, + {"not-int", "not-int"}, + {true, true}, + } + for _, c := range cases { + body, err := json.Marshal(map[string]any{"i": c.raw}) + require.NoError(t, err) + created := tool.Create(body).(*RestMCPTool) + assert.Equal(t, c.want, created.arguments["i"], "raw=%v", c.raw) + } +} + +func TestRestMCPTool_Create_NumberCoercion(t *testing.T) { + tool := newRestToolForCreate(t) + cases := []struct { + raw any + want any + }{ + {"3.14", 3.14}, + {"abc", "abc"}, + {float64(2.5), 2.5}, // default: passthrough + } + for _, c := range cases { + body, err := json.Marshal(map[string]any{"n": c.raw}) + require.NoError(t, err) + created := tool.Create(body).(*RestMCPTool) + assert.Equal(t, c.want, created.arguments["n"], "raw=%v", c.raw) + } +} + +func TestRestMCPTool_Create_DefaultApplied(t *testing.T) { + tool := newRestToolForCreate(t) + body := []byte(`{}`) + created := tool.Create(body).(*RestMCPTool) + assert.Equal(t, 7, created.arguments["d"]) + // Args without defaults are not present when omitted. + _, hasI := created.arguments["i"] + assert.False(t, hasI) +} + +func TestRestMCPTool_Create_StringPassthrough(t *testing.T) { + tool := newRestToolForCreate(t) + body, _ := json.Marshal(map[string]any{"s": "hello"}) + created := tool.Create(body).(*RestMCPTool) + assert.Equal(t, "hello", created.arguments["s"]) +} + +func TestRestMCPTool_Create_MalformedJSONStillProducesTool(t *testing.T) { + tool := newRestToolForCreate(t) + // Bad JSON is logged + ignored; defaults still applied. + created := tool.Create([]byte("{not json")).(*RestMCPTool) + assert.Equal(t, 7, created.arguments["d"], "default still applied when params unparseable") +} + +// --------------------------------------------------------------------------- +// hasContentType — case + charset suffix +// --------------------------------------------------------------------------- + +func TestHasContentType_CaseAndCharsetSuffix(t *testing.T) { + headers := [][2]string{ + {"content-type", "Application/JSON; charset=utf-8"}, + } + assert.True(t, hasContentType(headers, "application/json")) + assert.True(t, hasContentType(headers, "json")) + assert.False(t, hasContentType(headers, "xml")) + + emptyHeaders := [][2]string{} + assert.False(t, hasContentType(emptyHeaders, "application/json")) +} + +// pull in `require` for newer tests above without disturbing existing imports. +var _ = url.Parse +var _ = sjson.Set +var _ = strings.TrimSpace diff --git a/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go b/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go index 7994ec525..4d3a353da 100644 --- a/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go +++ b/plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go @@ -15,7 +15,11 @@ package server import ( + "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestParseSSEMessage tests SSE message parsing @@ -295,3 +299,176 @@ data: {"id":2} t.Errorf("Expected no more messages, got: %+v", msg4) } } + +// ----------------------------------------------------------------------------- +// ParseSSEMessage — additional edge cases (multi-line data, retry, empty) +// ----------------------------------------------------------------------------- + +func TestParseSSEMessage_EmptyInput(t *testing.T) { + msg, remaining, err := ParseSSEMessage([]byte("")) + require.NoError(t, err) + assert.Nil(t, msg) + assert.Len(t, remaining, 0) +} + +func TestParseSSEMessage_RetryFieldIgnored(t *testing.T) { + // `retry:` is part of the SSE spec but not implemented — must not break parsing. + input := []byte("retry: 5000\nevent: message\ndata: hi\n\n") + msg, _, err := ParseSSEMessage(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "message", msg.Event) + assert.Equal(t, "hi", msg.Data) +} + +func TestParseSSEMessage_MultiLineDataConcatenated(t *testing.T) { + // Per SSE spec, multiple `data:` lines in one message join with `\n`. + input := []byte("data: line-one\ndata: line-two\ndata: line-three\n\n") + msg, _, err := ParseSSEMessage(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "line-one\nline-two\nline-three", msg.Data) +} + +func TestParseSSEMessage_NoFinalBlankLine_NoMessageReturned(t *testing.T) { + // Message without the terminating blank line is treated as incomplete. + input := []byte("event: message\ndata: payload\n") + msg, remaining, err := ParseSSEMessage(input) + require.NoError(t, err) + assert.Nil(t, msg, "incomplete message must not be returned") + assert.Equal(t, input, remaining, "remaining is the entire input") +} + +func TestParseSSEMessage_LineWithoutColonSkipped(t *testing.T) { + // SplitN with len<2 → field/value pair can't be formed → skipped, not an error. + input := []byte("a-line-without-colon\nevent: msg\ndata: x\n\n") + msg, _, err := ParseSSEMessage(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "msg", msg.Event) + assert.Equal(t, "x", msg.Data) +} + +func TestParseSSEMessage_UnknownFieldIgnored(t *testing.T) { + // `random-field:` is parsed but the switch case ignores it. + input := []byte("random-field: stuff\nevent: msg\ndata: x\n\n") + msg, _, err := ParseSSEMessage(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "msg", msg.Event) +} + +// ----------------------------------------------------------------------------- +// ExtractEndpointURL — edge cases not in the table +// ----------------------------------------------------------------------------- + +func TestExtractEndpointURL_HttpsPassthrough(t *testing.T) { + got, err := ExtractEndpointURL("https://other.example/x", "http://b.example") + require.NoError(t, err) + assert.Equal(t, "https://other.example/x", got, "full https URL must pass through unchanged") +} + +func TestExtractEndpointURL_EmptyEndpointData_PathOnlyBase(t *testing.T) { + got, err := ExtractEndpointURL("", "/some/path") + require.NoError(t, err) + assert.Equal(t, "", got, "empty endpointData with path-only base → empty result") +} + +func TestExtractEndpointURL_RelativeEndpointWithSchemeBase(t *testing.T) { + got, err := ExtractEndpointURL("messages", "http://b.example/mcp") + require.NoError(t, err) + assert.Equal(t, "http://b.example/messages", got, "leading slash auto-inserted") +} + +// ----------------------------------------------------------------------------- +// applyProxyAuthenticationForSSE — pure URL+header munging (no proxywasm) +// ----------------------------------------------------------------------------- + +func TestApplyProxyAuthenticationForSSE_ApiKeyHeader(t *testing.T) { + server := NewMcpProxyServer("p") + server.AddSecurityScheme(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "abc", + }) + + headers := [][2]string{{"X-Other", "v"}} + got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x") + require.NoError(t, err) + assert.Equal(t, "http://backend/x", got, "no query → URL preserved") + + found := false + for _, kv := range headers { + if strings.EqualFold(kv[0], "X-Api-Key") { + assert.Equal(t, "abc", kv[1]) + found = true + } + } + assert.True(t, found, "API key header must be injected") +} + +func TestApplyProxyAuthenticationForSSE_ApiKeyQuery_PreservesExisting(t *testing.T) { + server := NewMcpProxyServer("p") + server.AddSecurityScheme(SecurityScheme{ + ID: "K", Type: "apiKey", In: "query", Name: "api_key", + DefaultCredential: "secret", + }) + + headers := [][2]string{} + got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x?existing=1") + require.NoError(t, err) + // Query is rebuilt via url.Values.Encode — both pairs must be present. + assert.Contains(t, got, "api_key=secret") + assert.Contains(t, got, "existing=1") +} + +func TestApplyProxyAuthenticationForSSE_PathOnlyURL_PreservesShape(t *testing.T) { + server := NewMcpProxyServer("p") + server.AddSecurityScheme(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "abc", + }) + + headers := [][2]string{} + got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "/relative/path") + require.NoError(t, err) + assert.Equal(t, "/relative/path", got, "path-only URL must come back as path-only") +} + +func TestApplyProxyAuthenticationForSSE_HttpBearerPassthrough(t *testing.T) { + server := NewMcpProxyServer("p") + server.AddSecurityScheme(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"}) + + headers := [][2]string{} + got, err := applyProxyAuthenticationForSSE(server, "B", "passthrough-token", &headers, "http://backend/x") + require.NoError(t, err) + assert.Equal(t, "http://backend/x", got) + + var authValue string + for _, kv := range headers { + if strings.EqualFold(kv[0], "Authorization") { + authValue = kv[1] + } + } + assert.Equal(t, "Bearer passthrough-token", authValue) +} + +func TestApplyProxyAuthenticationForSSE_MissingScheme_ReturnsError(t *testing.T) { + server := NewMcpProxyServer("p") + headers := [][2]string{} + _, err := applyProxyAuthenticationForSSE(server, "missing", "", &headers, "http://backend/x") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestApplyProxyAuthenticationForSSE_PreservesFragment(t *testing.T) { + server := NewMcpProxyServer("p") + server.AddSecurityScheme(SecurityScheme{ + ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key", + DefaultCredential: "abc", + }) + + headers := [][2]string{} + got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/path#section-2") + require.NoError(t, err) + assert.Contains(t, got, "#section-2", "fragment must round-trip") +}