Files
higress/plugins/wasm-go/pkg/mcp/server/proxy_server.go

501 lines
17 KiB
Go

// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// McpProxyConfig represents the configuration for MCP proxy server
// Note: mcpServerURL, timeout, defaultDownstreamSecurity, and defaultUpstreamSecurity
// are now direct server fields, not part of this config structure
type McpProxyConfig struct {
// This structure is kept for any additional server configuration that may be needed in the future
// Currently, most configuration is handled as direct server fields
}
// TransportProtocol represents the transport protocol type for MCP proxy
type TransportProtocol string
const (
TransportHTTP TransportProtocol = "http" // StreamableHTTP protocol
TransportSSE TransportProtocol = "sse" // SSE protocol
)
// ToolArg represents an argument for a proxy tool
type ToolArg struct {
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Required bool `json:"required"`
Default interface{} `json:"default,omitempty"`
Enum []interface{} `json:"enum,omitempty"`
}
// McpProxyToolConfig represents a tool configuration for MCP proxy
type McpProxyToolConfig struct {
Name string `json:"name"`
Description string `json:"description"`
Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server
Args []ToolArg `json:"args"`
OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18
RequestTemplate RequestTemplate `json:"requestTemplate,omitempty"`
}
// RequestTemplate defines request template configuration for proxy tools
type RequestTemplate struct {
Security SecurityRequirement `json:"security,omitempty"`
}
// McpProxyServer implements Server interface for MCP-to-MCP proxy
type McpProxyServer struct {
Name string
base BaseMCPServer
toolsConfig map[string]McpProxyToolConfig
securitySchemes map[string]SecurityScheme
defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication
defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication
mcpServerURL string // Backend MCP server URL
timeout int // Request timeout in milliseconds
transport TransportProtocol // Transport protocol (http or sse)
passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security
}
// NewMcpProxyServer creates a new MCP proxy server
func NewMcpProxyServer(name string) *McpProxyServer {
return &McpProxyServer{
Name: name,
base: NewBaseMCPServer(),
toolsConfig: make(map[string]McpProxyToolConfig),
securitySchemes: make(map[string]SecurityScheme),
}
}
// AddSecurityScheme adds a security scheme to the server's map
func (s *McpProxyServer) AddSecurityScheme(scheme SecurityScheme) {
if s.securitySchemes == nil {
s.securitySchemes = make(map[string]SecurityScheme)
}
s.securitySchemes[scheme.ID] = scheme
}
// GetSecurityScheme retrieves a security scheme by its ID from the map
func (s *McpProxyServer) GetSecurityScheme(id string) (SecurityScheme, bool) {
scheme, ok := s.securitySchemes[id]
return scheme, ok
}
// SetDefaultDownstreamSecurity sets the default downstream security configuration
func (s *McpProxyServer) SetDefaultDownstreamSecurity(security SecurityRequirement) {
s.defaultDownstreamSecurity = security
}
// GetDefaultDownstreamSecurity gets the default downstream security configuration
func (s *McpProxyServer) GetDefaultDownstreamSecurity() SecurityRequirement {
return s.defaultDownstreamSecurity
}
// SetDefaultUpstreamSecurity sets the default upstream security configuration
func (s *McpProxyServer) SetDefaultUpstreamSecurity(security SecurityRequirement) {
s.defaultUpstreamSecurity = security
}
// GetDefaultUpstreamSecurity gets the default upstream security configuration
func (s *McpProxyServer) GetDefaultUpstreamSecurity() SecurityRequirement {
return s.defaultUpstreamSecurity
}
// SetMcpServerURL sets the backend MCP server URL
func (s *McpProxyServer) SetMcpServerURL(url string) {
s.mcpServerURL = url
}
// GetMcpServerURL gets the backend MCP server URL
func (s *McpProxyServer) GetMcpServerURL() string {
return s.mcpServerURL
}
// SetTimeout sets the request timeout in milliseconds
func (s *McpProxyServer) SetTimeout(timeout int) {
s.timeout = timeout
}
// GetTimeout gets the request timeout in milliseconds
func (s *McpProxyServer) GetTimeout() int {
return s.timeout
}
// SetTransport sets the transport protocol
func (s *McpProxyServer) SetTransport(transport TransportProtocol) {
s.transport = transport
}
// GetTransport gets the transport protocol
func (s *McpProxyServer) GetTransport() TransportProtocol {
return s.transport
}
// AddMCPTool implements Server interface
func (s *McpProxyServer) AddMCPTool(name string, tool Tool) Server {
s.base.AddMCPTool(name, tool)
return s
}
// AddProxyTool adds a proxy tool configuration
func (s *McpProxyServer) AddProxyTool(toolConfig McpProxyToolConfig) error {
s.toolsConfig[toolConfig.Name] = toolConfig
s.base.AddMCPTool(toolConfig.Name, &McpProxyTool{
serverName: s.Name,
name: toolConfig.Name,
toolConfig: toolConfig,
})
return nil
}
// GetMCPTools implements Server interface
func (s *McpProxyServer) GetMCPTools() map[string]Tool {
return s.base.GetMCPTools()
}
// SetConfig implements Server interface
func (s *McpProxyServer) SetConfig(config []byte) {
s.base.SetConfig(config)
}
// GetConfig implements Server interface
func (s *McpProxyServer) GetConfig(v any) {
s.base.GetConfig(v)
}
// Clone implements Server interface
func (s *McpProxyServer) Clone() Server {
newServer := &McpProxyServer{
Name: s.Name,
base: s.base.CloneBase(),
toolsConfig: make(map[string]McpProxyToolConfig),
securitySchemes: make(map[string]SecurityScheme),
}
for k, v := range s.toolsConfig {
newServer.toolsConfig[k] = v
}
// Deep copy securitySchemes
if s.securitySchemes != nil {
for k, v := range s.securitySchemes {
newServer.securitySchemes[k] = v
}
}
return newServer
}
// GetToolConfig returns the proxy tool configuration for a given tool name
func (s *McpProxyServer) GetToolConfig(name string) (McpProxyToolConfig, bool) {
config, ok := s.toolsConfig[name]
return config, ok
}
// SetPassthroughAuthHeader sets the passthrough auth header flag
func (s *McpProxyServer) SetPassthroughAuthHeader(passthrough bool) {
s.passthroughAuthHeader = passthrough
}
// GetPassthroughAuthHeader gets the passthrough auth header flag
func (s *McpProxyServer) GetPassthroughAuthHeader() bool {
return s.passthroughAuthHeader
}
// ForwardToolsList forwards tools/list request to backend MCP server
func (s *McpProxyServer) ForwardToolsList(ctx HttpContext, cursor *string) error {
wrapperCtx := ctx.(wrapper.HttpContext)
// Handle default downstream security for tools/list requests
// tools/list requests use server-level default authentication configuration
passthroughCredential := ""
downstreamSecurity := s.GetDefaultDownstreamSecurity()
if downstreamSecurity.ID != "" {
clientScheme, schemeOk := s.GetSecurityScheme(downstreamSecurity.ID)
if !schemeOk {
log.Warnf("Default downstream security scheme ID '%s' not found for tools/list request.", downstreamSecurity.ID)
} else {
// Extract and remove the credential from the incoming request
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
if err != nil {
log.Warnf("Failed to extract/remove incoming credential for tools/list using scheme %s: %v", clientScheme.ID, err)
} else if extractedCred == "" {
log.Debugf("No incoming credential found for tools/list using scheme %s for extraction/removal.", clientScheme.ID)
}
// Only use passthrough if explicitly configured
if downstreamSecurity.Passthrough && extractedCred != "" {
passthroughCredential = extractedCred
log.Debugf("Passthrough credential set for tools/list request.")
}
}
} else {
// Fallback: Remove Authorization header if no downstream security is defined
// This prevents downstream credentials from being mistakenly passed to upstream
// Unless passthroughAuthHeader is explicitly set to true
if !s.GetPassthroughAuthHeader() {
proxywasm.RemoveHttpRequestHeader("Authorization")
}
}
// Create protocol handler using server fields
handler := NewMcpProtocolHandler(s.GetMcpServerURL(), s.GetTimeout())
// Prepare authentication information for gateway-to-backend communication
var authInfo *ProxyAuthInfo
upstreamSecurity := s.GetDefaultUpstreamSecurity()
if upstreamSecurity.ID != "" {
authInfo = &ProxyAuthInfo{
SecuritySchemeID: upstreamSecurity.ID,
PassthroughCredential: passthroughCredential,
Server: s,
}
}
// This will handle initialization asynchronously if needed and use ActionPause/Resume
return handler.ForwardToolsList(wrapperCtx, cursor, authInfo)
}
// McpProxyTool implements Tool interface for MCP-to-MCP proxy
type McpProxyTool struct {
serverName string
name string
toolConfig McpProxyToolConfig
arguments map[string]interface{}
}
// Create implements Tool interface
func (t *McpProxyTool) Create(params []byte) Tool {
newTool := &McpProxyTool{
serverName: t.serverName,
name: t.name,
toolConfig: t.toolConfig,
arguments: make(map[string]interface{}),
}
if len(params) > 0 {
json.Unmarshal(params, &newTool.arguments)
}
return newTool
}
// Call implements Tool interface - this is where the MCP protocol handling happens
func (t *McpProxyTool) Call(httpCtx HttpContext, server Server) error {
ctx := httpCtx.(wrapper.HttpContext)
// Get proxy server instance to access configuration
proxyServer, ok := server.(*McpProxyServer)
if !ok {
return fmt.Errorf("server is not a McpProxyServer")
}
// Handle tool-level or default downstream security: extract credential for passthrough if configured
// toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity
passthroughCredential := ""
var downstreamSecurity SecurityRequirement
if t.toolConfig.Security.ID != "" {
// Use tool-level security if configured
downstreamSecurity = t.toolConfig.Security
log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
} else {
// Fall back to server's default downstream security
downstreamSecurity = proxyServer.GetDefaultDownstreamSecurity()
if downstreamSecurity.ID != "" {
log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
}
}
if downstreamSecurity.ID != "" {
clientScheme, schemeOk := proxyServer.GetSecurityScheme(downstreamSecurity.ID)
if !schemeOk {
log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name)
} else {
// Extract and remove the credential from the incoming request
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
if err != nil {
log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err)
} else if extractedCred == "" {
log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID)
}
// Only use passthrough if explicitly configured
if downstreamSecurity.Passthrough && extractedCred != "" {
passthroughCredential = extractedCred
log.Debugf("Passthrough credential set for tool %s.", t.name)
}
}
} else {
// Fallback: Remove Authorization header if no downstream security is defined
// This prevents downstream credentials from being mistakenly passed to upstream
// Unless passthroughAuthHeader is explicitly set to true
if !proxyServer.GetPassthroughAuthHeader() {
proxywasm.RemoveHttpRequestHeader("Authorization")
}
}
// Create protocol handler using server fields
handler := NewMcpProtocolHandler(proxyServer.GetMcpServerURL(), proxyServer.GetTimeout())
// Prepare authentication information for gateway-to-backend communication
// toolConfig.RequestTemplate.Security represents gateway-to-backend authentication, falls back to server's defaultUpstreamSecurity
var authInfo *ProxyAuthInfo
var upstreamSecurity SecurityRequirement
if t.toolConfig.RequestTemplate.Security.ID != "" {
// Use tool-level upstream security if configured
upstreamSecurity = t.toolConfig.RequestTemplate.Security
log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
} else {
// Fall back to server's default upstream security
upstreamSecurity = proxyServer.GetDefaultUpstreamSecurity()
if upstreamSecurity.ID != "" {
log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
}
}
if upstreamSecurity.ID != "" {
authInfo = &ProxyAuthInfo{
SecuritySchemeID: upstreamSecurity.ID,
PassthroughCredential: passthroughCredential,
Server: proxyServer,
}
}
// This will handle initialization asynchronously if needed and use ActionPause/Resume
return handler.ForwardToolsCall(ctx, t.name, t.arguments, authInfo)
}
// Description implements Tool interface
func (t *McpProxyTool) Description() string {
return t.toolConfig.Description
}
// InputSchema implements Tool interface
func (t *McpProxyTool) InputSchema() map[string]any {
schema := map[string]any{
"type": "object",
"properties": make(map[string]any),
"required": []string{},
}
properties := schema["properties"].(map[string]any)
var required []string
for _, arg := range t.toolConfig.Args {
argSchema := map[string]any{
"type": arg.Type,
"description": arg.Description,
}
if arg.Default != nil {
argSchema["default"] = arg.Default
}
if len(arg.Enum) > 0 {
argSchema["enum"] = arg.Enum
}
properties[arg.Name] = argSchema
if arg.Required {
required = append(required, arg.Name)
}
}
schema["required"] = required
return schema
}
// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18)
func (t *McpProxyTool) OutputSchema() map[string]any {
return t.toolConfig.OutputSchema
}
// ValidateSecurityScheme validates a security scheme configuration
func ValidateSecurityScheme(scheme SecurityScheme) error {
if scheme.ID == "" {
return fmt.Errorf("security scheme ID is required")
}
if scheme.Type != "apiKey" && scheme.Type != "http" {
return fmt.Errorf("invalid security scheme type: %s", scheme.Type)
}
if scheme.Type == "apiKey" {
if scheme.Name == "" {
return fmt.Errorf("security scheme name is required for apiKey type")
}
if scheme.In != "header" && scheme.In != "query" && scheme.In != "cookie" {
return fmt.Errorf("invalid security scheme location: %s", scheme.In)
}
}
if scheme.Type == "http" {
if scheme.Scheme == "" {
return fmt.Errorf("security scheme scheme is required for http type")
}
}
return nil
}
// ValidateToolConfig validates a tool configuration
func ValidateToolConfig(config McpProxyToolConfig) error {
if config.Name == "" {
return fmt.Errorf("tool name is required")
}
if config.Description == "" {
return fmt.Errorf("tool description is required")
}
// Validate arguments
argNames := make(map[string]bool)
for _, arg := range config.Args {
if arg.Name == "" {
return fmt.Errorf("argument name is required")
}
if argNames[arg.Name] {
return fmt.Errorf("duplicate argument name: %s", arg.Name)
}
argNames[arg.Name] = true
if arg.Description == "" {
return fmt.Errorf("argument description is required for %s", arg.Name)
}
validTypes := []string{"string", "number", "integer", "boolean", "array", "object"}
validType := false
for _, t := range validTypes {
if arg.Type == t {
validType = true
break
}
}
if !validType {
return fmt.Errorf("invalid argument type %s for %s", arg.Type, arg.Name)
}
}
return nil
}