diff --git a/plugins/golang-filter/mcp-session/filter.go b/plugins/golang-filter/mcp-session/filter.go index acc034e9c..13310aa20 100644 --- a/plugins/golang-filter/mcp-session/filter.go +++ b/plugins/golang-filter/mcp-session/filter.go @@ -281,20 +281,7 @@ func (f *filter) encodeDataFromSSEUpstream(buffer api.BufferInstance, endStream bufferBytes := buffer.Bytes() bufferData := string(bufferBytes) - err, lineBreak := f.findSSELineBreak(bufferData) - if err != nil { - api.LogWarnf("Failed to find line break in SSE data: %v", err) - f.needProcess = false - return api.Continue - } - if lineBreak == "" { - // Have not found any line break. Need to buffer and check again. - return api.StopAndBuffer - } - - api.LogDebugf("Line break sequence: %v", []byte(lineBreak)) - - err, endpointUrl := f.findEndpointUrl(bufferData, lineBreak) + err, endpointUrl := f.findEndpointUrl(bufferData) if err != nil { api.LogWarnf("Failed to find endpoint URL in SSE data: %v", err) f.needProcess = false @@ -371,7 +358,7 @@ func (f *filter) rewriteEndpointUrl(endpointUrl string) (bool, string) { return true, endpointUrl } -func (f *filter) findSSELineBreak(bufferData string) (error, string) { +func (f *filter) findNextLineBreak(bufferData string) (error, string) { // See https://html.spec.whatwg.org/multipage/server-sent-events.html crIndex := strings.IndexAny(bufferData, "\r") lfIndex := strings.IndexAny(bufferData, "\n") @@ -381,11 +368,20 @@ func (f *filter) findSSELineBreak(bufferData string) (error, string) { } lineBreak := "" if crIndex != -1 && lfIndex != -1 { - if crIndex+1 != lfIndex { - // Found both line breaks, but they are not adjacent. Skip body processing. - return errors.New("found non-adjacent CR and LF"), "" + if crIndex < lfIndex { + if crIndex+1 == lfIndex { + lineBreak = "\r\n" + } else { + lineBreak = "\r" + } + } else { + if crIndex == lfIndex+1 { + // Found unexpected "\n\r". Skip body processing. + return errors.New("found unexpected LF+CR"), "" + } else { + lineBreak = "\n" + } } - lineBreak = "\r\n" } else if crIndex != -1 { lineBreak = "\r" } else { @@ -394,12 +390,21 @@ func (f *filter) findSSELineBreak(bufferData string) (error, string) { return nil, lineBreak } -func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) { +func (f *filter) findEndpointUrl(bufferData string) (error, string) { eventIndex := strings.Index(bufferData, "event:") if eventIndex == -1 { return nil, "" } bufferData = bufferData[eventIndex:] + err, lineBreak := f.findNextLineBreak(bufferData) + if err != nil { + return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" + } + if lineBreak == "" { + // No line break found, which means the data is not enough. + return nil, "" + } + api.LogDebugf("event line break sequence: %v", []byte(lineBreak)) eventEndIndex := strings.Index(bufferData, lineBreak) if eventEndIndex == -1 { return nil, "" @@ -409,6 +414,15 @@ func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) { return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), "" } bufferData = bufferData[eventEndIndex+len(lineBreak):] + err, lineBreak = f.findNextLineBreak(bufferData) + if err != nil { + return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), "" + } + if lineBreak == "" { + // No line break found, which means the data is not enough. + return nil, "" + } + api.LogDebugf("data line break sequence: %v", []byte(lineBreak)) dataEndIndex := strings.Index(bufferData, lineBreak) if dataEndIndex == -1 { // Data received not enough.