mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 10:45:25 +08:00
Signed-off-by: jingze <daijingze.djz@alibaba-inc.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
475 lines
14 KiB
Go
475 lines
14 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestParseSSEMessage tests SSE message parsing
|
|
func TestParseSSEMessage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
wantEvent string
|
|
wantData string
|
|
wantID string
|
|
shouldParse bool
|
|
}{
|
|
{
|
|
name: "endpoint message",
|
|
input: []byte(`event: endpoint
|
|
data: /messages/?session_id=test123
|
|
|
|
`),
|
|
wantEvent: "endpoint",
|
|
wantData: "/messages/?session_id=test123",
|
|
shouldParse: true,
|
|
},
|
|
{
|
|
name: "message with JSON data",
|
|
input: []byte(`event: message
|
|
data: {"jsonrpc":"2.0","id":1,"result":{"test":"value"}}
|
|
|
|
`),
|
|
wantEvent: "message",
|
|
wantData: `{"jsonrpc":"2.0","id":1,"result":{"test":"value"}}`,
|
|
shouldParse: true,
|
|
},
|
|
{
|
|
name: "incomplete message",
|
|
input: []byte(`event: message
|
|
data: {"jsonrpc":"2.0"`),
|
|
shouldParse: false,
|
|
},
|
|
{
|
|
name: "message with id",
|
|
input: []byte(`id: 123
|
|
event: message
|
|
data: test data
|
|
|
|
`),
|
|
wantEvent: "message",
|
|
wantData: "test data",
|
|
wantID: "123",
|
|
shouldParse: true,
|
|
},
|
|
{
|
|
name: "comment line ignored",
|
|
input: []byte(`: this is a comment
|
|
event: message
|
|
data: test data
|
|
|
|
`),
|
|
wantEvent: "message",
|
|
wantData: "test data",
|
|
shouldParse: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
msg, remaining, err := ParseSSEMessage(tt.input)
|
|
|
|
if err != nil {
|
|
t.Fatalf("parseSSEMessage() error = %v", err)
|
|
}
|
|
|
|
if tt.shouldParse {
|
|
if msg == nil {
|
|
t.Errorf("parseSSEMessage() expected message but got nil")
|
|
return
|
|
}
|
|
if msg.Event != tt.wantEvent {
|
|
t.Errorf("parseSSEMessage() Event = %v, want %v", msg.Event, tt.wantEvent)
|
|
}
|
|
if msg.Data != tt.wantData {
|
|
t.Errorf("parseSSEMessage() Data = %v, want %v", msg.Data, tt.wantData)
|
|
}
|
|
if msg.ID != tt.wantID {
|
|
t.Errorf("parseSSEMessage() ID = %v, want %v", msg.ID, tt.wantID)
|
|
}
|
|
if len(remaining) != 0 {
|
|
t.Errorf("parseSSEMessage() expected no remaining bytes, got %d bytes", len(remaining))
|
|
}
|
|
} else {
|
|
if msg != nil {
|
|
t.Errorf("parseSSEMessage() expected no message but got %v", msg)
|
|
}
|
|
if len(remaining) != len(tt.input) {
|
|
t.Errorf("parseSSEMessage() expected all data as remaining, got %d bytes instead of %d", len(remaining), len(tt.input))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExtractEndpointURL tests endpoint URL extraction
|
|
func TestExtractEndpointURL(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
endpointData string
|
|
baseURL string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "full URL",
|
|
endpointData: "http://example.com/messages?session=123",
|
|
baseURL: "http://backend.com/mcp",
|
|
want: "http://example.com/messages?session=123",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "path only",
|
|
endpointData: "/messages/?session_id=abc",
|
|
baseURL: "http://backend.com/mcp",
|
|
want: "http://backend.com/messages/?session_id=abc",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "https base URL",
|
|
endpointData: "/sse/endpoint",
|
|
baseURL: "https://secure.backend.com:8443/api",
|
|
want: "https://secure.backend.com:8443/sse/endpoint",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "path-only base URL",
|
|
endpointData: "/messages",
|
|
baseURL: "/api/v1",
|
|
want: "/messages",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "path without leading slash",
|
|
endpointData: "api/v1/messages",
|
|
baseURL: "http://backend.com",
|
|
want: "http://backend.com/api/v1/messages",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "path without leading slash with port",
|
|
endpointData: "sse/endpoint",
|
|
baseURL: "https://secure.backend.com:8443",
|
|
want: "https://secure.backend.com:8443/sse/endpoint",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := ExtractEndpointURL(tt.endpointData, tt.baseURL)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("extractEndpointURL() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if got != tt.want {
|
|
t.Errorf("extractEndpointURL() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTransportProtocolValidation tests transport protocol validation
|
|
func TestTransportProtocolValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
transport string
|
|
wantValid bool
|
|
}{
|
|
{
|
|
name: "valid http transport",
|
|
transport: "http",
|
|
wantValid: true,
|
|
},
|
|
{
|
|
name: "valid sse transport",
|
|
transport: "sse",
|
|
wantValid: true,
|
|
},
|
|
{
|
|
name: "invalid transport",
|
|
transport: "websocket",
|
|
wantValid: false,
|
|
},
|
|
{
|
|
name: "empty transport",
|
|
transport: "",
|
|
wantValid: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
transport := TransportProtocol(tt.transport)
|
|
isValid := transport == TransportHTTP || transport == TransportSSE
|
|
if isValid != tt.wantValid {
|
|
t.Errorf("TransportProtocol validation = %v, want %v for %s", isValid, tt.wantValid, tt.transport)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMcpProxyServerTransport tests transport getter/setter
|
|
func TestMcpProxyServerTransport(t *testing.T) {
|
|
server := NewMcpProxyServer("test-server")
|
|
|
|
// Test default transport
|
|
if server.GetTransport() != "" {
|
|
t.Errorf("Expected empty default transport, got %v", server.GetTransport())
|
|
}
|
|
|
|
// Test setting HTTP transport
|
|
server.SetTransport(TransportHTTP)
|
|
if server.GetTransport() != TransportHTTP {
|
|
t.Errorf("Expected HTTP transport, got %v", server.GetTransport())
|
|
}
|
|
|
|
// Test setting SSE transport
|
|
server.SetTransport(TransportSSE)
|
|
if server.GetTransport() != TransportSSE {
|
|
t.Errorf("Expected SSE transport, got %v", server.GetTransport())
|
|
}
|
|
}
|
|
|
|
// TestSSEMessageParsing_MultipleMessages tests parsing multiple SSE messages
|
|
func TestSSEMessageParsing_MultipleMessages(t *testing.T) {
|
|
data := []byte(`event: endpoint
|
|
data: /messages/123
|
|
|
|
event: message
|
|
data: {"id":1}
|
|
|
|
: comment line
|
|
event: message
|
|
data: {"id":2}
|
|
|
|
`)
|
|
|
|
// First message
|
|
msg1, remaining, err := ParseSSEMessage(data)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse first message: %v", err)
|
|
}
|
|
if msg1 == nil || msg1.Event != "endpoint" || msg1.Data != "/messages/123" {
|
|
t.Errorf("First message incorrect: %+v", msg1)
|
|
}
|
|
|
|
// Second message
|
|
msg2, remaining, err := ParseSSEMessage(remaining)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse second message: %v", err)
|
|
}
|
|
if msg2 == nil || msg2.Event != "message" || msg2.Data != `{"id":1}` {
|
|
t.Errorf("Second message incorrect: %+v", msg2)
|
|
}
|
|
|
|
// Third message
|
|
msg3, remaining, err := ParseSSEMessage(remaining)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse third message: %v", err)
|
|
}
|
|
if msg3 == nil || msg3.Event != "message" || msg3.Data != `{"id":2}` {
|
|
t.Errorf("Third message incorrect: %+v", msg3)
|
|
}
|
|
|
|
// Should be no more complete messages
|
|
msg4, _, err := ParseSSEMessage(remaining)
|
|
if err != nil {
|
|
t.Fatalf("Error parsing remaining data: %v", err)
|
|
}
|
|
if msg4 != nil {
|
|
t.Errorf("Expected no more messages, got: %+v", msg4)
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// ParseSSEMessage — additional edge cases (multi-line data, retry, empty)
|
|
// -----------------------------------------------------------------------------
|
|
|
|
func TestParseSSEMessage_EmptyInput(t *testing.T) {
|
|
msg, remaining, err := ParseSSEMessage([]byte(""))
|
|
require.NoError(t, err)
|
|
assert.Nil(t, msg)
|
|
assert.Len(t, remaining, 0)
|
|
}
|
|
|
|
func TestParseSSEMessage_RetryFieldIgnored(t *testing.T) {
|
|
// `retry:` is part of the SSE spec but not implemented — must not break parsing.
|
|
input := []byte("retry: 5000\nevent: message\ndata: hi\n\n")
|
|
msg, _, err := ParseSSEMessage(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, msg)
|
|
assert.Equal(t, "message", msg.Event)
|
|
assert.Equal(t, "hi", msg.Data)
|
|
}
|
|
|
|
func TestParseSSEMessage_MultiLineDataConcatenated(t *testing.T) {
|
|
// Per SSE spec, multiple `data:` lines in one message join with `\n`.
|
|
input := []byte("data: line-one\ndata: line-two\ndata: line-three\n\n")
|
|
msg, _, err := ParseSSEMessage(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, msg)
|
|
assert.Equal(t, "line-one\nline-two\nline-three", msg.Data)
|
|
}
|
|
|
|
func TestParseSSEMessage_NoFinalBlankLine_NoMessageReturned(t *testing.T) {
|
|
// Message without the terminating blank line is treated as incomplete.
|
|
input := []byte("event: message\ndata: payload\n")
|
|
msg, remaining, err := ParseSSEMessage(input)
|
|
require.NoError(t, err)
|
|
assert.Nil(t, msg, "incomplete message must not be returned")
|
|
assert.Equal(t, input, remaining, "remaining is the entire input")
|
|
}
|
|
|
|
func TestParseSSEMessage_LineWithoutColonSkipped(t *testing.T) {
|
|
// SplitN with len<2 → field/value pair can't be formed → skipped, not an error.
|
|
input := []byte("a-line-without-colon\nevent: msg\ndata: x\n\n")
|
|
msg, _, err := ParseSSEMessage(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, msg)
|
|
assert.Equal(t, "msg", msg.Event)
|
|
assert.Equal(t, "x", msg.Data)
|
|
}
|
|
|
|
func TestParseSSEMessage_UnknownFieldIgnored(t *testing.T) {
|
|
// `random-field:` is parsed but the switch case ignores it.
|
|
input := []byte("random-field: stuff\nevent: msg\ndata: x\n\n")
|
|
msg, _, err := ParseSSEMessage(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, msg)
|
|
assert.Equal(t, "msg", msg.Event)
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// ExtractEndpointURL — edge cases not in the table
|
|
// -----------------------------------------------------------------------------
|
|
|
|
func TestExtractEndpointURL_HttpsPassthrough(t *testing.T) {
|
|
got, err := ExtractEndpointURL("https://other.example/x", "http://b.example")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "https://other.example/x", got, "full https URL must pass through unchanged")
|
|
}
|
|
|
|
func TestExtractEndpointURL_EmptyEndpointData_PathOnlyBase(t *testing.T) {
|
|
got, err := ExtractEndpointURL("", "/some/path")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "", got, "empty endpointData with path-only base → empty result")
|
|
}
|
|
|
|
func TestExtractEndpointURL_RelativeEndpointWithSchemeBase(t *testing.T) {
|
|
got, err := ExtractEndpointURL("messages", "http://b.example/mcp")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "http://b.example/messages", got, "leading slash auto-inserted")
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// applyProxyAuthenticationForSSE — pure URL+header munging (no proxywasm)
|
|
// -----------------------------------------------------------------------------
|
|
|
|
func TestApplyProxyAuthenticationForSSE_ApiKeyHeader(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
server.AddSecurityScheme(SecurityScheme{
|
|
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
|
|
DefaultCredential: "abc",
|
|
})
|
|
|
|
headers := [][2]string{{"X-Other", "v"}}
|
|
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "http://backend/x", got, "no query → URL preserved")
|
|
|
|
found := false
|
|
for _, kv := range headers {
|
|
if strings.EqualFold(kv[0], "X-Api-Key") {
|
|
assert.Equal(t, "abc", kv[1])
|
|
found = true
|
|
}
|
|
}
|
|
assert.True(t, found, "API key header must be injected")
|
|
}
|
|
|
|
func TestApplyProxyAuthenticationForSSE_ApiKeyQuery_PreservesExisting(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
server.AddSecurityScheme(SecurityScheme{
|
|
ID: "K", Type: "apiKey", In: "query", Name: "api_key",
|
|
DefaultCredential: "secret",
|
|
})
|
|
|
|
headers := [][2]string{}
|
|
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/x?existing=1")
|
|
require.NoError(t, err)
|
|
// Query is rebuilt via url.Values.Encode — both pairs must be present.
|
|
assert.Contains(t, got, "api_key=secret")
|
|
assert.Contains(t, got, "existing=1")
|
|
}
|
|
|
|
func TestApplyProxyAuthenticationForSSE_PathOnlyURL_PreservesShape(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
server.AddSecurityScheme(SecurityScheme{
|
|
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
|
|
DefaultCredential: "abc",
|
|
})
|
|
|
|
headers := [][2]string{}
|
|
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "/relative/path")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "/relative/path", got, "path-only URL must come back as path-only")
|
|
}
|
|
|
|
func TestApplyProxyAuthenticationForSSE_HttpBearerPassthrough(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
server.AddSecurityScheme(SecurityScheme{ID: "B", Type: "http", Scheme: "bearer"})
|
|
|
|
headers := [][2]string{}
|
|
got, err := applyProxyAuthenticationForSSE(server, "B", "passthrough-token", &headers, "http://backend/x")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "http://backend/x", got)
|
|
|
|
var authValue string
|
|
for _, kv := range headers {
|
|
if strings.EqualFold(kv[0], "Authorization") {
|
|
authValue = kv[1]
|
|
}
|
|
}
|
|
assert.Equal(t, "Bearer passthrough-token", authValue)
|
|
}
|
|
|
|
func TestApplyProxyAuthenticationForSSE_MissingScheme_ReturnsError(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
headers := [][2]string{}
|
|
_, err := applyProxyAuthenticationForSSE(server, "missing", "", &headers, "http://backend/x")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "not found")
|
|
}
|
|
|
|
func TestApplyProxyAuthenticationForSSE_PreservesFragment(t *testing.T) {
|
|
server := NewMcpProxyServer("p")
|
|
server.AddSecurityScheme(SecurityScheme{
|
|
ID: "K", Type: "apiKey", In: "header", Name: "X-Api-Key",
|
|
DefaultCredential: "abc",
|
|
})
|
|
|
|
headers := [][2]string{}
|
|
got, err := applyProxyAuthenticationForSSE(server, "K", "", &headers, "http://backend/path#section-2")
|
|
require.NoError(t, err)
|
|
assert.Contains(t, got, "#section-2", "fragment must round-trip")
|
|
}
|