opt: mcp endpoint parser improvements (#2673)

This commit is contained in:
澄潭
2025-07-29 12:41:29 +08:00
committed by GitHub
parent 6a1557f6ac
commit ff9a29c5d9
3 changed files with 551 additions and 42 deletions

View File

@@ -64,7 +64,7 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
redisClient, err := common.NewRedisClient(redisConfig)
if err != nil {
api.LogErrorf("Failed to initialize Redis client: %w", err)
api.LogErrorf("Failed to initialize Redis client: %v", err)
} else {
api.LogDebug("Redis client initialized")
}

View File

@@ -397,48 +397,93 @@ func (f *filter) findNextLineBreak(bufferData string) (error, string) {
}
func (f *filter) findEndpointUrl(bufferData string) (error, string) {
eventIndex := strings.Index(bufferData, "event:")
if eventIndex == -1 {
return nil, ""
// Keep searching for events until we find an endpoint event or run out of data
for {
eventIndex := strings.Index(bufferData, "event:")
if eventIndex == -1 {
// No more events found
return nil, ""
}
// Move to the start of the event
bufferData = bufferData[eventIndex:]
// Find the end of the event line
err, lineBreak := f.findNextLineBreak(bufferData)
if err != nil {
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
}
if lineBreak == "" {
// No line break found, which means the data is not enough.
return nil, ""
}
api.LogDebugf("event line break sequence: %v", []byte(lineBreak))
eventEndIndex := strings.Index(bufferData, lineBreak)
if eventEndIndex == -1 {
return nil, ""
}
eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex])
// Move past the event line
bufferData = bufferData[eventEndIndex+len(lineBreak):]
if eventName == "endpoint" {
// Found endpoint event, now look for the data field
err, lineBreak = f.findNextLineBreak(bufferData)
if err != nil {
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
}
if lineBreak == "" {
// No line break found, which means the data is not enough.
return nil, ""
}
api.LogDebugf("data line break sequence: %v", []byte(lineBreak))
dataEndIndex := strings.Index(bufferData, lineBreak)
if dataEndIndex == -1 {
// Data received not enough.
return nil, ""
}
eventData := bufferData[:dataEndIndex]
if !strings.HasPrefix(eventData, "data:") {
return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), ""
}
return nil, strings.TrimSpace(eventData[len("data:"):])
} else {
// Not an endpoint event, skip to the next event
api.LogDebugf("Skipping non-endpoint event: %s", eventName)
// First, we need to skip the data field of this event
err, lineBreak = f.findNextLineBreak(bufferData)
if err != nil {
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
}
if lineBreak == "" {
// No line break found, which means the data is not enough.
return nil, ""
}
dataEndIndex := strings.Index(bufferData, lineBreak)
if dataEndIndex == -1 {
// Data received not enough.
return nil, ""
}
// Move past the data line
bufferData = bufferData[dataEndIndex+len(lineBreak):]
// Skip any additional empty lines that separate events
for strings.HasPrefix(bufferData, lineBreak) {
bufferData = bufferData[len(lineBreak):]
}
// Continue to look for the next event
}
}
bufferData = bufferData[eventIndex:]
err, lineBreak := f.findNextLineBreak(bufferData)
if err != nil {
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
}
if lineBreak == "" {
// No line break found, which means the data is not enough.
return nil, ""
}
api.LogDebugf("event line break sequence: %v", []byte(lineBreak))
eventEndIndex := strings.Index(bufferData, lineBreak)
if eventEndIndex == -1 {
return nil, ""
}
eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex])
if eventName != "endpoint" {
return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), ""
}
bufferData = bufferData[eventEndIndex+len(lineBreak):]
err, lineBreak = f.findNextLineBreak(bufferData)
if err != nil {
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
}
if lineBreak == "" {
// No line break found, which means the data is not enough.
return nil, ""
}
api.LogDebugf("data line break sequence: %v", []byte(lineBreak))
dataEndIndex := strings.Index(bufferData, lineBreak)
if dataEndIndex == -1 {
// Data received not enough.
return nil, ""
}
eventData := bufferData[:dataEndIndex]
if !strings.HasPrefix(eventData, "data:") {
return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), ""
}
return nil, strings.TrimSpace(eventData[len("data:"):])
}
// OnDestroy stops the goroutine

View File

