Files
higress/plugins/wasm-go/pkg/mcp/server/config_validator_test.go

329 lines
7.7 KiB
Go

// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"fmt"
"os"
"testing"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
// testLogger is a mock logger for testing to prevent panics
type testLogger struct{}
func (l *testLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) }
func (l *testLogger) Tracef(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...)
}
func (l *testLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) }
func (l *testLogger) Debugf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...)
}
func (l *testLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) }
func (l *testLogger) Infof(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...)
}
func (l *testLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) }
func (l *testLogger) Warnf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...)
}
func (l *testLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) }
func (l *testLogger) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...)
}
func (l *testLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) }
func (l *testLogger) Criticalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...)
}
func (l *testLogger) ResetID(pluginID string) {}
func init() {
// Set a custom logger for testing to prevent panics
log.SetPluginLog(&testLogger{})
}
// TestMcpProxyConfigValidation tests configuration validation for mcp-proxy servers
func TestMcpProxyConfigValidation(t *testing.T) {
tests := []struct {
name string
config string
shouldErr bool
errMsg string
}{
{
name: "valid basic proxy config",
config: `{
"server": {
"name": "test-proxy",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "http://backend.example.com/mcp",
"timeout": 5000
},
"tools": [
{
"name": "test-tool",
"description": "Test tool",
"args": [
{
"name": "input",
"description": "Input parameter",
"type": "string",
"required": true
}
]
}
]
}`,
shouldErr: false,
},
{
name: "proxy config with security schemes",
config: `{
"server": {
"name": "secure-proxy",
"type": "mcp-proxy",
"transport": "http",
"mcpServerURL": "https://secure.example.com/mcp",
"timeout": 8000,
"securitySchemes": [
{
"id": "ApiKeyAuth",
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
"defaultCredential": "test-key"
}
]
},
"tools": [
{
"name": "secure-tool",
"description": "Secure tool",
"args": [
{
"name": "data",
"description": "Data parameter",
"type": "object",
"required": true
}
],
"requestTemplate": {
"security": {
"id": "ApiKeyAuth"
}
}
}
]
}`,
shouldErr: false,
},
{
name: "missing mcpServerURL should fail",
config: `{
"server": {
"name": "invalid-proxy",
"type": "mcp-proxy",
"transport": "http",
"timeout": 5000
},
"tools": [
{
"name": "test-tool",
"description": "Test tool",
"args": []
}
]
}`,
shouldErr: true,
errMsg: "mcpServerURL is required",
},
{
name: "invalid server type should use default REST handling",
config: `{
"server": {
"name": "rest-server",
"type": "rest-api"
},
"tools": [
{
"name": "rest-tool",
"description": "REST tool",
"args": [],
"requestTemplate": {
"url": "http://example.com/api",
"method": "GET"
},
"responseTemplate": {
"body": "$.result"
}
}
]
}`,
shouldErr: false, // Should fall back to REST server logic
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configJson := gjson.Parse(tt.config)
config := &McpServerConfig{}
// Create validation options (similar to validator package)
toolRegistry := &GlobalToolRegistry{}
toolRegistry.Initialize()
opts := &ConfigOptions{
Servers: make(map[string]Server),
ToolRegistry: toolRegistry,
SkipPreRegisteredServers: true,
}
err := ParseConfigCore(configJson, config, opts)
if tt.shouldErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
assert.NotNil(t, config)
}
})
}
}
// TestSecuritySchemeValidation tests security scheme configuration validation
func TestSecuritySchemeValidation(t *testing.T) {
tests := []struct {
name string
scheme SecurityScheme
shouldErr bool
}{
{
name: "valid API key scheme",
scheme: SecurityScheme{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
shouldErr: false,
},
{
name: "valid HTTP bearer scheme",
scheme: SecurityScheme{
ID: "BearerAuth",
Type: "http",
Scheme: "bearer",
},
shouldErr: false,
},
{
name: "invalid scheme - missing ID",
scheme: SecurityScheme{
Type: "apiKey",
In: "header",
Name: "X-API-Key",
},
shouldErr: true,
},
{
name: "invalid scheme - missing Name for apiKey",
scheme: SecurityScheme{
ID: "ApiKeyAuth",
Type: "apiKey",
In: "header",
},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This will test the validation logic once SecurityScheme validation is implemented
err := ValidateSecurityScheme(tt.scheme)
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestToolConfigValidation tests tool configuration validation
func TestToolConfigValidation(t *testing.T) {
tests := []struct {
name string
toolCfg McpProxyToolConfig
shouldErr bool
}{
{
name: "valid tool config",
toolCfg: McpProxyToolConfig{
Name: "valid-tool",
Description: "A valid tool",
Args: []ToolArg{
{
Name: "param1",
Description: "Parameter 1",
Type: "string",
Required: true,
},
},
},
shouldErr: false,
},
{
name: "invalid tool - missing name",
toolCfg: McpProxyToolConfig{
Description: "Tool without name",
Args: []ToolArg{},
},
shouldErr: true,
},
{
name: "invalid tool - empty description",
toolCfg: McpProxyToolConfig{
Name: "tool-no-desc",
Description: "",
Args: []ToolArg{},
},
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateToolConfig(tt.toolCfg)
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// These validation functions are now implemented in proxy_server.go