mirror of
https://github.com/alibaba/higress.git
synced 2026-02-27 22:20:57 +08:00
501 lines
17 KiB
Go
501 lines
17 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
)
|
|
|
|
// McpProxyConfig represents the configuration for MCP proxy server
|
|
// Note: mcpServerURL, timeout, defaultDownstreamSecurity, and defaultUpstreamSecurity
|
|
// are now direct server fields, not part of this config structure
|
|
type McpProxyConfig struct {
|
|
// This structure is kept for any additional server configuration that may be needed in the future
|
|
// Currently, most configuration is handled as direct server fields
|
|
}
|
|
|
|
// TransportProtocol represents the transport protocol type for MCP proxy
|
|
type TransportProtocol string
|
|
|
|
const (
|
|
TransportHTTP TransportProtocol = "http" // StreamableHTTP protocol
|
|
TransportSSE TransportProtocol = "sse" // SSE protocol
|
|
)
|
|
|
|
// ToolArg represents an argument for a proxy tool
|
|
type ToolArg struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Type string `json:"type"`
|
|
Required bool `json:"required"`
|
|
Default interface{} `json:"default,omitempty"`
|
|
Enum []interface{} `json:"enum,omitempty"`
|
|
}
|
|
|
|
// McpProxyToolConfig represents a tool configuration for MCP proxy
|
|
type McpProxyToolConfig struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server
|
|
Args []ToolArg `json:"args"`
|
|
OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18
|
|
RequestTemplate RequestTemplate `json:"requestTemplate,omitempty"`
|
|
}
|
|
|
|
// RequestTemplate defines request template configuration for proxy tools
|
|
type RequestTemplate struct {
|
|
Security SecurityRequirement `json:"security,omitempty"`
|
|
}
|
|
|
|
// McpProxyServer implements Server interface for MCP-to-MCP proxy
|
|
type McpProxyServer struct {
|
|
Name string
|
|
base BaseMCPServer
|
|
toolsConfig map[string]McpProxyToolConfig
|
|
securitySchemes map[string]SecurityScheme
|
|
defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication
|
|
defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication
|
|
mcpServerURL string // Backend MCP server URL
|
|
timeout int // Request timeout in milliseconds
|
|
transport TransportProtocol // Transport protocol (http or sse)
|
|
passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security
|
|
}
|
|
|
|
// NewMcpProxyServer creates a new MCP proxy server
|
|
func NewMcpProxyServer(name string) *McpProxyServer {
|
|
return &McpProxyServer{
|
|
Name: name,
|
|
base: NewBaseMCPServer(),
|
|
toolsConfig: make(map[string]McpProxyToolConfig),
|
|
securitySchemes: make(map[string]SecurityScheme),
|
|
}
|
|
}
|
|
|
|
// AddSecurityScheme adds a security scheme to the server's map
|
|
func (s *McpProxyServer) AddSecurityScheme(scheme SecurityScheme) {
|
|
if s.securitySchemes == nil {
|
|
s.securitySchemes = make(map[string]SecurityScheme)
|
|
}
|
|
s.securitySchemes[scheme.ID] = scheme
|
|
}
|
|
|
|
// GetSecurityScheme retrieves a security scheme by its ID from the map
|
|
func (s *McpProxyServer) GetSecurityScheme(id string) (SecurityScheme, bool) {
|
|
scheme, ok := s.securitySchemes[id]
|
|
return scheme, ok
|
|
}
|
|
|
|
// SetDefaultDownstreamSecurity sets the default downstream security configuration
|
|
func (s *McpProxyServer) SetDefaultDownstreamSecurity(security SecurityRequirement) {
|
|
s.defaultDownstreamSecurity = security
|
|
}
|
|
|
|
// GetDefaultDownstreamSecurity gets the default downstream security configuration
|
|
func (s *McpProxyServer) GetDefaultDownstreamSecurity() SecurityRequirement {
|
|
return s.defaultDownstreamSecurity
|
|
}
|
|
|
|
// SetDefaultUpstreamSecurity sets the default upstream security configuration
|
|
func (s *McpProxyServer) SetDefaultUpstreamSecurity(security SecurityRequirement) {
|
|
s.defaultUpstreamSecurity = security
|
|
}
|
|
|
|
// GetDefaultUpstreamSecurity gets the default upstream security configuration
|
|
func (s *McpProxyServer) GetDefaultUpstreamSecurity() SecurityRequirement {
|
|
return s.defaultUpstreamSecurity
|
|
}
|
|
|
|
// SetMcpServerURL sets the backend MCP server URL
|
|
func (s *McpProxyServer) SetMcpServerURL(url string) {
|
|
s.mcpServerURL = url
|
|
}
|
|
|
|
// GetMcpServerURL gets the backend MCP server URL
|
|
func (s *McpProxyServer) GetMcpServerURL() string {
|
|
return s.mcpServerURL
|
|
}
|
|
|
|
// SetTimeout sets the request timeout in milliseconds
|
|
func (s *McpProxyServer) SetTimeout(timeout int) {
|
|
s.timeout = timeout
|
|
}
|
|
|
|
// GetTimeout gets the request timeout in milliseconds
|
|
func (s *McpProxyServer) GetTimeout() int {
|
|
return s.timeout
|
|
}
|
|
|
|
// SetTransport sets the transport protocol
|
|
func (s *McpProxyServer) SetTransport(transport TransportProtocol) {
|
|
s.transport = transport
|
|
}
|
|
|
|
// GetTransport gets the transport protocol
|
|
func (s *McpProxyServer) GetTransport() TransportProtocol {
|
|
return s.transport
|
|
}
|
|
|
|
// AddMCPTool implements Server interface
|
|
func (s *McpProxyServer) AddMCPTool(name string, tool Tool) Server {
|
|
s.base.AddMCPTool(name, tool)
|
|
return s
|
|
}
|
|
|
|
// AddProxyTool adds a proxy tool configuration
|
|
func (s *McpProxyServer) AddProxyTool(toolConfig McpProxyToolConfig) error {
|
|
s.toolsConfig[toolConfig.Name] = toolConfig
|
|
s.base.AddMCPTool(toolConfig.Name, &McpProxyTool{
|
|
serverName: s.Name,
|
|
name: toolConfig.Name,
|
|
toolConfig: toolConfig,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// GetMCPTools implements Server interface
|
|
func (s *McpProxyServer) GetMCPTools() map[string]Tool {
|
|
return s.base.GetMCPTools()
|
|
}
|
|
|
|
// SetConfig implements Server interface
|
|
func (s *McpProxyServer) SetConfig(config []byte) {
|
|
s.base.SetConfig(config)
|
|
}
|
|
|
|
// GetConfig implements Server interface
|
|
func (s *McpProxyServer) GetConfig(v any) {
|
|
s.base.GetConfig(v)
|
|
}
|
|
|
|
// Clone implements Server interface
|
|
func (s *McpProxyServer) Clone() Server {
|
|
newServer := &McpProxyServer{
|
|
Name: s.Name,
|
|
base: s.base.CloneBase(),
|
|
toolsConfig: make(map[string]McpProxyToolConfig),
|
|
securitySchemes: make(map[string]SecurityScheme),
|
|
}
|
|
for k, v := range s.toolsConfig {
|
|
newServer.toolsConfig[k] = v
|
|
}
|
|
// Deep copy securitySchemes
|
|
if s.securitySchemes != nil {
|
|
for k, v := range s.securitySchemes {
|
|
newServer.securitySchemes[k] = v
|
|
}
|
|
}
|
|
return newServer
|
|
}
|
|
|
|
// GetToolConfig returns the proxy tool configuration for a given tool name
|
|
func (s *McpProxyServer) GetToolConfig(name string) (McpProxyToolConfig, bool) {
|
|
config, ok := s.toolsConfig[name]
|
|
return config, ok
|
|
}
|
|
|
|
// SetPassthroughAuthHeader sets the passthrough auth header flag
|
|
func (s *McpProxyServer) SetPassthroughAuthHeader(passthrough bool) {
|
|
s.passthroughAuthHeader = passthrough
|
|
}
|
|
|
|
// GetPassthroughAuthHeader gets the passthrough auth header flag
|
|
func (s *McpProxyServer) GetPassthroughAuthHeader() bool {
|
|
return s.passthroughAuthHeader
|
|
}
|
|
|
|
// ForwardToolsList forwards tools/list request to backend MCP server
|
|
func (s *McpProxyServer) ForwardToolsList(ctx HttpContext, cursor *string) error {
|
|
wrapperCtx := ctx.(wrapper.HttpContext)
|
|
|
|
// Handle default downstream security for tools/list requests
|
|
// tools/list requests use server-level default authentication configuration
|
|
passthroughCredential := ""
|
|
downstreamSecurity := s.GetDefaultDownstreamSecurity()
|
|
if downstreamSecurity.ID != "" {
|
|
clientScheme, schemeOk := s.GetSecurityScheme(downstreamSecurity.ID)
|
|
if !schemeOk {
|
|
log.Warnf("Default downstream security scheme ID '%s' not found for tools/list request.", downstreamSecurity.ID)
|
|
} else {
|
|
// Extract and remove the credential from the incoming request
|
|
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
|
|
if err != nil {
|
|
log.Warnf("Failed to extract/remove incoming credential for tools/list using scheme %s: %v", clientScheme.ID, err)
|
|
} else if extractedCred == "" {
|
|
log.Debugf("No incoming credential found for tools/list using scheme %s for extraction/removal.", clientScheme.ID)
|
|
}
|
|
|
|
// Only use passthrough if explicitly configured
|
|
if downstreamSecurity.Passthrough && extractedCred != "" {
|
|
passthroughCredential = extractedCred
|
|
log.Debugf("Passthrough credential set for tools/list request.")
|
|
}
|
|
}
|
|
} else {
|
|
// Fallback: Remove Authorization header if no downstream security is defined
|
|
// This prevents downstream credentials from being mistakenly passed to upstream
|
|
// Unless passthroughAuthHeader is explicitly set to true
|
|
if !s.GetPassthroughAuthHeader() {
|
|
proxywasm.RemoveHttpRequestHeader("Authorization")
|
|
}
|
|
}
|
|
|
|
// Create protocol handler using server fields
|
|
handler := NewMcpProtocolHandler(s.GetMcpServerURL(), s.GetTimeout())
|
|
|
|
// Prepare authentication information for gateway-to-backend communication
|
|
var authInfo *ProxyAuthInfo
|
|
upstreamSecurity := s.GetDefaultUpstreamSecurity()
|
|
if upstreamSecurity.ID != "" {
|
|
authInfo = &ProxyAuthInfo{
|
|
SecuritySchemeID: upstreamSecurity.ID,
|
|
PassthroughCredential: passthroughCredential,
|
|
Server: s,
|
|
}
|
|
}
|
|
|
|
// This will handle initialization asynchronously if needed and use ActionPause/Resume
|
|
return handler.ForwardToolsList(wrapperCtx, cursor, authInfo)
|
|
}
|
|
|
|
// McpProxyTool implements Tool interface for MCP-to-MCP proxy
|
|
type McpProxyTool struct {
|
|
serverName string
|
|
name string
|
|
toolConfig McpProxyToolConfig
|
|
arguments map[string]interface{}
|
|
}
|
|
|
|
// Create implements Tool interface
|
|
func (t *McpProxyTool) Create(params []byte) Tool {
|
|
newTool := &McpProxyTool{
|
|
serverName: t.serverName,
|
|
name: t.name,
|
|
toolConfig: t.toolConfig,
|
|
arguments: make(map[string]interface{}),
|
|
}
|
|
|
|
if len(params) > 0 {
|
|
json.Unmarshal(params, &newTool.arguments)
|
|
}
|
|
|
|
return newTool
|
|
}
|
|
|
|
// Call implements Tool interface - this is where the MCP protocol handling happens
|
|
func (t *McpProxyTool) Call(httpCtx HttpContext, server Server) error {
|
|
ctx := httpCtx.(wrapper.HttpContext)
|
|
|
|
// Get proxy server instance to access configuration
|
|
proxyServer, ok := server.(*McpProxyServer)
|
|
if !ok {
|
|
return fmt.Errorf("server is not a McpProxyServer")
|
|
}
|
|
|
|
// Handle tool-level or default downstream security: extract credential for passthrough if configured
|
|
// toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity
|
|
passthroughCredential := ""
|
|
var downstreamSecurity SecurityRequirement
|
|
if t.toolConfig.Security.ID != "" {
|
|
// Use tool-level security if configured
|
|
downstreamSecurity = t.toolConfig.Security
|
|
log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
|
|
} else {
|
|
// Fall back to server's default downstream security
|
|
downstreamSecurity = proxyServer.GetDefaultDownstreamSecurity()
|
|
if downstreamSecurity.ID != "" {
|
|
log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
|
|
}
|
|
}
|
|
|
|
if downstreamSecurity.ID != "" {
|
|
clientScheme, schemeOk := proxyServer.GetSecurityScheme(downstreamSecurity.ID)
|
|
if !schemeOk {
|
|
log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name)
|
|
} else {
|
|
// Extract and remove the credential from the incoming request
|
|
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
|
|
if err != nil {
|
|
log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err)
|
|
} else if extractedCred == "" {
|
|
log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID)
|
|
}
|
|
|
|
// Only use passthrough if explicitly configured
|
|
if downstreamSecurity.Passthrough && extractedCred != "" {
|
|
passthroughCredential = extractedCred
|
|
log.Debugf("Passthrough credential set for tool %s.", t.name)
|
|
}
|
|
}
|
|
} else {
|
|
// Fallback: Remove Authorization header if no downstream security is defined
|
|
// This prevents downstream credentials from being mistakenly passed to upstream
|
|
// Unless passthroughAuthHeader is explicitly set to true
|
|
if !proxyServer.GetPassthroughAuthHeader() {
|
|
proxywasm.RemoveHttpRequestHeader("Authorization")
|
|
}
|
|
}
|
|
|
|
// Create protocol handler using server fields
|
|
handler := NewMcpProtocolHandler(proxyServer.GetMcpServerURL(), proxyServer.GetTimeout())
|
|
|
|
// Prepare authentication information for gateway-to-backend communication
|
|
// toolConfig.RequestTemplate.Security represents gateway-to-backend authentication, falls back to server's defaultUpstreamSecurity
|
|
var authInfo *ProxyAuthInfo
|
|
var upstreamSecurity SecurityRequirement
|
|
if t.toolConfig.RequestTemplate.Security.ID != "" {
|
|
// Use tool-level upstream security if configured
|
|
upstreamSecurity = t.toolConfig.RequestTemplate.Security
|
|
log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
|
|
} else {
|
|
// Fall back to server's default upstream security
|
|
upstreamSecurity = proxyServer.GetDefaultUpstreamSecurity()
|
|
if upstreamSecurity.ID != "" {
|
|
log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
|
|
}
|
|
}
|
|
|
|
if upstreamSecurity.ID != "" {
|
|
authInfo = &ProxyAuthInfo{
|
|
SecuritySchemeID: upstreamSecurity.ID,
|
|
PassthroughCredential: passthroughCredential,
|
|
Server: proxyServer,
|
|
}
|
|
}
|
|
|
|
// This will handle initialization asynchronously if needed and use ActionPause/Resume
|
|
return handler.ForwardToolsCall(ctx, t.name, t.arguments, authInfo)
|
|
}
|
|
|
|
// Description implements Tool interface
|
|
func (t *McpProxyTool) Description() string {
|
|
return t.toolConfig.Description
|
|
}
|
|
|
|
// InputSchema implements Tool interface
|
|
func (t *McpProxyTool) InputSchema() map[string]any {
|
|
schema := map[string]any{
|
|
"type": "object",
|
|
"properties": make(map[string]any),
|
|
"required": []string{},
|
|
}
|
|
|
|
properties := schema["properties"].(map[string]any)
|
|
var required []string
|
|
|
|
for _, arg := range t.toolConfig.Args {
|
|
argSchema := map[string]any{
|
|
"type": arg.Type,
|
|
"description": arg.Description,
|
|
}
|
|
|
|
if arg.Default != nil {
|
|
argSchema["default"] = arg.Default
|
|
}
|
|
|
|
if len(arg.Enum) > 0 {
|
|
argSchema["enum"] = arg.Enum
|
|
}
|
|
|
|
properties[arg.Name] = argSchema
|
|
|
|
if arg.Required {
|
|
required = append(required, arg.Name)
|
|
}
|
|
}
|
|
|
|
schema["required"] = required
|
|
return schema
|
|
}
|
|
|
|
// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18)
|
|
func (t *McpProxyTool) OutputSchema() map[string]any {
|
|
return t.toolConfig.OutputSchema
|
|
}
|
|
|
|
// ValidateSecurityScheme validates a security scheme configuration
|
|
func ValidateSecurityScheme(scheme SecurityScheme) error {
|
|
if scheme.ID == "" {
|
|
return fmt.Errorf("security scheme ID is required")
|
|
}
|
|
|
|
if scheme.Type != "apiKey" && scheme.Type != "http" {
|
|
return fmt.Errorf("invalid security scheme type: %s", scheme.Type)
|
|
}
|
|
|
|
if scheme.Type == "apiKey" {
|
|
if scheme.Name == "" {
|
|
return fmt.Errorf("security scheme name is required for apiKey type")
|
|
}
|
|
if scheme.In != "header" && scheme.In != "query" && scheme.In != "cookie" {
|
|
return fmt.Errorf("invalid security scheme location: %s", scheme.In)
|
|
}
|
|
}
|
|
|
|
if scheme.Type == "http" {
|
|
if scheme.Scheme == "" {
|
|
return fmt.Errorf("security scheme scheme is required for http type")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateToolConfig validates a tool configuration
|
|
func ValidateToolConfig(config McpProxyToolConfig) error {
|
|
if config.Name == "" {
|
|
return fmt.Errorf("tool name is required")
|
|
}
|
|
|
|
if config.Description == "" {
|
|
return fmt.Errorf("tool description is required")
|
|
}
|
|
|
|
// Validate arguments
|
|
argNames := make(map[string]bool)
|
|
for _, arg := range config.Args {
|
|
if arg.Name == "" {
|
|
return fmt.Errorf("argument name is required")
|
|
}
|
|
|
|
if argNames[arg.Name] {
|
|
return fmt.Errorf("duplicate argument name: %s", arg.Name)
|
|
}
|
|
argNames[arg.Name] = true
|
|
|
|
if arg.Description == "" {
|
|
return fmt.Errorf("argument description is required for %s", arg.Name)
|
|
}
|
|
|
|
validTypes := []string{"string", "number", "integer", "boolean", "array", "object"}
|
|
validType := false
|
|
for _, t := range validTypes {
|
|
if arg.Type == t {
|
|
validType = true
|
|
break
|
|
}
|
|
}
|
|
if !validType {
|
|
return fmt.Errorf("invalid argument type %s for %s", arg.Type, arg.Name)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|