@@ -0,0 +1,464 @@
package mcp_session
import (
"fmt"
"testing"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
// Mock implementation of CommonCAPI for testing
type mockCommonCAPI struct {
logs []string
}
func (m *mockCommonCAPI) Log(level api.LogType, message string) {
fmt.Printf("[%s] %s", level, message)
m.logs = append(m.logs, message)
}
func (m *mockCommonCAPI) LogLevel() api.LogType {
return api.Debug
}
// Test helper to create a filter instance for testing
func createTestFilter() *filter {
return &filter{}
}
// Test helper to create a match rule for testing
func createTestMatchRule() common.MatchRule {
return common.MatchRule{
UpstreamType: common.SSEUpstream,
EnablePathRewrite: true,
PathRewritePrefix: "/api/v1",
MatchRulePath: "/mcp",
MatchRuleType: common.PrefixMatch,
MatchRuleDomain: "example.com",
}
}
// TestFindEndpointUrl_ValidEndpointMessage tests the current behavior with a valid endpoint message
func TestFindEndpointUrl_ValidEndpointMessage(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with valid endpoint message
sseData := "event: endpoint\ndata: https://api.example.com/chat\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
expectedUrl := "https://api.example.com/chat"
if endpointUrl != expectedUrl {
t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl)
}
}
// TestFindEndpointUrl_NonEndpointFirstMessage tests improved behavior with non-endpoint first message
func TestFindEndpointUrl_NonEndpointFirstMessage(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with ping message first (this should now succeed with improved implementation)
sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
// Improved implementation should handle non-endpoint first message
if err != nil {
t.Errorf("Expected no error for non-endpoint first message, got: %v", err)
}
expectedUrl := "https://api.example.com/chat"
if endpointUrl != expectedUrl {
t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl)
}
// Check that the non-endpoint event was logged
found := false
for _, log := range mockAPI.logs {
if log == "Skipping non-endpoint event: ping" {
found = true
break
}
}
if !found {
t.Errorf("Expected log message about skipping ping event not found")
}
}
// TestFindEndpointUrl_MultipleNonEndpointMessages tests multiple non-endpoint messages before endpoint
func TestFindEndpointUrl_MultipleNonEndpointMessages(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with multiple non-endpoint messages before endpoint
sseData := "event: ping\ndata: alive\n\nevent: status\ndata: connecting\n\nevent: info\ndata: ready\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
expectedUrl := "https://api.example.com/chat"
if endpointUrl != expectedUrl {
t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl)
}
// Check that all non-endpoint events were logged
expectedLogs := []string{
"Skipping non-endpoint event: ping",
"Skipping non-endpoint event: status",
"Skipping non-endpoint event: info",
}
for _, expectedLog := range expectedLogs {
found := false
for _, log := range mockAPI.logs {
if log == expectedLog {
found = true
break
}
}
if !found {
t.Errorf("Expected log message '%s' not found", expectedLog)
}
}
}
// TestFindEndpointUrl_EndpointInMiddle tests endpoint message in the middle of other messages
func TestFindEndpointUrl_EndpointInMiddle(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with endpoint message in the middle
sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat\n\nevent: status\ndata: ready\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
expectedUrl := "https://api.example.com/chat"
if endpointUrl != expectedUrl {
t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl)
}
// Check that the ping event was logged as skipped
found := false
for _, log := range mockAPI.logs {
if log == "Skipping non-endpoint event: ping" {
found = true
break
}
}
if !found {
t.Errorf("Expected log message about skipping ping event not found")
}
}
// TestFindEndpointUrl_NoEndpointMessage tests when no endpoint message is present
func TestFindEndpointUrl_NoEndpointMessage(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with no endpoint message
sseData := "event: ping\ndata: alive\n\nevent: status\ndata: connecting\n\nevent: info\ndata: ready\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error when no endpoint found, got: %v", err)
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL when no endpoint found, got '%s'", endpointUrl)
}
// Check that all non-endpoint events were logged
expectedLogs := []string{
"Skipping non-endpoint event: ping",
"Skipping non-endpoint event: status",
"Skipping non-endpoint event: info",
}
for _, expectedLog := range expectedLogs {
found := false
for _, log := range mockAPI.logs {
if log == expectedLog {
found = true
break
}
}
if !found {
t.Errorf("Expected log message '%s' not found", expectedLog)
}
}
}
// TestFindEndpointUrl_IncompleteEndpointMessage tests incomplete endpoint message
func TestFindEndpointUrl_IncompleteEndpointMessage(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with incomplete endpoint message (missing final line break)
sseData := "event: ping\ndata: alive\n\nevent: endpoint\ndata: https://api.example.com/chat"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error for incomplete endpoint message, got: %v", err)
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL for incomplete message, got '%s'", endpointUrl)
}
}
// TestFindEndpointUrl_IncompleteNonEndpointMessage tests incomplete non-endpoint message
func TestFindEndpointUrl_IncompleteNonEndpointMessage(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with incomplete non-endpoint message
sseData := "event: ping\ndata: alive"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error for incomplete non-endpoint message, got: %v", err)
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL for incomplete message, got '%s'", endpointUrl)
}
}
// TestFindEndpointUrl_MalformedEndpointData tests malformed endpoint data
func TestFindEndpointUrl_MalformedEndpointData(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with malformed endpoint data (missing data field)
sseData := "event: ping\ndata: alive\n\nevent: endpoint\nnotdata: https://api.example.com/chat\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
// Should return error for malformed endpoint data
if err == nil {
t.Errorf("Expected error for malformed endpoint data, but got none")
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL when error occurs, got '%s'", endpointUrl)
}
}
// TestFindEndpointUrl_DifferentLineBreaks tests different line break formats with improved version
func TestFindEndpointUrl_DifferentLineBreaks(t *testing.T) {
testCases := []struct {
name string
sseData string
expected string
}{
{
name: "CRLF line breaks with ping first",
sseData: "event: ping\r\ndata: alive\r\n\r\nevent: endpoint\r\ndata: https://api.example.com/chat\r\n\r\n",
expected: "https://api.example.com/chat",
},
{
name: "CR line breaks with status first",
sseData: "event: status\rdata: ready\r\revent: endpoint\rdata: https://api.example.com/chat\r\r",
expected: "https://api.example.com/chat",
},
{
name: "LF line breaks with info first",
sseData: "event: info\ndata: starting\n\nevent: endpoint\ndata: https://api.example.com/chat\n\n",
expected: "https://api.example.com/chat",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
err, endpointUrl := f.findEndpointUrl(tc.sseData)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if endpointUrl != tc.expected {
t.Errorf("Expected endpoint URL '%s', got '%s'", tc.expected, endpointUrl)
}
})
}
}
// TestFindEndpointUrl_WithWhitespace tests improved version with whitespace
func TestFindEndpointUrl_WithWhitespace(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with whitespace around event names and data
sseData := "event: ping \ndata: alive \n\nevent: endpoint \ndata: https://api.example.com/chat \n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
expectedUrl := "https://api.example.com/chat"
if endpointUrl != expectedUrl {
t.Errorf("Expected endpoint URL '%s', got '%s'", expectedUrl, endpointUrl)
}
}
// TestFindEndpointUrl_NoEventFound tests behavior when no event is found
func TestFindEndpointUrl_NoEventFound(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with data that doesn't contain event
sseData := "some random data without event"
err, endpointUrl := f.findEndpointUrl(sseData)
if err != nil {
t.Errorf("Expected no error when no event found, got: %v", err)
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL when no event found, got '%s'", endpointUrl)
}
}
// TestFindEndpointUrl_MalformedData tests behavior with malformed SSE data
func TestFindEndpointUrl_MalformedData(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
// Test with malformed data (missing data field)
sseData := "event: endpoint\nnotdata: https://api.example.com/chat\n\n"
err, endpointUrl := f.findEndpointUrl(sseData)
// Should return error for malformed data
if err == nil {
t.Errorf("Expected error for malformed data, but got none")
}
if endpointUrl != "" {
t.Errorf("Expected empty endpoint URL when error occurs, got '%s'", endpointUrl)
}
}
// TestFindNextLineBreak tests the line break detection functionality
func TestFindNextLineBreak(t *testing.T) {
// Setup mock API
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
f := createTestFilter()
testCases := []struct {
name string
input string
expectedBreak string
expectedError bool
}{
{
name: "LF only",
input: "some text\nmore text",
expectedBreak: "\n",
expectedError: false,
},
{
name: "CR only",
input: "some text\rmore text",
expectedBreak: "\r",
expectedError: false,
},
{
name: "CRLF",
input: "some text\r\nmore text",
expectedBreak: "\r\n",
expectedError: false,
},
{
name: "No line break",
input: "some text without break",
expectedBreak: "",
expectedError: false,
},
{
name: "LF before CR (separate)",
input: "some text\n\rmore text",
expectedBreak: "",
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err, lineBreak := f.findNextLineBreak(tc.input)
if tc.expectedError && err == nil {
t.Errorf("Expected error, but got none")
}
if !tc.expectedError && err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if lineBreak != tc.expectedBreak {
t.Errorf("Expected line break '%v', got '%v'", []byte(tc.expectedBreak), []byte(lineBreak))
}
})
}
}