mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat(mcp/sse): support passthourgh the query parameter in sse server to the rest api server (#2460)
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -94,13 +95,15 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
|||||||
defer s.sessions.Delete(sessionID)
|
defer s.sessions.Delete(sessionID)
|
||||||
|
|
||||||
channel := GetSSEChannelName(sessionID)
|
channel := GetSSEChannelName(sessionID)
|
||||||
|
u, err := url.Parse(s.baseURL + s.messageEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
api.LogErrorf("Failed to parse base URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
messageEndpoint := fmt.Sprintf(
|
q := u.Query()
|
||||||
"%s%s?sessionId=%s",
|
q.Set("sessionId", sessionID)
|
||||||
s.baseURL,
|
u.RawQuery = q.Encode()
|
||||||
s.messageEndpoint,
|
messageEndpoint := u.String()
|
||||||
sessionID,
|
|
||||||
)
|
|
||||||
|
|
||||||
// go func() {
|
// go func() {
|
||||||
// for {
|
// for {
|
||||||
@@ -126,7 +129,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
|||||||
// }
|
// }
|
||||||
// }()
|
// }()
|
||||||
|
|
||||||
err := s.redisClient.Subscribe(channel, stopChan, func(message string) {
|
err = s.redisClient.Subscribe(channel, stopChan, func(message string) {
|
||||||
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
||||||
api.LogDebugf("SSE Send message: %s", message)
|
api.LogDebugf("SSE Send message: %s", message)
|
||||||
cb.EncoderFilterCallbacks().InjectData([]byte(message))
|
cb.EncoderFilterCallbacks().InjectData([]byte(message))
|
||||||
@@ -210,7 +213,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
|||||||
var status int
|
var status int
|
||||||
// Only send response if there is one (not for notifications)
|
// Only send response if there is one (not for notifications)
|
||||||
if response != nil {
|
if response != nil {
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
status = http.StatusAccepted
|
status = http.StatusAccepted
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -129,9 +129,15 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
|||||||
if method != http.MethodGet {
|
if method != http.MethodGet {
|
||||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||||
} else {
|
} else {
|
||||||
|
// to support the query param in Message Endpoint
|
||||||
|
trimmed := strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)
|
||||||
|
if rq := requestUrl.RawQuery; rq != "" {
|
||||||
|
trimmed += "?" + rq
|
||||||
|
}
|
||||||
|
|
||||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||||
common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)),
|
common.WithMessageEndpoint(trimmed),
|
||||||
common.WithRedisClient(f.config.redisClient))
|
common.WithRedisClient(f.config.redisClient))
|
||||||
f.serverName = f.config.defaultServer.GetServerName()
|
f.serverName = f.config.defaultServer.GetServerName()
|
||||||
body := "SSE connection create"
|
body := "SSE connection create"
|
||||||
|
|||||||
@@ -88,6 +88,11 @@ checkDesiredVersion() {
|
|||||||
elif [ "${HAS_WGET}" == "true" ]; then
|
elif [ "${HAS_WGET}" == "true" ]; then
|
||||||
VERSION=$(wget $latest_release_url -O - 2>&1 | grep 'href="/alibaba/higress/releases/tag/v[0-9]*.[0-9]*.[0-9]*\"' | sed -E 's/.*\/alibaba\/higress\/releases\/tag\/(v[0-9\.]+)".*/\1/g' | head -1)
|
VERSION=$(wget $latest_release_url -O - 2>&1 | grep 'href="/alibaba/higress/releases/tag/v[0-9]*.[0-9]*.[0-9]*\"' | sed -E 's/.*\/alibaba\/higress\/releases\/tag\/(v[0-9\.]+)".*/\1/g' | head -1)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ "$VERSION" == "" ]; then
|
||||||
|
echo "Failed to determine latest version. Please check network or set VERSION manually."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user