From ff9a29c5d9befabf19a04bcd9cdcd6c593dafecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Tue, 29 Jul 2025 12:41:29 +0800 Subject: [PATCH] opt: mcp endpoint parser improvements (#2673) --- plugins/golang-filter/mcp-session/config.go | 2 +- plugins/golang-filter/mcp-session/filter.go | 127 +++-- .../golang-filter/mcp-session/filter_test.go | 464 ++++++++++++++++++ 3 files changed, 551 insertions(+), 42 deletions(-) create mode 100644 plugins/golang-filter/mcp-session/filter_test.go diff --git a/plugins/golang-filter/mcp-session/config.go b/plugins/golang-filter/mcp-session/config.go index 6bb9fcf30..036dbb06f 100644 --- a/plugins/golang-filter/mcp-session/config.go +++ b/plugins/golang-filter/mcp-session/config.go @@ -64,7 +64,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int redisClient, err := common.NewRedisClient(redisConfig) if err != nil { - api.LogErrorf("Failed to initialize Redis client: %w", err) + api.LogErrorf("Failed to initialize Redis client: %v", err) } else { api.LogDebug("Redis client initialized") } diff --git a/plugins/golang-filter/mcp-session/filter.go b/plugins/golang-filter/mcp-session/filter.go index 8fc7233e2..9a42770d6 100644 --- a/plugins/golang-filter/mcp-session/filter.go +++ b/plugins/golang-filter/mcp-session/filter.go @@ -397,48 +397,93 @@ func (f *filter) findNextLineBreak(bufferData string) (error, string) { } func (f *filter) findEndpointUrl(bufferData string) (error, string) { - eventIndex := strings.Index(bufferData, "event:") - if eventIndex == -1 { - return nil, "" + // Keep searching for events until we find an endpoint event or run out of data + for { + eventIndex := strings.Index(bufferData, "event:") + if eventIndex == -1 { + // No more events found + return nil, "" + } + + // Move to the start of the event + bufferData = bufferData[eventIndex:] + + // Find the end of the event line + err, lineBreak := f.findNextLineBreak(bufferData) + if err != nil { + return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" + } + if lineBreak == "" { + // No line break found, which means the data is not enough. + return nil, "" + } + + api.LogDebugf("event line break sequence: %v", []byte(lineBreak)) + eventEndIndex := strings.Index(bufferData, lineBreak) + if eventEndIndex == -1 { + return nil, "" + } + + eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex]) + + // Move past the event line + bufferData = bufferData[eventEndIndex+len(lineBreak):] + + if eventName == "endpoint" { + // Found endpoint event, now look for the data field + err, lineBreak = f.findNextLineBreak(bufferData) + if err != nil { + return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" + } + if lineBreak == "" { + // No line break found, which means the data is not enough. + return nil, "" + } + + api.LogDebugf("data line break sequence: %v", []byte(lineBreak)) + dataEndIndex := strings.Index(bufferData, lineBreak) + if dataEndIndex == -1 { + // Data received not enough. + return nil, "" + } + + eventData := bufferData[:dataEndIndex] + if !strings.HasPrefix(eventData, "data:") { + return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), "" + } + + return nil, strings.TrimSpace(eventData[len("data:"):]) + } else { + // Not an endpoint event, skip to the next event + api.LogDebugf("Skipping non-endpoint event: %s", eventName) + + // First, we need to skip the data field of this event + err, lineBreak = f.findNextLineBreak(bufferData) + if err != nil { + return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" + } + if lineBreak == "" { + // No line break found, which means the data is not enough. + return nil, "" + } + + dataEndIndex := strings.Index(bufferData, lineBreak) + if dataEndIndex == -1 { + // Data received not enough. + return nil, "" + } + + // Move past the data line + bufferData = bufferData[dataEndIndex+len(lineBreak):] + + // Skip any additional empty lines that separate events + for strings.HasPrefix(bufferData, lineBreak) { + bufferData = bufferData[len(lineBreak):] + } + + // Continue to look for the next event + } } - bufferData = bufferData[eventIndex:] - err, lineBreak := f.findNextLineBreak(bufferData) - if err != nil { - return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" - } - if lineBreak == "" { - // No line break found, which means the data is not enough. - return nil, "" - } - api.LogDebugf("event line break sequence: %v", []byte(lineBreak)) - eventEndIndex := strings.Index(bufferData, lineBreak) - if eventEndIndex == -1 { - return nil, "" - } - eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex]) - if eventName != "endpoint" { - return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), "" - } - bufferData = bufferData[eventEndIndex+len(lineBreak):] - err, lineBreak = f.findNextLineBreak(bufferData) - if err != nil { - return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" - } - if lineBreak == "" { - // No line break found, which means the data is not enough. - return nil, "" - } - api.LogDebugf("data line break sequence: %v", []byte(lineBreak)) - dataEndIndex := strings.Index(bufferData, lineBreak) - if dataEndIndex == -1 { - // Data received not enough. - return nil, "" - } - eventData := bufferData[:dataEndIndex] - if !strings.HasPrefix(eventData, "data:") { - return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), "" - } - return nil, strings.TrimSpace(eventData[len("data:"):]) } // OnDestroy stops the goroutine diff --git a/plugins/golang-filter/mcp-session/filter_test.go b/plugins/golang-filter/mcp-session/filter_test.go new file mode 100644 index 000000000..fe548cadb --- /dev/null +++ b/plugins/golang-filter/mcp-session/filter_test.go @@ -0,0 +1,464 @@ +package mcp_session + +import ( + "fmt" + "testing" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-session/common" + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" +) + +// Mock implementation of CommonCAPI for testing +type mockCommonCAPI struct { + logs []string +} + +func (m *mockCommonCAPI) Log(level api.LogType, message string) { + fmt.Printf("[%s] %s", level, message) + m.logs = append(m.logs, message) +} + +func (m *mockCommonCAPI) LogLevel() api.LogType { + return api.Debug +} + +// Test helper to create a filter instance for testing +func createTestFilter() *filter { + return &filter{} +} + +// Test helper to create a match rule for testing +func createTestMatchRule() common.MatchRule { + return common.MatchRule{ + UpstreamType: common.SSEUpstream, + EnablePathRewrite: true, + PathRewritePrefix: "/api/v1", + MatchRulePath: "/mcp", + MatchRuleType: common.PrefixMatch, + MatchRuleDomain: "example.com", + } +} + +// TestFindEndpointUrl_ValidEndpointMessage tests the current behavior with a valid endpoint message +func TestFindEndpointUrl_ValidEndpointMessage(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with valid endpoint message + sseData := "event: endpoint\ndata: https://api.example.com/chat\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + expectedUrl := "https://api.example.com/chat" + if endpointUrl != expectedUrl { + t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl) + } +} + +// TestFindEndpointUrl_NonEndpointFirstMessage tests improved behavior with non-endpoint first message +func TestFindEndpointUrl_NonEndpointFirstMessage(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with ping message first (this should now succeed with improved implementation) + sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + // Improved implementation should handle non-endpoint first message + if err != nil { + t.Errorf("Expected no error for non-endpoint first message, got: %v", err) + } + + expectedUrl := "https://api.example.com/chat" + if endpointUrl != expectedUrl { + t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl) + } + + // Check that the non-endpoint event was logged + found := false + for _, log := range mockAPI.logs { + if log == "Skipping non-endpoint event: ping" { + found = true + break + } + } + if !found { + t.Errorf("Expected log message about skipping ping event not found") + } +} + +// TestFindEndpointUrl_MultipleNonEndpointMessages tests multiple non-endpoint messages before endpoint +func TestFindEndpointUrl_MultipleNonEndpointMessages(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with multiple non-endpoint messages before endpoint + sseData := "event: ping\ndata: alive\n\nevent: status\ndata: connecting\n\nevent: info\ndata: ready\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + expectedUrl := "https://api.example.com/chat" + if endpointUrl != expectedUrl { + t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl) + } + + // Check that all non-endpoint events were logged + expectedLogs := []string{ + "Skipping non-endpoint event: ping", + "Skipping non-endpoint event: status", + "Skipping non-endpoint event: info", + } + + for _, expectedLog := range expectedLogs { + found := false + for _, log := range mockAPI.logs { + if log == expectedLog { + found = true + break + } + } + if !found { + t.Errorf("Expected log message '%s' not found", expectedLog) + } + } +} + +// TestFindEndpointUrl_EndpointInMiddle tests endpoint message in the middle of other messages +func TestFindEndpointUrl_EndpointInMiddle(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with endpoint message in the middle + sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat\n\nevent: status\ndata: ready\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + expectedUrl := "https://api.example.com/chat" + if endpointUrl != expectedUrl { + t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl) + } + + // Check that the ping event was logged as skipped + found := false + for _, log := range mockAPI.logs { + if log == "Skipping non-endpoint event: ping" { + found = true + break + } + } + if !found { + t.Errorf("Expected log message about skipping ping event not found") + } +} + +// TestFindEndpointUrl_NoEndpointMessage tests when no endpoint message is present +func TestFindEndpointUrl_NoEndpointMessage(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with no endpoint message + sseData := "event: ping\ndata: alive\n\nevent: status\ndata: connecting\n\nevent: info\ndata: ready\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error when no endpoint found, got: %v", err) + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL when no endpoint found, got '%s'", endpointUrl) + } + + // Check that all non-endpoint events were logged + expectedLogs := []string{ + "Skipping non-endpoint event: ping", + "Skipping non-endpoint event: status", + "Skipping non-endpoint event: info", + } + + for _, expectedLog := range expectedLogs { + found := false + for _, log := range mockAPI.logs { + if log == expectedLog { + found = true + break + } + } + if !found { + t.Errorf("Expected log message '%s' not found", expectedLog) + } + } +} + +// TestFindEndpointUrl_IncompleteEndpointMessage tests incomplete endpoint message +func TestFindEndpointUrl_IncompleteEndpointMessage(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with incomplete endpoint message (missing final line break) + sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error for incomplete endpoint message, got: %v", err) + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL for incomplete message, got '%s'", endpointUrl) + } +} + +// TestFindEndpointUrl_IncompleteNonEndpointMessage tests incomplete non-endpoint message +func TestFindEndpointUrl_IncompleteNonEndpointMessage(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with incomplete non-endpoint message + sseData := "event: ping\ndata: alive" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error for incomplete non-endpoint message, got: %v", err) + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL for incomplete message, got '%s'", endpointUrl) + } +} + +// TestFindEndpointUrl_MalformedEndpointData tests malformed endpoint data +func TestFindEndpointUrl_MalformedEndpointData(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with malformed endpoint data (missing data field) + sseData := "event: ping\ndata: alive\n\nevent: endpoint\nnotdata: https://api.example.com/chat\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + // Should return error for malformed endpoint data + if err == nil { + t.Errorf("Expected error for malformed endpoint data, but got none") + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL when error occurs, got '%s'", endpointUrl) + } +} + +// TestFindEndpointUrl_DifferentLineBreaks tests different line break formats with improved version +func TestFindEndpointUrl_DifferentLineBreaks(t *testing.T) { + testCases := []struct { + name string + sseData string + expected string + }{ + { + name: "CRLF line breaks with ping first", + sseData: "event: ping\r\ndata: alive\r\n\r\nevent: endpoint\r\ndata: https://api.example.com/chat\r\n\r\n", + expected: "https://api.example.com/chat", + }, + { + name: "CR line breaks with status first", + sseData: "event: status\rdata: ready\r\revent: endpoint\rdata: https://api.example.com/chat\r\r", + expected: "https://api.example.com/chat", + }, + { + name: "LF line breaks with info first", + sseData: "event: info\ndata: starting\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n", + expected: "https://api.example.com/chat", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + err, endpointUrl := f.findEndpointUrl(tc.sseData) + + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + if endpointUrl != tc.expected { + t.Errorf("Expected endpoint URL '%s', got '%s'", tc.expected, endpointUrl) + } + }) + } +} + +// TestFindEndpointUrl_WithWhitespace tests improved version with whitespace +func TestFindEndpointUrl_WithWhitespace(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with whitespace around event names and data + sseData := "event: ping \ndata: alive \n\nevent: endpoint \ndata: https://api.example.com/chat \n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + expectedUrl := "https://api.example.com/chat" + if endpointUrl != expectedUrl { + t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl) + } +} + +// TestFindEndpointUrl_NoEventFound tests behavior when no event is found +func TestFindEndpointUrl_NoEventFound(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with data that doesn't contain event + sseData := "some random data without event" + + err, endpointUrl := f.findEndpointUrl(sseData) + + if err != nil { + t.Errorf("Expected no error when no event found, got: %v", err) + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL when no event found, got '%s'", endpointUrl) + } +} + +// TestFindEndpointUrl_MalformedData tests behavior with malformed SSE data +func TestFindEndpointUrl_MalformedData(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + // Test with malformed data (missing data field) + sseData := "event: endpoint\nnotdata: https://api.example.com/chat\n\n" + + err, endpointUrl := f.findEndpointUrl(sseData) + + // Should return error for malformed data + if err == nil { + t.Errorf("Expected error for malformed data, but got none") + } + + if endpointUrl != "" { + t.Errorf("Expected empty endpoint URL when error occurs, got '%s'", endpointUrl) + } +} + +// TestFindNextLineBreak tests the line break detection functionality +func TestFindNextLineBreak(t *testing.T) { + // Setup mock API + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + f := createTestFilter() + + testCases := []struct { + name string + input string + expectedBreak string + expectedError bool + }{ + { + name: "LF only", + input: "some text\nmore text", + expectedBreak: "\n", + expectedError: false, + }, + { + name: "CR only", + input: "some text\rmore text", + expectedBreak: "\r", + expectedError: false, + }, + { + name: "CRLF", + input: "some text\r\nmore text", + expectedBreak: "\r\n", + expectedError: false, + }, + { + name: "No line break", + input: "some text without break", + expectedBreak: "", + expectedError: false, + }, + { + name: "LF before CR (separate)", + input: "some text\n\rmore text", + expectedBreak: "", + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err, lineBreak := f.findNextLineBreak(tc.input) + + if tc.expectedError && err == nil { + t.Errorf("Expected error, but got none") + } + + if !tc.expectedError && err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + if lineBreak != tc.expectedBreak { + t.Errorf("Expected line break '%v', got '%v'", []byte(tc.expectedBreak), []byte(lineBreak)) + } + }) + } +}