feat(mcp/sse): support passthourgh the query parameter in sse server to the rest api server (#2460)

This commit is contained in:
xingpiaoliang
2025-06-20 15:07:45 +08:00
committed by GitHub
parent db66df39c4
commit 04cbbfc7e8
3 changed files with 23 additions and 9 deletions

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"sync"
"time"
@@ -94,13 +95,15 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
defer s.sessions.Delete(sessionID)
channel := GetSSEChannelName(sessionID)
u, err := url.Parse(s.baseURL + s.messageEndpoint)
if err != nil {
api.LogErrorf("Failed to parse base URL: %v", err)
}
messageEndpoint := fmt.Sprintf(
"%s%s?sessionId=%s",
s.baseURL,
s.messageEndpoint,
sessionID,
)
q := u.Query()
q.Set("sessionId", sessionID)
u.RawQuery = q.Encode()
messageEndpoint := u.String()
// go func() {
// for {
@@ -126,7 +129,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
// }
// }()
err := s.redisClient.Subscribe(channel, stopChan, func(message string) {
err = s.redisClient.Subscribe(channel, stopChan, func(message string) {
defer cb.EncoderFilterCallbacks().RecoverPanic()
api.LogDebugf("SSE Send message: %s", message)
cb.EncoderFilterCallbacks().InjectData([]byte(message))
@@ -210,7 +213,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
var status int
// Only send response if there is one (not for notifications)
if response != nil {
if sessionID != "" {
if sessionID != "" {
w.WriteHeader(http.StatusAccepted)
status = http.StatusAccepted
} else {

View File

@@ -129,9 +129,15 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
if method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
// to support the query param in Message Endpoint
trimmed := strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)
if rq := requestUrl.RawQuery; rq != "" {
trimmed += "?" + rq
}
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)),
common.WithMessageEndpoint(trimmed),
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"

View File

@@ -88,6 +88,11 @@ checkDesiredVersion() {
elif [ "${HAS_WGET}" == "true" ]; then
VERSION=$(wget $latest_release_url -O - 2>&1 | grep 'href="/alibaba/higress/releases/tag/v[0-9]*.[0-9]*.[0-9]*\"' | sed -E 's/.*\/alibaba\/higress\/releases\/tag\/(v[0-9\.]+)".*/\1/g' | head -1)
fi
if [ "$VERSION" == "" ]; then
echo "Failed to determine latest version. Please check network or set VERSION manually."
exit 1
fi
fi
}