mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 21:21:01 +08:00
Compare commits
56 Commits
v2.2.0
...
add-releas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b446651dd3 | ||
|
|
b3fb6324a4 | ||
|
|
8576128e4c | ||
|
|
caa5317723 | ||
|
|
093ef9a2c0 | ||
|
|
9346f1340b | ||
|
|
87c6cc9c9f | ||
|
|
ac29ba6984 | ||
|
|
1c847dd553 | ||
|
|
a07f5024a9 | ||
|
|
814c3307ba | ||
|
|
b76a3aca5e | ||
|
|
28df33c596 | ||
|
|
8e7292c42e | ||
|
|
d03932b3ea | ||
|
|
5a2ff8c836 | ||
|
|
6f8ef2ff69 | ||
|
|
67e2913f3d | ||
|
|
e996194228 | ||
|
|
95f86d7ab5 | ||
|
|
5d5d20df1f | ||
|
|
1ddc07992c | ||
|
|
13ed2284ae | ||
|
|
f9c7527753 | ||
|
|
c2be0e8c9a | ||
|
|
927fb52309 | ||
|
|
c0761c4553 | ||
|
|
4f857597da | ||
|
|
0d45ce755f | ||
|
|
44d688a168 | ||
|
|
0d9354da16 | ||
|
|
65834bff21 | ||
|
|
668c2b3669 | ||
|
|
ff4de901e7 | ||
|
|
a1967adb94 | ||
|
|
f6cb3031fe | ||
|
|
d4a0665957 | ||
|
|
2c7771da42 | ||
|
|
75c6fbe090 | ||
|
|
b153d08610 | ||
|
|
de633d8610 | ||
|
|
f2e4942f00 | ||
|
|
1b3a8b762b | ||
|
|
c885b89d03 | ||
|
|
ce4dff9887 | ||
|
|
6935a44d53 | ||
|
|
b33e2be5e9 | ||
|
|
d2385f1b30 | ||
|
|
ef5e3ee31b | ||
|
|
d2b0885236 | ||
|
|
6cb48247fd | ||
|
|
773f639260 | ||
|
|
fe58ce3943 | ||
|
|
0dbc056ce9 | ||
|
|
3bf39b60ea | ||
|
|
e9bb5d3255 |
@@ -1,435 +0,0 @@
|
||||
---
|
||||
name: higress-clawdbot-integration
|
||||
description: "Deploy and configure Higress AI Gateway for Clawdbot/OpenClaw integration. Use when: (1) User wants to deploy Higress AI Gateway, (2) User wants to configure Clawdbot/OpenClaw to use Higress as a model provider, (3) User mentions 'higress', 'ai gateway', 'model gateway', 'AI网关', (4) User wants to set up model routing or auto-routing, (5) User needs to manage LLM provider API keys, (6) User wants to track token usage and conversation history."
|
||||
---
|
||||
|
||||
# Higress AI Gateway Integration
|
||||
|
||||
Deploy and configure Higress AI Gateway for Clawdbot/OpenClaw integration with one-click deployment, model provider configuration, auto-routing, and session monitoring.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker installed and running
|
||||
- Internet access to download setup script
|
||||
- LLM provider API keys (at least one)
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Download Setup Script
|
||||
|
||||
Download official get-ai-gateway.sh script:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/higress-group/higress-standalone/main/all-in-one/get-ai-gateway.sh -o get-ai-gateway.sh
|
||||
chmod +x get-ai-gateway.sh
|
||||
```
|
||||
|
||||
### Step 2: Gather Configuration
|
||||
|
||||
Ask user for:
|
||||
|
||||
1. **LLM Provider API Keys** (at least one required):
|
||||
|
||||
**Top Commonly Used Providers:**
|
||||
- Aliyun Dashscope (Qwen): `--dashscope-key`
|
||||
- DeepSeek: `--deepseek-key`
|
||||
- Moonshot (Kimi): `--moonshot-key`
|
||||
- Zhipu AI: `--zhipuai-key`
|
||||
- Claude Code (OAuth mode): `--claude-code-key` (run `claude setup-token` to get token)
|
||||
- Claude: `--claude-key`
|
||||
- Minimax: `--minimax-key`
|
||||
- Azure OpenAI: `--azure-key`
|
||||
- AWS Bedrock: `--bedrock-key`
|
||||
- Google Vertex AI: `--vertex-key`
|
||||
- OpenAI: `--openai-key`
|
||||
- OpenRouter: `--openrouter-key`
|
||||
- Grok: `--grok-key`
|
||||
|
||||
To configure additional providers beyond the above, run `./get-ai-gateway.sh --help` to view the complete list of supported models and providers.
|
||||
|
||||
2. **Port Configuration** (optional):
|
||||
- HTTP port: `--http-port` (default: 8080)
|
||||
- HTTPS port: `--https-port` (default: 8443)
|
||||
- Console port: `--console-port` (default: 8001)
|
||||
|
||||
3. **Auto-routing** (optional):
|
||||
- Enable: `--auto-routing`
|
||||
- Default model: `--auto-routing-default-model`
|
||||
|
||||
### Step 3: Run Setup Script
|
||||
|
||||
Run script in non-interactive mode with gathered parameters:
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx \
|
||||
--openai-key sk-xxx \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model qwen-turbo
|
||||
```
|
||||
|
||||
**Automatic Repository Selection:**
|
||||
|
||||
The script automatically detects your timezone and selects the geographically closest registry for both:
|
||||
- **Container image** (`IMAGE_REPO`)
|
||||
- **WASM plugins** (`PLUGIN_REGISTRY`)
|
||||
|
||||
| Region | Timezone Examples | Selected Registry |
|
||||
|--------|------------------|-------------------|
|
||||
| China & nearby | Asia/Shanghai, Asia/Hong_Kong, etc. | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
| Southeast Asia | Asia/Singapore, Asia/Jakarta, etc. | `higress-registry.ap-southeast-7.cr.aliyuncs.com` |
|
||||
| North America | America/*, US/*, Canada/* | `higress-registry.us-west-1.cr.aliyuncs.com` |
|
||||
| Others | Default fallback | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
|
||||
**Manual Override (optional):**
|
||||
|
||||
If you want to use a specific registry:
|
||||
|
||||
```bash
|
||||
IMAGE_REPO="higress-registry.ap-southeast-7.cr.aliyuncs.com/higress/all-in-one" \
|
||||
PLUGIN_REGISTRY="higress-registry.ap-southeast-7.cr.aliyuncs.com" \
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx \
|
||||
--openai-key sk-xxx
|
||||
```
|
||||
|
||||
### Step 4: Verify Deployment
|
||||
|
||||
After script completion:
|
||||
|
||||
1. Check container is running:
|
||||
```bash
|
||||
docker ps --filter "name=higress-ai-gateway"
|
||||
```
|
||||
|
||||
2. Test gateway endpoint:
|
||||
```bash
|
||||
curl http://localhost:8080/v1/models
|
||||
```
|
||||
|
||||
3. Access console (optional):
|
||||
```
|
||||
http://localhost:8001
|
||||
```
|
||||
|
||||
### Step 5: Configure Clawdbot/OpenClaw Plugin
|
||||
|
||||
If user wants to use Higress with Clawdbot/OpenClaw, install appropriate plugin:
|
||||
|
||||
#### Automatic Installation
|
||||
|
||||
Detect runtime and install correct plugin version:
|
||||
|
||||
```bash
|
||||
# Detect which runtime is installed
|
||||
if command -v clawdbot &> /dev/null; then
|
||||
RUNTIME="clawdbot"
|
||||
RUNTIME_DIR="$HOME/.clawdbot"
|
||||
PLUGIN_SRC="scripts/plugin-clawdbot"
|
||||
elif command -v openclaw &> /dev/null; then
|
||||
RUNTIME="openclaw"
|
||||
RUNTIME_DIR="$HOME/.openclaw"
|
||||
PLUGIN_SRC="scripts/plugin"
|
||||
else
|
||||
echo "Error: Neither clawdbot nor openclaw is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install plugin
|
||||
PLUGIN_DEST="$RUNTIME_DIR/extensions/higress-ai-gateway"
|
||||
echo "Installing Higress AI Gateway plugin for $RUNTIME..."
|
||||
mkdir -p "$(dirname "$PLUGIN_DEST")"
|
||||
[ -d "$PLUGIN_DEST" ] && rm -rf "$PLUGIN_DEST"
|
||||
cp -r "$PLUGIN_SRC" "$PLUGIN_DEST"
|
||||
echo "✓ Plugin installed at: $PLUGIN_DEST"
|
||||
|
||||
# Configure provider
|
||||
echo ""
|
||||
echo "Configuring provider..."
|
||||
$RUNTIME models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will guide you through an interactive setup for:
|
||||
1. Gateway URL (default: `http://localhost:8080`)
|
||||
2. Console URL (default: `http://localhost:8001`)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using `higress/auto`)
|
||||
|
||||
### Step 6: Manage API Keys (optional)
|
||||
|
||||
After deployment, manage API keys without redeploying:
|
||||
|
||||
```bash
|
||||
# View configured API keys
|
||||
./get-ai-gateway.sh config list
|
||||
# Add or update an API key (hot-reload, no restart needed)
|
||||
./get-ai-gateway.sh config add --provider <provider> --key <api-key>
|
||||
# Remove an API key (hot-reload, no restart needed)
|
||||
./get-ai-gateway.sh config remove --provider <provider>
|
||||
```
|
||||
|
||||
**Note:** Changes take effect immediately via hot-reload. No container restart required.
|
||||
|
||||
## CLI Parameters Reference
|
||||
|
||||
### Basic Options
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--non-interactive` | Run without prompts | - |
|
||||
| `--http-port` | Gateway HTTP port | 8080 |
|
||||
| `--https-port` | Gateway HTTPS port | 8443 |
|
||||
| `--console-port` | Console port | 8001 |
|
||||
| `--container-name` | Container name | higress-ai-gateway |
|
||||
| `--data-folder` | Data folder path | ./higress |
|
||||
| `--auto-routing` | Enable auto-routing feature | - |
|
||||
| `--auto-routing-default-model` | Default model when no rule matches | - |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `PLUGIN_REGISTRY` | Registry URL for container images and WASM plugins (auto-selected based on timezone) | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
|
||||
**Auto-Selection Logic:**
|
||||
|
||||
The registry is automatically selected based on your timezone:
|
||||
|
||||
- **China & nearby** (Asia/Shanghai, etc.) → `higress-registry.cn-hangzhou.cr.aliyuncs.com`
|
||||
- **Southeast Asia** (Asia/Singapore, etc.) → `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
- **North America** (America/*, etc.) → `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
- **Others** → `higress-registry.cn-hangzhou.cr.aliyuncs.com` (default)
|
||||
|
||||
Both container images and WASM plugins use the same registry for consistency.
|
||||
|
||||
**Manual Override:**
|
||||
|
||||
```bash
|
||||
PLUGIN_REGISTRY="higress-registry.ap-southeast-7.cr.aliyuncs.com" \
|
||||
./get-ai-gateway.sh start --non-interactive ...
|
||||
```
|
||||
|
||||
### LLM Provider API Keys
|
||||
|
||||
**Top Providers:**
|
||||
|
||||
| Parameter | Provider |
|
||||
|-----------|----------|
|
||||
| `--dashscope-key` | Aliyun Dashscope (Qwen) |
|
||||
| `--deepseek-key` | DeepSeek |
|
||||
| `--moonshot-key` | Moonshot (Kimi) |
|
||||
| `--zhipuai-key` | Zhipu AI |
|
||||
| `--claude-code-key` | Claude Code (OAuth mode - run `claude setup-token` to get token) |
|
||||
| `--claude-key` | Claude |
|
||||
| `--openai-key` | OpenAI |
|
||||
| `--openrouter-key` | OpenRouter |
|
||||
| `--gemini-key` | Google Gemini |
|
||||
| `--groq-key` | Groq |
|
||||
|
||||
**Additional Providers:**
|
||||
|
||||
`--doubao-key`, `--baichuan-key`, `--yi-key`, `--stepfun-key`, `--minimax-key`, `--cohere-key`, `--mistral-key`, `--github-key`, `--fireworks-key`, `--togetherai-key`, `--grok-key`, `--azure-key`, `--bedrock-key`, `--vertex-key`
|
||||
|
||||
## Managing Configuration
|
||||
|
||||
### API Keys
|
||||
|
||||
```bash
|
||||
# List all configured API keys
|
||||
./get-ai-gateway.sh config list
|
||||
# Add or update an API key (hot-reload)
|
||||
./get-ai-gateway.sh config add --provider deepseek --key sk-xxx
|
||||
# Remove an API key (hot-reload)
|
||||
./get-ai-gateway.sh config remove --provider deepseek
|
||||
```
|
||||
|
||||
**Supported provider aliases:**
|
||||
|
||||
`dashscope`/`qwen`, `moonshot`/`kimi`, `zhipuai`/`zhipu`, `togetherai`/`together`
|
||||
|
||||
### Routing Rules
|
||||
|
||||
```bash
|
||||
# Add a routing rule
|
||||
./get-ai-gateway.sh route add --model claude-opus-4.5 --trigger "深入思考|deep thinking"
|
||||
# List all rules
|
||||
./get-ai-gateway.sh route list
|
||||
# Remove a rule
|
||||
./get-ai-gateway.sh route remove --rule-id 0
|
||||
```
|
||||
|
||||
See [higress-auto-router](../higress-auto-router/SKILL.md) for detailed documentation.
|
||||
|
||||
## Access Logs
|
||||
|
||||
Gateway access logs are available at:
|
||||
|
||||
```
|
||||
$DATA_FOLDER/logs/access.log
|
||||
```
|
||||
|
||||
These logs can be used with **agent-session-monitor** skill for token tracking and conversation analysis.
|
||||
|
||||
## Related Skills
|
||||
|
||||
- **higress-auto-router**: Configure automatic model routing using CLI commands
|
||||
See: [higress-auto-router](../higress-auto-router/SKILL.md)
|
||||
|
||||
- **agent-session-monitor**: Monitor and track token usage across sessions
|
||||
See: [agent-session-monitor](../agent-session-monitor/SKILL.md)
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Basic Deployment with Dashscope
|
||||
|
||||
**User:** 帮我部署一个Higress AI网关,使用阿里云的通义千问
|
||||
|
||||
**Steps:**
|
||||
1. Download script
|
||||
2. Get Dashscope API key from user
|
||||
3. Run (script auto-detects timezone and selects optimal registry):
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: Asia/Shanghai
|
||||
Selected plugin registry: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 部署完成!
|
||||
|
||||
网关地址: http://localhost:8080/v1/chat/completions
|
||||
控制台: http://localhost:8001
|
||||
日志目录: ./higress/logs
|
||||
|
||||
已配置的模型提供商:
|
||||
- Aliyun Dashscope (Qwen)
|
||||
|
||||
测试命令:
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "qwen-turbo", "messages": [{"role": "user", "content": "Hello!"}]}'
|
||||
```
|
||||
|
||||
### Example 2: Full Integration with Clawdbot
|
||||
|
||||
**User:** 完整配置Higress和Clawdbot的集成
|
||||
|
||||
**Steps:**
|
||||
1. Deploy Higress AI Gateway (auto-detects timezone)
|
||||
2. Install and configure Clawdbot plugin
|
||||
3. Enable auto-routing
|
||||
4. Set up session monitoring
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: Asia/Shanghai
|
||||
Selected plugin registry: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 集成完成!
|
||||
|
||||
1. 网关已部署:
|
||||
- HTTP: http://localhost:8080
|
||||
- Console: http://localhost:8001
|
||||
- 容器镜像: Hangzhou (自动选择)
|
||||
- 插件镜像: Hangzhou (自动选择)
|
||||
|
||||
2. Clawdbot 插件配置:
|
||||
Plugin installed at: /root/.clawdbot/extensions/higress-ai-gateway
|
||||
Run: clawdbot models auth login --provider higress
|
||||
|
||||
3. 自动路由:
|
||||
已启用,使用 model="higress/auto"
|
||||
|
||||
4. 会话监控:
|
||||
日志路径: ./higress/logs/access.log
|
||||
|
||||
需要我帮你配置自动路由规则吗?
|
||||
```
|
||||
|
||||
### Example 3: Manage API Keys
|
||||
|
||||
**User:** 帮我查看当前配置的API keys,并添加一个DeepSeek的key
|
||||
|
||||
**Steps:**
|
||||
1. List current API keys:
|
||||
```bash
|
||||
./get-ai-gateway.sh config list
|
||||
```
|
||||
|
||||
2. Add DeepSeek API key:
|
||||
```bash
|
||||
./get-ai-gateway.sh config add --provider deepseek --key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
当前配置的API keys:
|
||||
|
||||
Aliyun Dashscope (Qwen): sk-ab***ef12
|
||||
OpenAI: sk-cd***gh34
|
||||
|
||||
Adding API key for DeepSeek...
|
||||
|
||||
✅ API key updated successfully!
|
||||
|
||||
Provider: DeepSeek
|
||||
Key: sk-xx***yy56
|
||||
|
||||
Configuration has been hot-reloaded (no restart needed).
|
||||
```
|
||||
|
||||
### Example 4: North America Deployment
|
||||
|
||||
**User:** 帮我部署Higress AI网关
|
||||
|
||||
**Context:** User's timezone is America/Los_Angeles
|
||||
|
||||
**Steps:**
|
||||
1. Download script
|
||||
2. Get API keys from user
|
||||
3. Run (script auto-detects timezone and selects North America mirror):
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--openai-key sk-xxx \
|
||||
--openrouter-key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: America/Los_Angeles
|
||||
Selected plugin registry: higress-registry.us-west-1.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 部署完成!
|
||||
|
||||
网关地址: http://localhost:8080/v1/chat/completions
|
||||
控制台: http://localhost:8001
|
||||
日志目录: ./higress/logs
|
||||
|
||||
镜像优化:
|
||||
- 容器镜像: North America (基于时区自动选择)
|
||||
- 插件镜像: North America (基于时区自动选择)
|
||||
|
||||
已配置的模型提供商:
|
||||
- OpenAI
|
||||
- OpenRouter
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
For detailed troubleshooting guides, see [TROUBLESHOOTING.md](references/TROUBLESHOOTING.md).
|
||||
|
||||
Common issues:
|
||||
- **Container fails to start**: Check Docker status, port availability, and container logs
|
||||
- **"too many open files" error**: Increase `fs.inotify.max_user_instances` to 8192
|
||||
- **Gateway not responding**: Verify container status and port mapping
|
||||
- **Plugin not recognized**: Check installation path and restart runtime
|
||||
- **Auto-routing not working**: Verify model list and routing rules
|
||||
- **Timezone detection fails**: Manually set `IMAGE_REPO` environment variable
|
||||
@@ -1,79 +0,0 @@
|
||||
# Higress AI Gateway Plugin (Clawdbot)
|
||||
|
||||
Clawdbot model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables Clawdbot to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and Clawdbot extension declaration
|
||||
- **clawdbot.plugin.json**: Plugin manifest for Clawdbot
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-clawdbot-integration` skill. See the parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Copy plugin files
|
||||
mkdir -p "$HOME/.clawdbot/extensions/higress-ai-gateway"
|
||||
cp -r ./* "$HOME/.clawdbot/extensions/higress-ai-gateway/"
|
||||
|
||||
# Configure provider
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
## Auto-routing
|
||||
|
||||
To use auto-routing, include `higress/auto` in your model list during configuration. Then use it in your conversations:
|
||||
|
||||
```bash
|
||||
# Use auto-routing
|
||||
clawdbot chat --model higress/auto "深入思考 这个问题应该怎么解决?"
|
||||
|
||||
# The gateway will automatically route to the appropriate model based on:
|
||||
# - Message content triggers (configured via higress-auto-router skill)
|
||||
# - Fallback to default model if no rule matches
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-clawdbot-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
- **Session Monitoring**: [agent-session-monitor](../../agent-session-monitor/SKILL.md)
|
||||
- **Higress AI Gateway**: https://github.com/higress-group/higress-standalone
|
||||
|
||||
## Compatibility
|
||||
|
||||
- **Clawdbot**: v2.0.0+
|
||||
- **Higress AI Gateway**: All versions
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"id": "higress-ai-gateway",
|
||||
"name": "Higress AI Gateway",
|
||||
"description": "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
"providers": ["higress"],
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
import { emptyPluginConfigSchema } from "clawdbot/plugin-sdk";
|
||||
|
||||
const DEFAULT_GATEWAY_URL = "http://localhost:8080";
|
||||
const DEFAULT_CONSOLE_URL = "http://localhost:8001";
|
||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||
const DEFAULT_MAX_TOKENS = 8192;
|
||||
|
||||
// Common models that Higress AI Gateway typically supports
|
||||
const DEFAULT_MODEL_IDS = [
|
||||
// Auto-routing special model
|
||||
"higress/auto",
|
||||
// OpenAI models
|
||||
"gpt-5.2",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
// Anthropic models
|
||||
"claude-opus-4.5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-haiku-4.5",
|
||||
// Qwen models
|
||||
"qwen3-turbo",
|
||||
"qwen3-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-480b-a35b-instruct",
|
||||
// DeepSeek models
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
// Other common models
|
||||
"kimi-k2.5",
|
||||
"glm-4.7",
|
||||
"MiniMax-M2.1",
|
||||
] as const;
|
||||
|
||||
function normalizeBaseUrl(value: string): string {
|
||||
const trimmed = value.trim();
|
||||
if (!trimmed) return DEFAULT_GATEWAY_URL;
|
||||
let normalized = trimmed;
|
||||
while (normalized.endsWith("/")) normalized = normalized.slice(0, -1);
|
||||
if (!normalized.endsWith("/v1")) normalized = `${normalized}/v1`;
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function validateUrl(value: string): string | undefined {
|
||||
const normalized = normalizeBaseUrl(value);
|
||||
try {
|
||||
new URL(normalized);
|
||||
} catch {
|
||||
return "Enter a valid URL";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function parseModelIds(input: string): string[] {
|
||||
const parsed = input
|
||||
.split(/[\n,]/)
|
||||
.map((model) => model.trim())
|
||||
.filter(Boolean);
|
||||
return Array.from(new Set(parsed));
|
||||
}
|
||||
|
||||
function buildModelDefinition(modelId: string) {
|
||||
const isAutoModel = modelId === "higress/auto";
|
||||
return {
|
||||
id: modelId,
|
||||
name: isAutoModel ? "Higress Auto Router" : modelId,
|
||||
api: "openai-completions",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: DEFAULT_MAX_TOKENS,
|
||||
};
|
||||
}
|
||||
|
||||
async function testGatewayConnection(gatewayUrl: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${gatewayUrl}/v1/models`, {
|
||||
method: "GET",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
return response.ok || response.status === 401; // 401 means gateway is up but needs auth
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchAvailableModels(consoleUrl: string): Promise<string[]> {
|
||||
try {
|
||||
// Try to get models from Higress Console API
|
||||
const response = await fetch(`${consoleUrl}/v1/ai/routes`, {
|
||||
method: "GET",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { data?: { model?: string }[] };
|
||||
if (data.data && Array.isArray(data.data)) {
|
||||
return data.data
|
||||
.map((route: { model?: string }) => route.model)
|
||||
.filter((m): m is string => typeof m === "string");
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, use defaults
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const higressPlugin = {
|
||||
id: "higress-ai-gateway",
|
||||
name: "Higress AI Gateway",
|
||||
description: "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
configSchema: emptyPluginConfigSchema(),
|
||||
register(api) {
|
||||
api.registerProvider({
|
||||
id: "higress",
|
||||
label: "Higress AI Gateway",
|
||||
docsPath: "/providers/models",
|
||||
aliases: ["higress-gateway", "higress-ai"],
|
||||
auth: [
|
||||
{
|
||||
id: "api-key",
|
||||
label: "API Key",
|
||||
hint: "Configure Higress AI Gateway endpoint with optional API key",
|
||||
kind: "custom",
|
||||
run: async (ctx) => {
|
||||
// Step 1: Get Gateway URL
|
||||
const gatewayUrlInput = await ctx.prompter.text({
|
||||
message: "Higress AI Gateway URL",
|
||||
initialValue: DEFAULT_GATEWAY_URL,
|
||||
validate: validateUrl,
|
||||
});
|
||||
const gatewayUrl = normalizeBaseUrl(gatewayUrlInput);
|
||||
|
||||
// Step 2: Get Console URL (for auto-router configuration)
|
||||
const consoleUrlInput = await ctx.prompter.text({
|
||||
message: "Higress Console URL (for auto-router config)",
|
||||
initialValue: DEFAULT_CONSOLE_URL,
|
||||
validate: validateUrl,
|
||||
});
|
||||
const consoleUrl = normalizeBaseUrl(consoleUrlInput);
|
||||
|
||||
// Step 3: Test connection (create a new spinner)
|
||||
const spin = ctx.prompter.progress("Testing gateway connection…");
|
||||
const isConnected = await testGatewayConnection(gatewayUrl);
|
||||
if (!isConnected) {
|
||||
spin.stop("Gateway connection failed");
|
||||
await ctx.prompter.note(
|
||||
[
|
||||
"Could not connect to Higress AI Gateway.",
|
||||
"Make sure the gateway is running and the URL is correct.",
|
||||
"",
|
||||
`Tried: ${gatewayUrl}/v1/models`,
|
||||
].join("\n"),
|
||||
"Connection Warning",
|
||||
);
|
||||
} else {
|
||||
spin.stop("Gateway connected");
|
||||
}
|
||||
|
||||
// Step 4: Get API Key (optional for local gateway)
|
||||
const apiKeyInput = await ctx.prompter.text({
|
||||
message: "API Key (leave empty if not required)",
|
||||
initialValue: "",
|
||||
}) || '';
|
||||
const apiKey = apiKeyInput.trim() || "higress-local";
|
||||
|
||||
// Step 5: Fetch available models (create a new spinner)
|
||||
const spin2 = ctx.prompter.progress("Fetching available models…");
|
||||
const fetchedModels = await fetchAvailableModels(consoleUrl);
|
||||
const defaultModels = fetchedModels.length > 0
|
||||
? ["higress/auto", ...fetchedModels]
|
||||
: DEFAULT_MODEL_IDS;
|
||||
spin2.stop();
|
||||
|
||||
// Step 6: Let user customize model list
|
||||
const modelInput = await ctx.prompter.text({
|
||||
message: "Model IDs (comma-separated, higress/auto enables auto-routing)",
|
||||
initialValue: defaultModels.slice(0, 10).join(", "),
|
||||
validate: (value) =>
|
||||
parseModelIds(value).length > 0 ? undefined : "Enter at least one model id",
|
||||
});
|
||||
|
||||
const modelIds = parseModelIds(modelInput);
|
||||
const hasAutoModel = modelIds.includes("higress/auto");
|
||||
|
||||
// FIX: Avoid double prefix - if modelId already starts with provider, don't add prefix again
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "qwen-turbo");
|
||||
const defaultModelRef = defaultModelId.startsWith("higress/")
|
||||
? defaultModelId
|
||||
: `higress/${defaultModelId}`;
|
||||
|
||||
// Step 7: Configure default model for auto-routing
|
||||
let autoRoutingDefaultModel = "qwen-turbo";
|
||||
if (hasAutoModel) {
|
||||
const autoRoutingModelInput = await ctx.prompter.text({
|
||||
message: "Default model for auto-routing (when no rule matches)",
|
||||
initialValue: "qwen-turbo",
|
||||
});
|
||||
autoRoutingDefaultModel = autoRoutingModelInput.trim(); // FIX: Add trim() here
|
||||
}
|
||||
|
||||
return {
|
||||
profiles: [
|
||||
{
|
||||
profileId: `higress:${apiKey === "higress-local" ? "local" : "default"}`,
|
||||
credential: {
|
||||
type: "token",
|
||||
provider: "higress",
|
||||
token: apiKey,
|
||||
},
|
||||
},
|
||||
],
|
||||
configPatch: {
|
||||
models: {
|
||||
providers: {
|
||||
higress: {
|
||||
baseUrl: `${gatewayUrl}/v1`,
|
||||
apiKey: apiKey,
|
||||
api: "openai-completions",
|
||||
authHeader: apiKey !== "higress-local",
|
||||
models: modelIds.map((modelId) => buildModelDefinition(modelId)),
|
||||
},
|
||||
},
|
||||
},
|
||||
agents: {
|
||||
defaults: {
|
||||
models: Object.fromEntries(
|
||||
modelIds.map((modelId) => {
|
||||
// FIX: Avoid double prefix - only add provider prefix if not already present
|
||||
const modelRef = modelId.startsWith("higress/")
|
||||
? modelId
|
||||
: `higress/${modelId}`;
|
||||
return [modelRef, {}];
|
||||
}),
|
||||
),
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
entries: {
|
||||
"higress-ai-gateway": {
|
||||
enabled: true,
|
||||
config: {
|
||||
gatewayUrl,
|
||||
consoleUrl,
|
||||
autoRoutingDefaultModel,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
defaultModel: defaultModelRef,
|
||||
notes: [
|
||||
"Higress AI Gateway is now configured as a model provider.",
|
||||
hasAutoModel
|
||||
? `Auto-routing enabled: use model "higress/auto" to route based on message content.`
|
||||
: "Add 'higress/auto' to models to enable auto-routing.",
|
||||
`Gateway endpoint: ${gatewayUrl}/v1/chat/completions`,
|
||||
`Console: ${consoleUrl}`,
|
||||
"",
|
||||
"🎯 Recommended Skills (install via Clawdbot conversation):",
|
||||
"",
|
||||
"1. Auto-Routing Skill:",
|
||||
" Configure automatic model routing based on message content",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/higress-auto-router",
|
||||
' Say: "Install higress-auto-router skill"',
|
||||
"",
|
||||
"2. Agent Session Monitor Skill:",
|
||||
" Track token usage and monitor conversation history",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/agent-session-monitor",
|
||||
' Say: "Install agent-session-monitor skill"',
|
||||
],
|
||||
};
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export default higressPlugin;
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"name": "@higress/higress-ai-gateway",
|
||||
"version": "1.0.0",
|
||||
"description": "Higress AI Gateway model provider plugin for Clawdbot with auto-routing support",
|
||||
"main": "index.ts",
|
||||
"clawdbot": {
|
||||
"extensions": ["./index.ts"]
|
||||
},
|
||||
"keywords": [
|
||||
"clawdbot",
|
||||
"higress",
|
||||
"ai-gateway",
|
||||
"model-router",
|
||||
"auto-routing"
|
||||
],
|
||||
"author": "Higress Team",
|
||||
"license": "Apache-2.0",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/alibaba/higress"
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
# Higress AI Gateway Plugin
|
||||
|
||||
OpenClaw/Clawdbot model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables Clawdbot and OpenClaw to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and OpenClaw extension declaration
|
||||
- **openclaw.plugin.json**: Plugin manifest for OpenClaw
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-clawdbot-integration` skill. See the parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Detect runtime
|
||||
if command -v clawdbot &> /dev/null; then
|
||||
RUNTIME_DIR="$HOME/.clawdbot"
|
||||
elif command -v openclaw &> /dev/null; then
|
||||
RUNTIME_DIR="$HOME/.openclaw"
|
||||
else
|
||||
echo "Error: Neither clawdbot nor openclaw is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Copy plugin files
|
||||
mkdir -p "$RUNTIME_DIR/extensions/higress-ai-gateway"
|
||||
cp -r ./* "$RUNTIME_DIR/extensions/higress-ai-gateway/"
|
||||
|
||||
# Configure provider
|
||||
clawdbot models auth login --provider higress
|
||||
# or
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
## Auto-routing
|
||||
|
||||
To use auto-routing, include `higress/auto` in your model list during configuration. Then use it in your conversations:
|
||||
|
||||
```bash
|
||||
# Use auto-routing
|
||||
clawdbot chat --model higress/auto "深入思考 这个问题应该怎么解决?"
|
||||
|
||||
# The gateway will automatically route to the appropriate model based on:
|
||||
# - Message content triggers (configured via higress-auto-router skill)
|
||||
# - Fallback to default model if no rule matches
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-clawdbot-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
- **Session Monitoring**: [agent-session-monitor](../../agent-session-monitor/SKILL.md)
|
||||
- **Higress AI Gateway**: https://github.com/higress-group/higress-standalone
|
||||
|
||||
## Compatibility
|
||||
|
||||
- **OpenClaw**: v2.0.0+
|
||||
- **Clawdbot**: v2.0.0+
|
||||
- **Higress AI Gateway**: All versions
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
259
.claude/skills/higress-openclaw-integration/SKILL.md
Normal file
259
.claude/skills/higress-openclaw-integration/SKILL.md
Normal file
@@ -0,0 +1,259 @@
|
||||
---
|
||||
name: higress-openclaw-integration
|
||||
description: "Deploy and configure Higress AI Gateway for OpenClaw integration. Use when: (1) User wants to deploy Higress AI Gateway, (2) User wants to configure OpenClaw to use more model providers, (3) User mentions 'higress', 'ai gateway', 'model gateway', 'AI网关', (4) User wants to set up model routing or auto-routing, (5) User needs to manage LLM provider API keys."
|
||||
---
|
||||
|
||||
# Higress AI Gateway Integration
|
||||
|
||||
Deploy Higress AI Gateway and configure OpenClaw to use it as a unified model provider.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Collect Information from User
|
||||
|
||||
**Ask the user for the following information upfront:**
|
||||
|
||||
1. **Which LLM provider(s) to use?** (at least one required)
|
||||
|
||||
**Commonly Used Providers:**
|
||||
|
||||
| Provider | Parameter | Notes |
|
||||
|----------|-----------|-------|
|
||||
| 智谱 / z.ai | `--zhipuai-key` | Models: glm-*, Code Plan mode enabled by default |
|
||||
| Claude Code | `--claude-code-key` | **Requires OAuth token from `claude setup-token`** |
|
||||
| Moonshot (Kimi) | `--moonshot-key` | Models: moonshot-*, kimi-* |
|
||||
| Minimax | `--minimax-key` | Models: abab-* |
|
||||
| 阿里云通义千问 (Dashscope) | `--dashscope-key` | Models: qwen* |
|
||||
| OpenAI | `--openai-key` | Models: gpt-*, o1-*, o3-* |
|
||||
| DeepSeek | `--deepseek-key` | Models: deepseep-* |
|
||||
| Grok | `--grok-key` | Models: grok-* |
|
||||
|
||||
**Other Providers:**
|
||||
|
||||
| Provider | Parameter | Notes |
|
||||
|----------|-----------|-------|
|
||||
| Claude | `--claude-key` | Models: claude-* |
|
||||
| Google Gemini | `--gemini-key` | Models: gemini-* |
|
||||
| OpenRouter | `--openrouter-key` | Supports all models (catch-all) |
|
||||
| Groq | `--groq-key` | Fast inference |
|
||||
| Doubao (豆包) | `--doubao-key` | Models: doubao-* |
|
||||
| Mistral | `--mistral-key` | Models: mistral-* |
|
||||
| Baichuan (百川) | `--baichuan-key` | Models: Baichuan* |
|
||||
| 01.AI (Yi) | `--yi-key` | Models: yi-* |
|
||||
| Stepfun (阶跃星辰) | `--stepfun-key` | Models: step-* |
|
||||
| Cohere | `--cohere-key` | Models: command* |
|
||||
| Fireworks AI | `--fireworks-key` | - |
|
||||
| Together AI | `--togetherai-key` | - |
|
||||
| GitHub Models | `--github-key` | - |
|
||||
|
||||
**Cloud Providers (require additional config):**
|
||||
- Azure OpenAI: `--azure-key` (requires service URL)
|
||||
- AWS Bedrock: `--bedrock-key` (requires region and access key)
|
||||
- Google Vertex AI: `--vertex-key` (requires project ID and region)
|
||||
|
||||
**Brand Name Display (z.ai / 智谱):**
|
||||
- If user communicates in Chinese: display as "智谱"
|
||||
- If user communicates in English: display as "z.ai"
|
||||
|
||||
2. **Enable auto-routing?** (recommended)
|
||||
- If yes: `--auto-routing --auto-routing-default-model <model-name>`
|
||||
- Auto-routing allows using `model="higress/auto"` to automatically route requests based on message content
|
||||
|
||||
3. **Custom ports?** (optional, defaults: HTTP=8080, HTTPS=8443, Console=8001)
|
||||
|
||||
### Step 2: Deploy Gateway
|
||||
|
||||
**Auto-detect region for z.ai / 智谱 domain configuration:**
|
||||
|
||||
When user selects z.ai / 智谱 provider, detect their region:
|
||||
|
||||
```bash
|
||||
# Run region detection script (scripts/detect-region.sh relative to skill directory)
|
||||
REGION=$(bash scripts/detect-region.sh)
|
||||
# Output: "china" or "international"
|
||||
```
|
||||
|
||||
**Based on detection result:**
|
||||
|
||||
- If `REGION="china"`: use default domain `open.bigmodel.cn`, no extra parameter needed
|
||||
- If `REGION="international"`: automatically add `--zhipuai-domain api.z.ai` to deployment command
|
||||
|
||||
**After deployment (for international users):**
|
||||
Notify user in English: "The z.ai endpoint domain has been set to api.z.ai. If you want to change it, let me know and I can update the configuration."
|
||||
|
||||
```bash
|
||||
# Create installation directory
|
||||
mkdir -p higress-install
|
||||
cd higress-install
|
||||
|
||||
# Download script (if not exists)
|
||||
curl -fsSL https://higress.ai/ai-gateway/install.sh -o get-ai-gateway.sh
|
||||
chmod +x get-ai-gateway.sh
|
||||
|
||||
# Deploy with user's configuration
|
||||
# For z.ai / 智谱: always include --zhipuai-code-plan-mode
|
||||
# For non-China users: include --zhipuai-domain api.z.ai
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--<provider>-key <api-key> \
|
||||
[--auto-routing --auto-routing-default-model <model>]
|
||||
```
|
||||
|
||||
**z.ai / 智谱 Options:**
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `--zhipuai-code-plan-mode` | Enable Code Plan mode (enabled by default) |
|
||||
| `--zhipuai-domain <domain>` | Custom domain, default: `open.bigmodel.cn` (China), `api.z.ai` (international) |
|
||||
|
||||
**Example (China user):**
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--zhipuai-key sk-xxx \
|
||||
--zhipuai-code-plan-mode \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model glm-5
|
||||
```
|
||||
|
||||
**Example (International user):**
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--zhipuai-key sk-xxx \
|
||||
--zhipuai-domain api.z.ai \
|
||||
--zhipuai-code-plan-mode \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model glm-5
|
||||
```
|
||||
|
||||
### Step 3: Install OpenClaw Plugin
|
||||
|
||||
Install the Higress provider plugin for OpenClaw:
|
||||
|
||||
```bash
|
||||
# Copy plugin files (PLUGIN_SRC is relative to skill directory: scripts/plugin)
|
||||
PLUGIN_SRC="scripts/plugin"
|
||||
PLUGIN_DEST="$HOME/.openclaw/extensions/higress"
|
||||
|
||||
mkdir -p "$PLUGIN_DEST"
|
||||
cp -r "$PLUGIN_SRC"/* "$PLUGIN_DEST/"
|
||||
```
|
||||
|
||||
**Tell user to run the following commands manually in their terminal (interactive commands, cannot be executed by AI agent):**
|
||||
|
||||
```bash
|
||||
# Step 1: Enable the plugin
|
||||
openclaw plugins enable higress
|
||||
|
||||
# Step 2: Configure provider (interactive - will prompt for Gateway URL, API Key, models, etc.)
|
||||
openclaw models auth login --provider higress --set-default
|
||||
|
||||
# Step 3: Restart OpenClaw gateway to apply changes
|
||||
openclaw gateway restart
|
||||
```
|
||||
|
||||
The `openclaw models auth login` command will interactively prompt for:
|
||||
1. Gateway URL (default: `http://localhost:8080`)
|
||||
2. Console URL (default: `http://localhost:8001`)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using `higress/auto`)
|
||||
|
||||
After configuration and restart, Higress models are available in OpenClaw with `higress/` prefix (e.g., `higress/glm-5`, `higress/auto`).
|
||||
|
||||
**Future Configuration Updates (No Restart Needed)**
|
||||
|
||||
After the initial setup, you can manage your configuration through conversation with OpenClaw:
|
||||
|
||||
- **Add New Providers**: Add new LLM providers (e.g., DeepSeek, OpenAI, Claude) and their models dynamically.
|
||||
- **Update API Keys**: Update existing provider API keys without service restart.
|
||||
- **Configure Auto-routing**: If you've set up multiple models, ask OpenClaw to configure auto-routing rules. Requests will be intelligently routed based on your message content, using the most suitable model automatically.
|
||||
|
||||
All configuration changes are hot-loaded through Higress — no `openclaw gateway restart` required. Iterate on your model provider setup dynamically without service interruption!
|
||||
|
||||
## Post-Deployment Management
|
||||
|
||||
### Add/Update API Keys (Hot-reload)
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh config add --provider <provider> --key <api-key>
|
||||
./get-ai-gateway.sh config list
|
||||
./get-ai-gateway.sh config remove --provider <provider>
|
||||
```
|
||||
|
||||
Provider aliases: `dashscope`/`qwen`, `moonshot`/`kimi`, `zhipuai`/`zhipu`
|
||||
|
||||
### Update z.ai Domain (Hot-reload)
|
||||
|
||||
If user wants to change the z.ai domain after deployment:
|
||||
|
||||
```bash
|
||||
# Update domain configuration
|
||||
./get-ai-gateway.sh config add --provider zhipuai --extra-config "zhipuDomain=api.z.ai"
|
||||
# Or revert to China endpoint
|
||||
./get-ai-gateway.sh config add --provider zhipuai --extra-config "zhipuDomain=open.bigmodel.cn"
|
||||
```
|
||||
|
||||
### Add Routing Rules (for auto-routing)
|
||||
|
||||
```bash
|
||||
# Add rule: route to specific model when message starts with trigger
|
||||
./get-ai-gateway.sh route add --model <model> --trigger "keyword1|keyword2"
|
||||
|
||||
# Examples
|
||||
./get-ai-gateway.sh route add --model glm-4-flash --trigger "quick|fast"
|
||||
./get-ai-gateway.sh route add --model claude-opus-4 --trigger "think|complex"
|
||||
./get-ai-gateway.sh route add --model deepseek-coder --trigger "code|debug"
|
||||
|
||||
# List/remove rules
|
||||
./get-ai-gateway.sh route list
|
||||
./get-ai-gateway.sh route remove --rule-id 0
|
||||
```
|
||||
|
||||
### Stop/Delete Gateway
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh stop
|
||||
./get-ai-gateway.sh delete
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Endpoint | URL |
|
||||
|----------|-----|
|
||||
| Chat Completions | http://localhost:8080/v1/chat/completions |
|
||||
| Console | http://localhost:8001 |
|
||||
| Logs | `./higress-install/logs/access.log` |
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Test with specific model
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "<model-name>", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
|
||||
# Test auto-routing (if enabled)
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "higress/auto", "messages": [{"role": "user", "content": "What is AI?"}]}'
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Container fails to start | Check `docker logs higress-ai-gateway` |
|
||||
| Port already in use | Use `--http-port`, `--console-port` to change ports |
|
||||
| API key error | Run `./get-ai-gateway.sh config list` to verify keys |
|
||||
| Auto-routing not working | Ensure `--auto-routing` was set during deployment |
|
||||
| Slow image download | Script auto-selects nearest registry based on timezone |
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **Claude Code Mode**: Requires OAuth token from `claude setup-token` command, not a regular API key
|
||||
2. **z.ai Code Plan Mode**: Enabled by default, uses `/api/coding/paas/v4/chat/completions` endpoint, optimized for coding tasks
|
||||
3. **z.ai Domain Selection**:
|
||||
- China users: `open.bigmodel.cn` (default)
|
||||
- International users: `api.z.ai` (auto-detected based on timezone)
|
||||
- Users can update domain anytime after deployment
|
||||
4. **Auto-routing**: Must be enabled during initial deployment (`--auto-routing`); routing rules can be added later
|
||||
5. **OpenClaw Integration**: The `openclaw models auth login` and `openclaw gateway restart` commands are **interactive** and must be run by the user manually in their terminal
|
||||
6. **Hot-reload**: API key changes take effect immediately; no container restart needed
|
||||
15
.claude/skills/higress-openclaw-integration/scripts/detect-region.sh
Executable file
15
.claude/skills/higress-openclaw-integration/scripts/detect-region.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
# Detect if user is in China region based on timezone
|
||||
# Returns: "china" or "international"
|
||||
|
||||
TIMEZONE=$(cat /etc/timezone 2>/dev/null || timedatectl show --property=Timezone --value 2>/dev/null || echo "Unknown")
|
||||
|
||||
# Check if timezone indicates China region (including Hong Kong)
|
||||
if [[ "$TIMEZONE" == "Asia/Shanghai" ]] || \
|
||||
[[ "$TIMEZONE" == "Asia/Hong_Kong" ]] || \
|
||||
[[ "$TIMEZONE" == *"China"* ]] || \
|
||||
[[ "$TIMEZONE" == *"Beijing"* ]]; then
|
||||
echo "china"
|
||||
else
|
||||
echo "international"
|
||||
fi
|
||||
@@ -0,0 +1,61 @@
|
||||
# Higress AI Gateway Plugin
|
||||
|
||||
OpenClaw model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables OpenClaw to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and OpenClaw extension declaration
|
||||
- **openclaw.plugin.json**: Plugin manifest for OpenClaw
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-openclaw-integration` skill. See parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Copy plugin files
|
||||
mkdir -p "$HOME/.openclaw/extensions/higress"
|
||||
cp -r ./* "$HOME/.openclaw/extensions/higress/"
|
||||
|
||||
# Configure provider
|
||||
openclaw plugins enable higress
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-openclaw-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
@@ -2,33 +2,47 @@ import { emptyPluginConfigSchema } from "openclaw/plugin-sdk";
|
||||
|
||||
const DEFAULT_GATEWAY_URL = "http://localhost:8080";
|
||||
const DEFAULT_CONSOLE_URL = "http://localhost:8001";
|
||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||
const DEFAULT_MAX_TOKENS = 8192;
|
||||
|
||||
// Model-specific context window and max tokens configurations
|
||||
const MODEL_CONFIG: Record<string, { contextWindow: number; maxTokens: number }> = {
|
||||
"gpt-5.3-codex": { contextWindow: 400_000, maxTokens: 128_000 },
|
||||
"gpt-5-mini": { contextWindow: 400_000, maxTokens: 128_000 },
|
||||
"gpt-5-nano": { contextWindow: 400_000, maxTokens: 128_000 },
|
||||
"claude-opus-4-6": { contextWindow: 1_000_000, maxTokens: 128_000 },
|
||||
"claude-sonnet-4-6": { contextWindow: 1_000_000, maxTokens: 64_000 },
|
||||
"claude-haiku-4-5": { contextWindow: 200_000, maxTokens: 64_000 },
|
||||
"qwen3.5-plus": { contextWindow: 960_000, maxTokens: 64_000 },
|
||||
"deepseek-chat": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"deepseek-reasoner": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"kimi-k2.5": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"glm-5": { contextWindow: 200_000, maxTokens: 128_000 },
|
||||
"MiniMax-M2.5": { contextWindow: 200_000, maxTokens: 128_000 },
|
||||
};
|
||||
|
||||
// Default values for unknown models
|
||||
const DEFAULT_CONTEXT_WINDOW = 200_000;
|
||||
const DEFAULT_MAX_TOKENS = 128_000;
|
||||
|
||||
// Common models that Higress AI Gateway typically supports
|
||||
const DEFAULT_MODEL_IDS = [
|
||||
// Auto-routing special model
|
||||
"higress/auto",
|
||||
// Commonly models
|
||||
"kimi-k2.5",
|
||||
"glm-5",
|
||||
"MiniMax-M2.5",
|
||||
"qwen3.5-plus",
|
||||
// Anthropic models
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5",
|
||||
// OpenAI models
|
||||
"gpt-5.2",
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
// Anthropic models
|
||||
"claude-opus-4.5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-haiku-4.5",
|
||||
// Qwen models
|
||||
"qwen3-turbo",
|
||||
"qwen3-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-480b-a35b-instruct",
|
||||
// DeepSeek models
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
// Other common models
|
||||
"kimi-k2.5",
|
||||
"glm-4.7",
|
||||
"MiniMax-M2.1",
|
||||
"deepseek-reasoner",
|
||||
] as const;
|
||||
|
||||
function normalizeBaseUrl(value: string): string {
|
||||
@@ -60,26 +74,33 @@ function parseModelIds(input: string): string[] {
|
||||
|
||||
function buildModelDefinition(modelId: string) {
|
||||
const isAutoModel = modelId === "higress/auto";
|
||||
const config = MODEL_CONFIG[modelId] || { contextWindow: DEFAULT_CONTEXT_WINDOW, maxTokens: DEFAULT_MAX_TOKENS };
|
||||
|
||||
return {
|
||||
id: modelId,
|
||||
name: isAutoModel ? "Higress Auto Router" : modelId,
|
||||
api: "openai-completions",
|
||||
reasoning: false,
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: DEFAULT_MAX_TOKENS,
|
||||
contextWindow: config.contextWindow,
|
||||
maxTokens: config.maxTokens,
|
||||
};
|
||||
}
|
||||
|
||||
async function testGatewayConnection(gatewayUrl: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${gatewayUrl}/v1/models`, {
|
||||
method: "GET",
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
// Use chat/completions endpoint with empty body to test connection
|
||||
// Higress doesn't support /models endpoint
|
||||
const response = await fetch(`${gatewayUrl}/chat/completions`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({}),
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
return response.ok || response.status === 401; // 401 means gateway is up but needs auth
|
||||
// Any response (including 400/401/422) means gateway is reachable
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
@@ -108,7 +129,7 @@ async function fetchAvailableModels(consoleUrl: string): Promise<string[]> {
|
||||
}
|
||||
|
||||
const higressPlugin = {
|
||||
id: "higress-ai-gateway",
|
||||
id: "higress",
|
||||
name: "Higress AI Gateway",
|
||||
description: "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
configSchema: emptyPluginConfigSchema(),
|
||||
@@ -150,8 +171,6 @@ const higressPlugin = {
|
||||
[
|
||||
"Could not connect to Higress AI Gateway.",
|
||||
"Make sure the gateway is running and the URL is correct.",
|
||||
"",
|
||||
`Tried: ${gatewayUrl}/v1/models`,
|
||||
].join("\n"),
|
||||
"Connection Warning",
|
||||
);
|
||||
@@ -184,21 +203,19 @@ const higressPlugin = {
|
||||
|
||||
const modelIds = parseModelIds(modelInput);
|
||||
const hasAutoModel = modelIds.includes("higress/auto");
|
||||
|
||||
// FIX: Avoid double prefix - if modelId already starts with provider, don't add prefix again
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "qwen-turbo");
|
||||
const defaultModelRef = defaultModelId.startsWith("higress/")
|
||||
? defaultModelId
|
||||
: `higress/${defaultModelId}`;
|
||||
|
||||
// Always add higress/ provider prefix to create model reference
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "glm-5");
|
||||
const defaultModelRef = `higress/${defaultModelId}`;
|
||||
|
||||
// Step 7: Configure default model for auto-routing
|
||||
let autoRoutingDefaultModel = "qwen-turbo";
|
||||
let autoRoutingDefaultModel = "glm-5";
|
||||
if (hasAutoModel) {
|
||||
const autoRoutingModelInput = await ctx.prompter.text({
|
||||
message: "Default model for auto-routing (when no rule matches)",
|
||||
initialValue: "qwen-turbo",
|
||||
initialValue: "glm-5",
|
||||
});
|
||||
autoRoutingDefaultModel = autoRoutingModelInput.trim(); // FIX: Add trim() here
|
||||
}
|
||||
@@ -218,7 +235,8 @@ const higressPlugin = {
|
||||
models: {
|
||||
providers: {
|
||||
higress: {
|
||||
baseUrl: `${gatewayUrl}/v1`,
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
baseUrl: gatewayUrl,
|
||||
apiKey: apiKey,
|
||||
api: "openai-completions",
|
||||
authHeader: apiKey !== "higress-local",
|
||||
@@ -230,10 +248,8 @@ const higressPlugin = {
|
||||
defaults: {
|
||||
models: Object.fromEntries(
|
||||
modelIds.map((modelId) => {
|
||||
// FIX: Avoid double prefix - only add provider prefix if not already present
|
||||
const modelRef = modelId.startsWith("higress/")
|
||||
? modelId
|
||||
: `higress/${modelId}`;
|
||||
// Always add higress/ provider prefix to create model reference
|
||||
const modelRef = `higress/${modelId}`;
|
||||
return [modelRef, {}];
|
||||
}),
|
||||
),
|
||||
@@ -241,7 +257,7 @@ const higressPlugin = {
|
||||
},
|
||||
plugins: {
|
||||
entries: {
|
||||
"higress-ai-gateway": {
|
||||
"higress": {
|
||||
enabled: true,
|
||||
config: {
|
||||
gatewayUrl,
|
||||
@@ -258,20 +274,22 @@ const higressPlugin = {
|
||||
hasAutoModel
|
||||
? `Auto-routing enabled: use model "higress/auto" to route based on message content.`
|
||||
: "Add 'higress/auto' to models to enable auto-routing.",
|
||||
`Gateway endpoint: ${gatewayUrl}/v1/chat/completions`,
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
`Gateway endpoint: ${gatewayUrl}/chat/completions`,
|
||||
`Console: ${consoleUrl}`,
|
||||
"",
|
||||
"🎯 Recommended Skills (install via Clawdbot conversation):",
|
||||
"💡 Future Configuration Updates (No Restart Needed):",
|
||||
" • Add New Providers: Add LLM providers (DeepSeek, OpenAI, Claude, etc.) dynamically.",
|
||||
" • Update API Keys: Update existing provider keys without restart.",
|
||||
" • Configure Auto-Routing: Ask OpenClaw to set up intelligent routing rules.",
|
||||
" All changes hot-load via Higress — no gateway restart required!",
|
||||
"",
|
||||
"🎯 Recommended Skills (install via OpenClaw conversation):",
|
||||
"",
|
||||
"1. Auto-Routing Skill:",
|
||||
" Configure automatic model routing based on message content",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/higress-auto-router",
|
||||
' Say: "Install higress-auto-router skill"',
|
||||
"",
|
||||
"2. Agent Session Monitor Skill:",
|
||||
" Track token usage and monitor conversation history",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/agent-session-monitor",
|
||||
' Say: "Install agent-session-monitor skill"',
|
||||
],
|
||||
};
|
||||
},
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"id": "higress-ai-gateway",
|
||||
"id": "higress",
|
||||
"name": "Higress AI Gateway",
|
||||
"description": "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
"providers": ["higress"],
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "@higress/higress-ai-gateway",
|
||||
"name": "@higress/higress",
|
||||
"version": "1.0.0",
|
||||
"description": "Higress AI Gateway model provider plugin for OpenClaw with auto-routing support",
|
||||
"main": "index.ts",
|
||||
18
.github/workflows/build-and-test-plugin.yaml
vendored
18
.github/workflows/build-and-test-plugin.yaml
vendored
@@ -19,7 +19,7 @@ on:
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
# - run: make lint
|
||||
|
||||
higress-wasmplugin-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
# TODO(Xunzhuo): Enable C WASM Filters in CI
|
||||
@@ -38,6 +38,18 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Disable containerd image store
|
||||
run: |
|
||||
sudo bash -c 'cat > /etc/docker/daemon.json << EOF
|
||||
{
|
||||
"features": {
|
||||
"containerd-snapshotter": false
|
||||
}
|
||||
}
|
||||
EOF'
|
||||
sudo systemctl restart docker
|
||||
docker info -f '{{ .DriverStatus }}'
|
||||
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
@@ -79,7 +91,7 @@ jobs:
|
||||
command: GOPROXY="https://proxy.golang.org,direct" PLUGIN_TYPE=${{ matrix.wasmPluginType }} make higress-wasmplugin-test
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [higress-wasmplugin-test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
24
.github/workflows/build-and-test.yaml
vendored
24
.github/workflows/build-and-test.yaml
vendored
@@ -10,7 +10,7 @@ env:
|
||||
GO_VERSION: 1.24
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
# - run: make lint
|
||||
|
||||
coverage-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
build:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [lint, coverage-test]
|
||||
steps:
|
||||
- name: "Checkout ${{ github.ref }}"
|
||||
@@ -91,17 +91,29 @@ jobs:
|
||||
path: out/
|
||||
|
||||
gateway-conformance-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [build]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
higress-conformance-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Disable containerd image store
|
||||
run: |
|
||||
sudo bash -c 'cat > /etc/docker/daemon.json << EOF
|
||||
{
|
||||
"features": {
|
||||
"containerd-snapshotter": false
|
||||
}
|
||||
}
|
||||
EOF'
|
||||
sudo systemctl restart docker
|
||||
docker info -f '{{ .DriverStatus }}'
|
||||
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
@@ -139,7 +151,7 @@ jobs:
|
||||
run: GOPROXY="https://proxy.golang.org,direct" make higress-conformance-test
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [higress-conformance-test, gateway-conformance-test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
50
.github/workflows/sync-skills-to-oss.yaml
vendored
Normal file
50
.github/workflows/sync-skills-to-oss.yaml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Sync Skills to OSS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '.claude/skills/**'
|
||||
workflow_dispatch: ~
|
||||
|
||||
jobs:
|
||||
sync-skills-to-oss:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: oss
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download AI Gateway Install Script
|
||||
run: |
|
||||
wget -O install.sh https://raw.githubusercontent.com/higress-group/higress-standalone/main/all-in-one/get-ai-gateway.sh
|
||||
chmod +x install.sh
|
||||
|
||||
- name: Package Skills
|
||||
run: |
|
||||
mkdir -p packaged-skills
|
||||
for skill_dir in .claude/skills/*/; do
|
||||
if [ -d "$skill_dir" ]; then
|
||||
skill_name=$(basename "$skill_dir")
|
||||
echo "Packaging $skill_name..."
|
||||
(cd "$skill_dir" && zip -r "$GITHUB_WORKSPACE/packaged-skills/${skill_name}.zip" .)
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Sync Skills to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -r -u packaged-skills/ oss://higress-ai/skills/'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
|
||||
- name: Sync Install Script to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -u install.sh oss://higress-ai/ai-gateway/install.sh'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
12
README.md
12
README.md
@@ -86,6 +86,18 @@ Port descriptions:
|
||||
>
|
||||
> **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
> **For Kubernetes deployments**, you can configure the `global.hub` parameter in Helm values to use a mirror registry closer to your region. This applies to both Higress component images and built-in Wasm plugin images:
|
||||
>
|
||||
> ```bash
|
||||
> # Example: Using North America mirror
|
||||
> helm install higress -n higress-system higress.io/higress --set global.hub=higress-registry.us-west-1.cr.aliyuncs.com --create-namespace
|
||||
> ```
|
||||
>
|
||||
> Available mirror registries:
|
||||
> - **China (Hangzhou)**: `higress-registry.cn-hangzhou.cr.aliyuncs.com` (default)
|
||||
> - **North America**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
> - **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.io/en-us/docs/user/quickstart).
|
||||
|
||||
If you are deploying on the cloud, it is recommended to use the [Enterprise Edition](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
18
README_ZH.md
18
README_ZH.md
@@ -80,6 +80,24 @@ docker run -d --rm --name higress-ai -v ${PWD}:/data \
|
||||
|
||||
**Higress 的所有 Docker 镜像都一直使用自己独享的仓库,不受 Docker Hub 境内访问受限的影响**
|
||||
|
||||
> 如果从 `higress-registry.cn-hangzhou.cr.aliyuncs.com` 拉取镜像超时,可以尝试使用以下镜像加速源:
|
||||
>
|
||||
> **北美**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
>
|
||||
> **东南亚**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
> **K8s 部署时**,可以通过 Helm values 配置 `global.hub` 参数来使用距离部署区域更近的镜像仓库,该参数会同时应用于 Higress 组件镜像和内置 Wasm 插件镜像:
|
||||
>
|
||||
> ```bash
|
||||
> # 示例:使用北美镜像源
|
||||
> helm install higress -n higress-system higress.io/higress --set global.hub=higress-registry.us-west-1.cr.aliyuncs.com --create-namespace
|
||||
> ```
|
||||
>
|
||||
> 可用镜像仓库:
|
||||
> - **中国(杭州)**: `higress-registry.cn-hangzhou.cr.aliyuncs.com`(默认)
|
||||
> - **北美**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
> - **东南亚**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start 文档](https://higress.cn/docs/latest/user/quickstart/)。
|
||||
|
||||
如果您是在云上部署,推荐使用[企业版](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
@@ -23,7 +23,7 @@ spec:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: "{{ .Values.global.hub }}/{{ .Values.redis.image | default "redis-stack-server" }}:{{ .Values.redis.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.global.hub }}/higress/{{ .Values.redis.image | default "redis-stack-server" }}:{{ .Values.redis.tag | default .Chart.AppVersion }}"
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
|
||||
@@ -39,7 +39,7 @@ template:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: higress-gateway
|
||||
image: "{{ .Values.gateway.hub | default .Values.global.hub }}/{{ .Values.gateway.image | default "gateway" }}:{{ .Values.gateway.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.gateway.hub | default .Values.global.hub }}/higress/{{ .Values.gateway.image | default "gateway" }}:{{ .Values.gateway.tag | default .Chart.AppVersion }}"
|
||||
args:
|
||||
- proxy
|
||||
- router
|
||||
@@ -205,7 +205,7 @@ template:
|
||||
{{- if $o11y.enabled }}
|
||||
{{- $config := $o11y.promtail }}
|
||||
- name: promtail
|
||||
image: {{ $config.image.repository | default (printf "%s/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
|
||||
image: {{ $config.image.repository | default (printf "%s/higress/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
|
||||
imagePullPolicy: IfNotPresent
|
||||
args:
|
||||
- -config.file=/etc/promtail/promtail.yaml
|
||||
|
||||
@@ -38,7 +38,7 @@ spec:
|
||||
- name: {{ .Chart.Name }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.controller.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.controller.hub | default .Values.global.hub }}/{{ .Values.controller.image | default "higress" }}:{{ .Values.controller.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.controller.hub | default .Values.global.hub }}/higress/{{ .Values.controller.image | default "higress" }}:{{ .Values.controller.tag | default .Chart.AppVersion }}"
|
||||
args:
|
||||
- "serve"
|
||||
- --gatewaySelectorKey=higress
|
||||
@@ -104,7 +104,7 @@ spec:
|
||||
- name: log
|
||||
mountPath: /var/log
|
||||
- name: discovery
|
||||
image: "{{ .Values.pilot.hub | default .Values.global.hub }}/{{ .Values.pilot.image | default "pilot" }}:{{ .Values.pilot.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.pilot.hub | default .Values.global.hub }}/higress/{{ .Values.pilot.image | default "pilot" }}:{{ .Values.pilot.tag | default .Chart.AppVersion }}"
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
|
||||
@@ -23,7 +23,7 @@ spec:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: {{ .Values.pluginServer.hub | default .Values.global.hub }}/{{ .Values.pluginServer.image | default "plugin-server" }}:{{ .Values.pluginServer.tag | default "1.0.0" }}
|
||||
image: {{ .Values.pluginServer.hub | default .Values.global.hub }}/higress/{{ .Values.pluginServer.image | default "plugin-server" }}:{{ .Values.pluginServer.tag | default "1.0.0" }}
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
|
||||
@@ -70,10 +70,14 @@ global:
|
||||
# cpu: 100m
|
||||
# memory: 128Mi
|
||||
|
||||
# -- Default hub for Istio images.
|
||||
# Releases are published to docker hub under 'istio' project.
|
||||
# Dev builds from prow are on gcr.io
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
# -- Default hub (registry) for Higress images.
|
||||
# For Higress deployments, images are pulled from: {hub}/higress/{image}
|
||||
# For built-in plugins, images are pulled from: {hub}/{pluginNamespace}/{plugin-name}
|
||||
# Change this to use a mirror registry closer to your deployment region for faster image pulls.
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
# -- Namespace for built-in plugin images. Default is "plugins".
|
||||
# Used by higress-console to configure plugin image path.
|
||||
pluginNamespace: "plugins"
|
||||
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
|
||||
@@ -178,7 +178,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. |
|
||||
| global.externalIstiod | bool | `false` | Configure a remote cluster data plane controlled by an external istiod. When set to true, istiod is not deployed locally and only a subset of the other discovery charts are enabled. |
|
||||
| global.hostRDSMergeSubset | bool | `false` | |
|
||||
| global.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | Default hub for Istio images. Releases are published to docker hub under 'istio' project. Dev builds from prow are on gcr.io |
|
||||
| global.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com"` | Default hub (registry) for Higress images. For Higress deployments, images are pulled from: {hub}/higress/{image} For built-in plugins, images are pulled from: {hub}/{pluginNamespace}/{plugin-name} Change this to use a mirror registry closer to your deployment region for faster image pulls. |
|
||||
| global.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
|
||||
| global.imagePullSecrets | list | `[]` | ImagePullSecrets for all ServiceAccount, list of secrets in the same namespace to use for pulling any images in pods that reference this ServiceAccount. For components that don't use ServiceAccounts (i.e. grafana, servicegraph, tracing) ImagePullSecrets will be added to the corresponding Deployment(StatefulSet) objects. Must be set for any cluster configured with private docker registry. |
|
||||
| global.ingressClass | string | `"higress"` | IngressClass filters which ingress resources the higress controller watches. The default ingress class is higress. There are some special cases for special ingress class. 1. When the ingress class is set as nginx, the higress controller will watch ingress resources with the nginx ingress class or without any ingress class. 2. When the ingress class is set empty, the higress controller will watch all ingress resources in the k8s cluster. |
|
||||
@@ -203,6 +203,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.onlyPushRouteCluster | bool | `true` | |
|
||||
| global.operatorManageWebhooks | bool | `false` | Configure whether Operator manages webhook configurations. The current behavior of Istiod is to manage its own webhook configurations. When this option is set as true, Istio Operator, instead of webhooks, manages the webhook configurations. When this option is set as false, webhooks manage their own webhook configurations. |
|
||||
| global.pilotCertProvider | string | `"istiod"` | Configure the certificate provider for control plane communication. Currently, two providers are supported: "kubernetes" and "istiod". As some platforms may not have kubernetes signing APIs, Istiod is the default |
|
||||
| global.pluginNamespace | string | `"plugins"` | Namespace for built-in plugin images. Default is "plugins". Used by higress-console to configure plugin image path. |
|
||||
| global.priorityClassName | string | `""` | Kubernetes >=v1.11.0 will create two PriorityClass, including system-cluster-critical and system-node-critical, it is better to configure this in order to make sure your Istio pods will not be killed because of low priority class. Refer to https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption/#priorityclass for more detail. |
|
||||
| global.proxy.autoInject | string | `"enabled"` | This controls the 'policy' in the sidecar injector. |
|
||||
| global.proxy.clusterDomain | string | `"cluster.local"` | CAUTION: It is important to ensure that all Istio helm charts specify the same clusterDomain value cluster domain. Default value is "cluster.local". |
|
||||
|
||||
@@ -157,3 +157,8 @@ func TestClaude(t *testing.T) {
|
||||
test.RunClaudeOnHttpRequestHeadersTests(t)
|
||||
test.RunClaudeOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestConsumerAffinity(t *testing.T) {
|
||||
test.RunConsumerAffinityParseConfigTests(t)
|
||||
test.RunConsumerAffinityOnHttpRequestHeadersTests(t)
|
||||
}
|
||||
|
||||
@@ -73,8 +73,8 @@ type claudeChatMessageContent struct {
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
// Tool result fields
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content claudeChatMessageContentWr `json:"content,omitempty"` // For tool_result - can be string or array
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content *claudeChatMessageContentWr `json:"content,omitempty"` // For tool_result - can be string or array
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
|
||||
@@ -237,13 +237,13 @@ type claudeTextGenResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
Signature string `json:"signature,omitempty"` // For thinking
|
||||
Thinking string `json:"thinking,omitempty"` // For thinking
|
||||
Type string `json:"type,omitempty"`
|
||||
Text *string `json:"text,omitempty"` // Use pointer: empty string outputs "text":"", nil omits field
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input *map[string]interface{} `json:"input,omitempty"` // Use pointer: empty map outputs "input":{}, nil omits field
|
||||
Signature *string `json:"signature,omitempty"` // For thinking - use pointer for empty string output
|
||||
Thinking *string `json:"thinking,omitempty"` // For thinking - use pointer for empty string output
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
@@ -269,11 +269,12 @@ type claudeTextGenStreamResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence json.RawMessage `json:"stop_sequence,omitempty"` // Use RawMessage to output explicit null
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
@@ -441,6 +442,34 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
claudeRequest.MaxTokens = claudeDefaultMaxTokens
|
||||
}
|
||||
|
||||
// Convert OpenAI reasoning parameters to Claude thinking configuration
|
||||
if origRequest.ReasoningEffort != "" || origRequest.ReasoningMaxTokens > 0 {
|
||||
var budgetTokens int
|
||||
if origRequest.ReasoningMaxTokens > 0 {
|
||||
budgetTokens = origRequest.ReasoningMaxTokens
|
||||
} else {
|
||||
// Convert reasoning_effort to budget_tokens
|
||||
switch origRequest.ReasoningEffort {
|
||||
case "low":
|
||||
budgetTokens = 1024 // Minimum required by Claude
|
||||
case "medium":
|
||||
budgetTokens = 8192
|
||||
case "high":
|
||||
budgetTokens = 16384
|
||||
default:
|
||||
budgetTokens = 8192 // Default to medium
|
||||
}
|
||||
}
|
||||
// Ensure minimum budget_tokens requirement
|
||||
if budgetTokens < 1024 {
|
||||
budgetTokens = 1024
|
||||
}
|
||||
claudeRequest.Thinking = &claudeThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: budgetTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Track if system message exists in original request
|
||||
hasSystemMessage := false
|
||||
for _, message := range origRequest.Messages {
|
||||
@@ -469,9 +498,102 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle OpenAI "tool" role messages - convert to Claude "user" role with tool_result content
|
||||
if message.Role == roleTool {
|
||||
toolResultContent := claudeChatMessageContent{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
}
|
||||
// Tool result content can be string or array
|
||||
if message.IsStringContent() {
|
||||
toolResultContent.Content = &claudeChatMessageContentWr{
|
||||
StringValue: message.StringContent(),
|
||||
IsString: true,
|
||||
}
|
||||
} else {
|
||||
// For array content, extract text parts
|
||||
var textParts []string
|
||||
for _, part := range message.ParseContent() {
|
||||
if part.Type == contentTypeText {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
toolResultContent.Content = &claudeChatMessageContentWr{
|
||||
StringValue: strings.Join(textParts, "\n"),
|
||||
IsString: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the last message is a user message with tool_result, merge if so
|
||||
if len(claudeRequest.Messages) > 0 {
|
||||
lastMsg := &claudeRequest.Messages[len(claudeRequest.Messages)-1]
|
||||
if lastMsg.Role == roleUser && !lastMsg.Content.IsString {
|
||||
// Check if last message contains tool_result
|
||||
hasToolResult := false
|
||||
for _, content := range lastMsg.Content.ArrayValue {
|
||||
if content.Type == "tool_result" {
|
||||
hasToolResult = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasToolResult {
|
||||
// Merge with existing tool_result message
|
||||
lastMsg.Content.ArrayValue = append(lastMsg.Content.ArrayValue, toolResultContent)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new user message with tool_result
|
||||
claudeMessage := claudeChatMessage{
|
||||
Role: roleUser,
|
||||
Content: NewArrayContent([]claudeChatMessageContent{toolResultContent}),
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
claudeMessage := claudeChatMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool_calls - convert to Claude tool_use content blocks
|
||||
if message.Role == roleAssistant && len(message.ToolCalls) > 0 {
|
||||
chatMessageContents := make([]claudeChatMessageContent, 0)
|
||||
|
||||
// Add text content if present
|
||||
if message.IsStringContent() && message.StringContent() != "" {
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: message.StringContent(),
|
||||
})
|
||||
}
|
||||
|
||||
// Convert tool_calls to tool_use content blocks
|
||||
for _, tc := range message.ToolCalls {
|
||||
var inputMap map[string]interface{}
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &inputMap); err != nil {
|
||||
log.Errorf("failed to parse tool call arguments: %v", err)
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
} else {
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: "tool_use",
|
||||
Id: tc.Id,
|
||||
Name: tc.Function.Name,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
|
||||
claudeMessage.Content = NewArrayContent(chatMessageContents)
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = NewStringContent(message.StringContent())
|
||||
} else {
|
||||
@@ -562,9 +684,41 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
}
|
||||
|
||||
func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenResponse) *chatCompletionResponse {
|
||||
// Extract text content, thinking content, and tool calls from Claude response
|
||||
var textContent string
|
||||
var reasoningContent string
|
||||
var toolCalls []toolCall
|
||||
for _, content := range origResponse.Content {
|
||||
switch content.Type {
|
||||
case contentTypeText:
|
||||
if content.Text != nil {
|
||||
textContent = *content.Text
|
||||
}
|
||||
case "thinking":
|
||||
if content.Thinking != nil {
|
||||
reasoningContent = *content.Thinking
|
||||
}
|
||||
case "tool_use":
|
||||
var args []byte
|
||||
if content.Input != nil {
|
||||
args, _ = json.Marshal(*content.Input)
|
||||
} else {
|
||||
args = []byte("{}")
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall{
|
||||
Id: content.Id,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: content.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: origResponse.Content[0].Text},
|
||||
Message: &chatMessage{Role: roleAssistant, Content: textContent, ReasoningContent: reasoningContent, ToolCalls: toolCalls},
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
|
||||
}
|
||||
|
||||
@@ -600,6 +754,8 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
return finishReasonStop
|
||||
case "max_tokens":
|
||||
return finishReasonLength
|
||||
case "tool_use":
|
||||
return finishReasonToolCall
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
@@ -626,11 +782,64 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_start":
|
||||
// Handle tool_use content block start
|
||||
if origResponse.ContentBlock != nil && origResponse.ContentBlock.Type == "tool_use" {
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{
|
||||
ToolCalls: []toolCall{
|
||||
{
|
||||
Index: index,
|
||||
Id: origResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: origResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
return nil
|
||||
|
||||
case "content_block_delta":
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
// Handle tool_use input_json_delta
|
||||
if origResponse.Delta != nil && origResponse.Delta.Type == "input_json_delta" {
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{
|
||||
ToolCalls: []toolCall{
|
||||
{
|
||||
Index: index,
|
||||
Function: functionCall{
|
||||
Arguments: origResponse.Delta.PartialJson,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
// Handle thinking_delta
|
||||
if origResponse.Delta != nil && origResponse.Delta.Type == "thinking_delta" {
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{Reasoning: origResponse.Delta.Thinking},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
// Handle text_delta
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
@@ -667,7 +876,7 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or
|
||||
TotalTokens: c.usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
case "content_block_stop", "ping", "content_block_start":
|
||||
case "content_block_stop", "ping":
|
||||
log.Debugf("skip processing response type: %s", origResponse.Type)
|
||||
return nil
|
||||
default:
|
||||
|
||||
@@ -315,3 +315,107 @@ func TestClaudeProvider_GetApiName(t *testing.T) {
|
||||
assert.Equal(t, ApiName(""), provider.GetApiName("/unknown"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeProvider_BuildClaudeTextGenRequest_ToolRoleConversion(t *testing.T) {
|
||||
provider := &claudeProvider{
|
||||
config: ProviderConfig{
|
||||
claudeCodeMode: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("converts_single_tool_role_to_user_with_tool_result", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather?"},
|
||||
{Role: roleAssistant, Content: nil, ToolCalls: []toolCall{
|
||||
{Id: "call_123", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
}},
|
||||
{Role: roleTool, ToolCallId: "call_123", Content: "Sunny, 25°C"},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
// Should have 3 messages: user, assistant with tool_use, user with tool_result
|
||||
require.Len(t, claudeReq.Messages, 3)
|
||||
|
||||
// First message should be user
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[0].Role)
|
||||
|
||||
// Second message should be assistant with tool_use
|
||||
assert.Equal(t, roleAssistant, claudeReq.Messages[1].Role)
|
||||
require.False(t, claudeReq.Messages[1].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 1)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_123", claudeReq.Messages[1].Content.ArrayValue[0].Id)
|
||||
assert.Equal(t, "get_weather", claudeReq.Messages[1].Content.ArrayValue[0].Name)
|
||||
|
||||
// Third message should be user with tool_result
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[2].Role)
|
||||
require.False(t, claudeReq.Messages[2].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[2].Content.ArrayValue, 1)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_123", claudeReq.Messages[2].Content.ArrayValue[0].ToolUseId)
|
||||
})
|
||||
|
||||
t.Run("merges_multiple_tool_results_into_single_user_message", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather and time?"},
|
||||
{Role: roleAssistant, Content: nil, ToolCalls: []toolCall{
|
||||
{Id: "call_1", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
{Id: "call_2", Type: "function", Function: functionCall{Name: "get_time", Arguments: `{"timezone": "Asia/Shanghai"}`}},
|
||||
}},
|
||||
{Role: roleTool, ToolCallId: "call_1", Content: "Sunny, 25°C"},
|
||||
{Role: roleTool, ToolCallId: "call_2", Content: "3:00 PM"},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
// Should have 3 messages: user, assistant with 2 tool_use, user with 2 tool_results
|
||||
require.Len(t, claudeReq.Messages, 3)
|
||||
|
||||
// Assistant message should have 2 tool_use blocks
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 2)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[1].Type)
|
||||
|
||||
// User message should have 2 tool_result blocks merged
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[2].Role)
|
||||
require.Len(t, claudeReq.Messages[2].Content.ArrayValue, 2)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_1", claudeReq.Messages[2].Content.ArrayValue[0].ToolUseId)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[1].Type)
|
||||
assert.Equal(t, "call_2", claudeReq.Messages[2].Content.ArrayValue[1].ToolUseId)
|
||||
})
|
||||
|
||||
t.Run("handles_assistant_tool_calls_with_text_content", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather?"},
|
||||
{Role: roleAssistant, Content: "Let me check the weather for you.", ToolCalls: []toolCall{
|
||||
{Id: "call_123", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
require.Len(t, claudeReq.Messages, 2)
|
||||
|
||||
// Assistant message should have both text and tool_use
|
||||
assert.Equal(t, roleAssistant, claudeReq.Messages[1].Role)
|
||||
require.False(t, claudeReq.Messages[1].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 2)
|
||||
assert.Equal(t, contentTypeText, claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "Let me check the weather for you.", claudeReq.Messages[1].Content.ArrayValue[0].Text)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[1].Type)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -119,6 +119,15 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, toolMsg)
|
||||
}
|
||||
// Also add text content if present alongside tool results
|
||||
// This handles cases like: [tool_result, tool_result, text]
|
||||
if len(conversionResult.textParts) > 0 {
|
||||
textMsg := chatMessage{
|
||||
Role: claudeMsg.Role,
|
||||
Content: strings.Join(conversionResult.textParts, "\n\n"),
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, textMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular content if no tool calls or tool results
|
||||
@@ -136,7 +145,8 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
if claudeRequest.System != nil {
|
||||
systemMsg := chatMessage{Role: roleSystem}
|
||||
if !claudeRequest.System.IsArray {
|
||||
systemMsg.Content = claudeRequest.System.StringValue
|
||||
// Strip dynamic cch field from billing header to enable caching
|
||||
systemMsg.Content = stripCchFromBillingHeader(claudeRequest.System.StringValue)
|
||||
} else {
|
||||
conversionResult := c.convertContentArray(claudeRequest.System.ArrayValue)
|
||||
systemMsg.Content = conversionResult.openaiContents
|
||||
@@ -183,6 +193,7 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
|
||||
if claudeRequest.Thinking.Type == "enabled" {
|
||||
openaiRequest.ReasoningMaxTokens = claudeRequest.Thinking.BudgetTokens
|
||||
openaiRequest.Thinking = &thinkingParam{Type: "enabled", BudgetToken: claudeRequest.Thinking.BudgetTokens}
|
||||
|
||||
// Set ReasoningEffort based on budget_tokens
|
||||
// low: <4096, medium: >=4096 and <16384, high: >=16384
|
||||
@@ -198,7 +209,10 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
claudeRequest.Thinking.BudgetTokens, openaiRequest.ReasoningEffort, openaiRequest.ReasoningMaxTokens)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("[Claude->OpenAI] No thinking config found")
|
||||
// Explicitly disable thinking when not configured in Claude request
|
||||
// This prevents providers like ZhipuAI from enabling thinking by default
|
||||
openaiRequest.Thinking = &thinkingParam{Type: "disabled"}
|
||||
log.Debugf("[Claude->OpenAI] No thinking config found, explicitly disabled")
|
||||
}
|
||||
|
||||
result, err := json.Marshal(openaiRequest)
|
||||
@@ -253,19 +267,21 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
emptySignature := ""
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature, use empty string
|
||||
Thinking: reasoningText,
|
||||
Signature: &emptySignature, // Use pointer for empty string
|
||||
Thinking: &reasoningText,
|
||||
})
|
||||
log.Debugf("[OpenAI->Claude] Added thinking content: %s", reasoningText)
|
||||
}
|
||||
|
||||
// Add text content if present
|
||||
if choice.Message.StringContent() != "" {
|
||||
textContent := choice.Message.StringContent()
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: choice.Message.StringContent(),
|
||||
Text: &textContent,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -288,7 +304,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
|
||||
Type: "tool_use",
|
||||
Id: toolCall.Id,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: input,
|
||||
Input: &input,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -338,7 +354,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &c.thinkingBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
if c.textBlockStarted && !c.textBlockStopped {
|
||||
c.textBlockStopped = true
|
||||
@@ -348,7 +364,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &c.textBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
// Send final content_block_stop events for any remaining unclosed tool calls
|
||||
for index, toolCall := range c.toolCallStates {
|
||||
@@ -360,7 +376,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &toolCall.claudeContentIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,12 +386,12 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
StopReason: c.pendingStopReason,
|
||||
StopReason: c.pendingStopReason,
|
||||
StopSequence: json.RawMessage("null"),
|
||||
},
|
||||
}
|
||||
stopData, _ := json.Marshal(messageDelta)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", messageDelta.Type, stopData))
|
||||
c.pendingStopReason = nil
|
||||
}
|
||||
|
||||
@@ -386,7 +402,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Type: "message_stop",
|
||||
}
|
||||
stopData, _ := json.Marshal(messageStopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", messageStopEvent.Type, stopData))
|
||||
}
|
||||
|
||||
// Reset all state for next request
|
||||
@@ -515,13 +531,14 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
c.nextContentIndex++
|
||||
c.thinkingBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for thinking at index %d", c.thinkingBlockIndex)
|
||||
emptyStr := ""
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature
|
||||
Thinking: "",
|
||||
Signature: &emptyStr, // Use pointer for empty string output
|
||||
Thinking: &emptyStr, // Use pointer for empty string output
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -532,8 +549,8 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
Type: "content_block_delta",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "thinking_delta", // Use thinking_delta for reasoning content
|
||||
Text: reasoningText,
|
||||
Type: "thinking_delta",
|
||||
Thinking: reasoningText, // Use Thinking field, not Text
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -564,12 +581,13 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
c.nextContentIndex++
|
||||
c.textBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for text at index %d", c.textBlockIndex)
|
||||
emptyText := ""
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.textBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
Text: &emptyText,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -588,6 +606,30 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
|
||||
// Handle tool calls in streaming response
|
||||
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
|
||||
// Ensure message_start is sent before any content blocks
|
||||
if !c.messageStartSent {
|
||||
c.messageId = openaiResponse.Id
|
||||
c.messageStartSent = true
|
||||
message := &claudeTextGenResponse{
|
||||
Id: openaiResponse.Id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResponse.Model,
|
||||
Content: []claudeTextGenContent{},
|
||||
}
|
||||
if openaiResponse.Usage != nil {
|
||||
message.Usage = claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
OutputTokens: 0,
|
||||
}
|
||||
}
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "message_start",
|
||||
Message: message,
|
||||
})
|
||||
log.Debugf("[OpenAI->Claude] Generated message_start event before tool calls for id: %s", openaiResponse.Id)
|
||||
}
|
||||
|
||||
// Initialize toolCallStates if needed
|
||||
if c.toolCallStates == nil {
|
||||
c.toolCallStates = make(map[int]*toolCallInfo)
|
||||
@@ -722,7 +764,9 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
}
|
||||
|
||||
// Handle usage information
|
||||
if openaiResponse.Usage != nil && choice.FinishReason == nil {
|
||||
// Note: Some providers may send usage in the same chunk as finish_reason,
|
||||
// so we check for usage regardless of whether finish_reason is present
|
||||
if openaiResponse.Usage != nil {
|
||||
log.Debugf("[OpenAI->Claude] Processing usage info - input: %d, output: %d",
|
||||
openaiResponse.Usage.PromptTokens, openaiResponse.Usage.CompletionTokens)
|
||||
|
||||
@@ -730,7 +774,7 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
StopSequence: json.RawMessage("null"), // Explicit null per Claude spec
|
||||
},
|
||||
Usage: &claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
@@ -789,10 +833,12 @@ func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeCha
|
||||
switch claudeContent.Type {
|
||||
case "text":
|
||||
if claudeContent.Text != "" {
|
||||
result.textParts = append(result.textParts, claudeContent.Text)
|
||||
// Strip dynamic cch field from billing header to enable caching
|
||||
processedText := stripCchFromBillingHeader(claudeContent.Text)
|
||||
result.textParts = append(result.textParts, processedText)
|
||||
result.openaiContents = append(result.openaiContents, chatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: claudeContent.Text,
|
||||
Text: processedText,
|
||||
CacheControl: claudeContent.CacheControl,
|
||||
})
|
||||
}
|
||||
@@ -884,6 +930,7 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
toolState.claudeContentIndex, toolState.id, toolState.name)
|
||||
|
||||
// Send content_block_start
|
||||
emptyInput := map[string]interface{}{}
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &toolState.claudeContentIndex,
|
||||
@@ -891,7 +938,7 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
Type: "tool_use",
|
||||
Id: toolState.id,
|
||||
Name: toolState.name,
|
||||
Input: map[string]interface{}{}, // Empty input as per Claude spec
|
||||
Input: &emptyInput, // Empty input as per Claude spec
|
||||
},
|
||||
})
|
||||
|
||||
@@ -910,3 +957,42 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
|
||||
return responses
|
||||
}
|
||||
|
||||
// stripCchFromBillingHeader removes the dynamic cch field from x-anthropic-billing-header text
|
||||
// to enable caching. The cch value changes on every request, which would break prompt caching.
|
||||
// Example input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;"
|
||||
// Example output: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;"
|
||||
func stripCchFromBillingHeader(text string) string {
|
||||
const billingHeaderPrefix = "x-anthropic-billing-header:"
|
||||
|
||||
// Check if this is a billing header
|
||||
if !strings.HasPrefix(text, billingHeaderPrefix) {
|
||||
return text
|
||||
}
|
||||
|
||||
// Remove cch=xxx pattern (may appear with or without trailing semicolon)
|
||||
// Pattern: ; cch=<any-non-semicolon-chars> followed by ; or end of string
|
||||
result := text
|
||||
|
||||
// Try to find and remove ; cch=... pattern
|
||||
// We need to handle both "; cch=xxx;" and "; cch=xxx" (at end)
|
||||
for {
|
||||
cchIdx := strings.Index(result, "; cch=")
|
||||
if cchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the end of cch value (next semicolon or end of string)
|
||||
start := cchIdx + 2 // skip "; "
|
||||
end := strings.Index(result[start:], ";")
|
||||
if end == -1 {
|
||||
// cch is at the end, remove from "; cch=" to end
|
||||
result = result[:cchIdx]
|
||||
} else {
|
||||
// cch is followed by more content, remove "; cch=xxx" part
|
||||
result = result[:cchIdx] + result[start+end:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -388,6 +388,7 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
|
||||
t.Run("convert_tool_result_with_actual_error_data", func(t *testing.T) {
|
||||
// Test using the actual JSON data from the error log to ensure our fix works
|
||||
// This tests the fix for issue #3344 - text content alongside tool_result should be preserved
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
@@ -415,14 +416,20 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one tool message (the text content is included in the same message array)
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
// Should have two messages: tool message + user message with text content
|
||||
// This is the fix for issue #3344 - text content alongside tool_result is preserved
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// Should be tool message
|
||||
// First should be tool message
|
||||
toolMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "tool", toolMsg.Role)
|
||||
assert.Contains(t, toolMsg.Content, "three.js")
|
||||
assert.Equal(t, "toolu_vrtx_01UbCfwoTgoDBqbYEwkVaxd5", toolMsg.ToolCallId)
|
||||
|
||||
// Second should be user message with text content
|
||||
userMsg := openaiRequest.Messages[1]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
assert.Equal(t, "继续", userMsg.Content)
|
||||
})
|
||||
|
||||
t.Run("convert_multiple_tool_calls", func(t *testing.T) {
|
||||
@@ -617,7 +624,7 @@ func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
|
||||
// First content should be text
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", textContent.Text)
|
||||
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", *textContent.Text)
|
||||
|
||||
// Second content should be tool_use
|
||||
toolContent := claudeResponse.Content[1]
|
||||
@@ -627,7 +634,7 @@ func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
|
||||
|
||||
// Verify tool arguments
|
||||
require.NotNil(t, toolContent.Input)
|
||||
assert.Equal(t, "/Users/zhangty/git/higress/README.md", toolContent.Input["file_path"])
|
||||
assert.Equal(t, "/Users/zhangty/git/higress/README.md", (*toolContent.Input)["file_path"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -830,21 +837,147 @@ func TestClaudeToOpenAIConverter_ConvertReasoningResponseToClaude(t *testing.T)
|
||||
// First should be thinking
|
||||
thinkingContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "thinking", thinkingContent.Type)
|
||||
assert.Equal(t, "", thinkingContent.Signature) // OpenAI doesn't provide signature
|
||||
assert.Contains(t, thinkingContent.Thinking, "Let me think about this step by step")
|
||||
require.NotNil(t, thinkingContent.Signature)
|
||||
assert.Equal(t, "", *thinkingContent.Signature) // OpenAI doesn't provide signature
|
||||
require.NotNil(t, thinkingContent.Thinking)
|
||||
assert.Contains(t, *thinkingContent.Thinking, "Let me think about this step by step")
|
||||
|
||||
// Second should be text
|
||||
textContent := claudeResponse.Content[1]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
require.NotNil(t, textContent.Text)
|
||||
assert.Equal(t, tt.expectedText, *textContent.Text)
|
||||
} else {
|
||||
// Should only have text content
|
||||
assert.Len(t, claudeResponse.Content, 1)
|
||||
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
require.NotNil(t, textContent.Text)
|
||||
assert.Equal(t, tt.expectedText, *textContent.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_StripCchFromSystemMessage(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
t.Run("string_system_with_billing_header", func(t *testing.T) {
|
||||
// Test that cch field is stripped from string format system message
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1024,
|
||||
"system": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;"
|
||||
}
|
||||
],
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}]
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// First message should be system with cch stripped
|
||||
systemMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "system", systemMsg.Role)
|
||||
|
||||
// The system content should have cch removed
|
||||
contentArray, ok := systemMsg.Content.([]interface{})
|
||||
require.True(t, ok, "System content should be an array")
|
||||
require.Len(t, contentArray, 1)
|
||||
|
||||
contentMap, ok := contentArray[0].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", contentMap["type"])
|
||||
assert.Equal(t, "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;", contentMap["text"])
|
||||
assert.NotContains(t, contentMap["text"], "cch=")
|
||||
})
|
||||
|
||||
t.Run("plain_string_system_unchanged", func(t *testing.T) {
|
||||
// Test that normal system messages are not modified
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}]
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First message should be system with original content
|
||||
systemMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "system", systemMsg.Role)
|
||||
assert.Equal(t, "You are a helpful assistant.", systemMsg.Content)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStripCchFromBillingHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "billing header with cch at end",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "billing header with cch at end without trailing semicolon",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode",
|
||||
},
|
||||
{
|
||||
name: "billing header with cch in middle",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cch=abc123; cc_entrypoint=claude-vscode;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "billing header without cch",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "non-billing header text unchanged",
|
||||
input: "This is a normal system prompt",
|
||||
expected: "This is a normal system prompt",
|
||||
},
|
||||
{
|
||||
name: "empty string unchanged",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "billing header with multiple cch fields",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cch=first; cc_entrypoint=claude-vscode; cch=second;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := stripCchFromBillingHeader(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
|
||||
if c.isFailoverEnabled() {
|
||||
apiToken = c.GetGlobalRandomToken()
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
apiToken = c.GetOrSetTokenWithContext(ctx)
|
||||
}
|
||||
log.Debugf("Use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
@@ -30,7 +30,13 @@ const (
|
||||
)
|
||||
|
||||
type NonOpenAIStyleOptions struct {
|
||||
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
|
||||
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
|
||||
Thinking *thinkingParam `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type thinkingParam struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
BudgetToken int `json:"budget_token,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
|
||||
@@ -2,8 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path"
|
||||
@@ -151,6 +153,7 @@ const (
|
||||
protocolOriginal = "original"
|
||||
|
||||
roleSystem = "system"
|
||||
roleDeveloper = "developer"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
roleTool = "tool"
|
||||
@@ -193,6 +196,12 @@ type providerInitializer interface {
|
||||
var (
|
||||
errUnsupportedApiName = errors.New("unsupported API name")
|
||||
|
||||
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
||||
developerRoleSupportedProviders = map[string]bool{
|
||||
providerTypeOpenAI: true,
|
||||
providerTypeAzure: true,
|
||||
}
|
||||
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||||
providerTypeAzure: &azureProviderInitializer{},
|
||||
@@ -445,6 +454,12 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Claude Code 模式
|
||||
// @Description zh-CN 仅适用于Claude服务。启用后将伪装成Claude Code客户端发起请求,支持使用Claude Code的OAuth Token进行认证。
|
||||
claudeCodeMode bool `required:"false" yaml:"claudeCodeMode" json:"claudeCodeMode"`
|
||||
// @Title zh-CN 智谱AI服务域名
|
||||
// @Description zh-CN 仅适用于智谱AI服务。默认为 open.bigmodel.cn(中国),可配置为 api.z.ai(国际)
|
||||
zhipuDomain string `required:"false" yaml:"zhipuDomain" json:"zhipuDomain"`
|
||||
// @Title zh-CN 智谱AI Code Plan 模式
|
||||
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
||||
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -650,6 +665,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
|
||||
c.doubaoDomain = json.Get("doubaoDomain").String()
|
||||
c.claudeCodeMode = json.Get("claudeCodeMode").Bool()
|
||||
c.zhipuDomain = json.Get("zhipuDomain").String()
|
||||
c.zhipuCodePlanMode = json.Get("zhipuCodePlanMode").Bool()
|
||||
c.contextCleanupCommands = make([]string, 0)
|
||||
for _, cmd := range json.Get("contextCleanupCommands").Array() {
|
||||
if cmd.String() != "" {
|
||||
@@ -690,12 +707,45 @@ func (c *ProviderConfig) Validate() error {
|
||||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||||
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
|
||||
if ctxApiKey == nil {
|
||||
ctxApiKey = c.GetRandomToken()
|
||||
token := c.selectApiToken(ctx)
|
||||
ctxApiKey = token
|
||||
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
|
||||
}
|
||||
return ctxApiKey.(string)
|
||||
}
|
||||
|
||||
// selectApiToken selects an API token based on the request context
|
||||
// For stateful APIs, it uses consumer affinity if available
|
||||
func (c *ProviderConfig) selectApiToken(ctx wrapper.HttpContext) string {
|
||||
// Get API name from context if available
|
||||
ctxApiName := ctx.GetContext(CtxKeyApiName)
|
||||
var apiName string
|
||||
if ctxApiName != nil {
|
||||
// ctxApiName is of type ApiName, need to convert to string
|
||||
apiName = string(ctxApiName.(ApiName))
|
||||
}
|
||||
|
||||
// For stateful APIs, try to use consumer affinity
|
||||
if isStatefulAPI(apiName) {
|
||||
consumer := c.getConsumerFromContext(ctx)
|
||||
if consumer != "" {
|
||||
return c.GetTokenWithConsumerAffinity(ctx, consumer)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to random selection
|
||||
return c.GetRandomToken()
|
||||
}
|
||||
|
||||
// getConsumerFromContext retrieves the consumer identifier from the request context
|
||||
func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string {
|
||||
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
|
||||
if err == nil && consumer != "" {
|
||||
return consumer
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetRandomToken() string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
@@ -709,6 +759,50 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
}
|
||||
}
|
||||
|
||||
// isStatefulAPI checks if the given API name is a stateful API that requires consumer affinity
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
|
||||
// GetTokenWithConsumerAffinity selects an API token based on consumer affinity
|
||||
// If x-mse-consumer header is present and API is stateful, it will consistently select the same token
|
||||
func (c *ProviderConfig) GetTokenWithConsumerAffinity(ctx wrapper.HttpContext, consumer string) string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
switch count {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return apiTokens[0]
|
||||
default:
|
||||
// Use FNV-1a hash for consistent token selection
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(consumer))
|
||||
hashValue := h.Sum32()
|
||||
index := int(hashValue) % count
|
||||
if index < 0 {
|
||||
index += count
|
||||
}
|
||||
return apiTokens[index]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
@@ -838,6 +932,34 @@ func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// isDeveloperRoleSupported checks if the provider supports the "developer" role.
|
||||
func isDeveloperRoleSupported(providerType string) bool {
|
||||
return developerRoleSupportedProviders[providerType]
|
||||
}
|
||||
|
||||
// convertDeveloperRoleToSystem converts "developer" roles to "system" role in the request body.
|
||||
// This is used for providers that don't support the "developer" role.
|
||||
func convertDeveloperRoleToSystem(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for developer role conversion: %v", err)
|
||||
}
|
||||
|
||||
converted := false
|
||||
for i := range request.Messages {
|
||||
if request.Messages[i].Role == roleDeveloper {
|
||||
request.Messages[i].Role = roleSystem
|
||||
converted = true
|
||||
}
|
||||
}
|
||||
|
||||
if converted {
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
@@ -976,6 +1098,18 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// convert developer role to system role for providers that don't support it
|
||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||
body, err = convertDeveloperRoleToSystem(body)
|
||||
if err != nil {
|
||||
log.Warnf("[developerRole] failed to convert developer role to system: %v", err)
|
||||
// Continue processing even if conversion fails
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[developerRole] converted developer role to system for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// use openai protocol (either original openai or converted from claude)
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
|
||||
275
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
275
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsStatefulAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName string
|
||||
expected bool
|
||||
}{
|
||||
// Stateful APIs - should return true
|
||||
{
|
||||
name: "responses_api",
|
||||
apiName: string(ApiNameResponses),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "files_api",
|
||||
apiName: string(ApiNameFiles),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_api",
|
||||
apiName: string(ApiNameRetrieveFile),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_content_api",
|
||||
apiName: string(ApiNameRetrieveFileContent),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "batches_api",
|
||||
apiName: string(ApiNameBatches),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_batch_api",
|
||||
apiName: string(ApiNameRetrieveBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_batch_api",
|
||||
apiName: string(ApiNameCancelBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_jobs_api",
|
||||
apiName: string(ApiNameFineTuningJobs),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_fine_tuning_job_api",
|
||||
apiName: string(ApiNameRetrieveFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_events_api",
|
||||
apiName: string(ApiNameFineTuningJobEvents),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_checkpoints_api",
|
||||
apiName: string(ApiNameFineTuningJobCheckpoints),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_fine_tuning_job_api",
|
||||
apiName: string(ApiNameCancelFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "resume_fine_tuning_job_api",
|
||||
apiName: string(ApiNameResumeFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
// Non-stateful APIs - should return false
|
||||
{
|
||||
name: "chat_completion_api",
|
||||
apiName: string(ApiNameChatCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completion_api",
|
||||
apiName: string(ApiNameCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "embeddings_api",
|
||||
apiName: string(ApiNameEmbeddings),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "models_api",
|
||||
apiName: string(ApiNameModels),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "image_generation_api",
|
||||
apiName: string(ApiNameImageGeneration),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "audio_speech_api",
|
||||
apiName: string(ApiNameAudioSpeech),
|
||||
expected: false,
|
||||
},
|
||||
// Empty/unknown API - should return false
|
||||
{
|
||||
name: "empty_api_name",
|
||||
apiName: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unknown_api_name",
|
||||
apiName: "unknown/api",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isStatefulAPI(tt.apiName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiTokens []string
|
||||
consumer string
|
||||
wantEmpty bool
|
||||
wantToken string // If not empty, expected specific token (for single token case)
|
||||
}{
|
||||
{
|
||||
name: "no_tokens_returns_empty",
|
||||
apiTokens: []string{},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "nil_tokens_returns_empty",
|
||||
apiTokens: nil,
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single_token_always_returns_same_token",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer1",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "single_token_with_different_consumer",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer2",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_consistent_for_same_consumer",
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: false, // Will get one of the tokens, consistently
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_different_consumers_may_get_different_tokens",
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
consumer: "consumerA",
|
||||
wantEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: tt.apiTokens,
|
||||
}
|
||||
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tt.consumer)
|
||||
|
||||
if tt.wantEmpty {
|
||||
assert.Empty(t, result)
|
||||
} else if tt.wantToken != "" {
|
||||
assert.Equal(t, tt.wantToken, result)
|
||||
} else {
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, tt.apiTokens, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_Consistency(t *testing.T) {
|
||||
// Test that the same consumer always gets the same token (consistency)
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3", "token4", "token5"},
|
||||
}
|
||||
|
||||
t.Run("same_consumer_gets_same_token_repeatedly", func(t *testing.T) {
|
||||
consumer := "test-consumer"
|
||||
var firstResult string
|
||||
|
||||
// Call multiple times and verify consistency
|
||||
for i := 0; i < 10; i++ {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
if i == 0 {
|
||||
firstResult = result
|
||||
}
|
||||
assert.Equal(t, firstResult, result, "Consumer should consistently get the same token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different_consumers_distribute_across_tokens", func(t *testing.T) {
|
||||
// Use multiple consumers and verify they distribute across tokens
|
||||
consumers := []string{"consumer1", "consumer2", "consumer3", "consumer4", "consumer5", "consumer6", "consumer7", "consumer8", "consumer9", "consumer10"}
|
||||
tokenCounts := make(map[string]int)
|
||||
|
||||
for _, consumer := range consumers {
|
||||
token := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
tokenCounts[token]++
|
||||
}
|
||||
|
||||
// Verify all tokens returned are valid
|
||||
for token := range tokenCounts {
|
||||
assert.Contains(t, config.apiTokens, token)
|
||||
}
|
||||
|
||||
// With 10 consumers and 5 tokens, we expect some distribution
|
||||
// (not necessarily perfect distribution, but should use multiple tokens)
|
||||
assert.GreaterOrEqual(t, len(tokenCounts), 2, "Should use at least 2 different tokens")
|
||||
})
|
||||
|
||||
t.Run("empty_consumer_returns_empty_string", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
}
|
||||
result := config.GetTokenWithConsumerAffinity(nil, "")
|
||||
// Empty consumer still returns a token (hash of empty string)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, []string{"token1", "token2"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
|
||||
// Test that the hash function distributes consumers reasonably across tokens
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
}
|
||||
|
||||
// Test specific consumers to verify hash behavior
|
||||
testCases := []struct {
|
||||
consumer string
|
||||
expectValid bool
|
||||
}{
|
||||
{"user-alice", true},
|
||||
{"user-bob", true},
|
||||
{"user-charlie", true},
|
||||
{"service-api-v1", true},
|
||||
{"service-api-v2", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("consumer_"+tc.consumer, func(t *testing.T) {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tc.consumer)
|
||||
assert.True(t, tc.expectValid)
|
||||
assert.Contains(t, config.apiTokens, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,15 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
zhipuAiDomain = "open.bigmodel.cn"
|
||||
zhipuAiDefaultDomain = "open.bigmodel.cn"
|
||||
zhipuAiInternationalDomain = "api.z.ai"
|
||||
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
|
||||
zhipuAiCodePlanPath = "/api/coding/paas/v4/chat/completions"
|
||||
zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings"
|
||||
zhipuAiAnthropicMessagesPath = "/api/anthropic/v1/messages"
|
||||
)
|
||||
@@ -26,16 +30,20 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
func (m *zhipuAiProviderInitializer) DefaultCapabilities(codePlanMode bool) map[string]string {
|
||||
chatPath := zhipuAiChatCompletionPath
|
||||
if codePlanMode {
|
||||
chatPath = zhipuAiCodePlanPath
|
||||
}
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): zhipuAiChatCompletionPath,
|
||||
string(ApiNameChatCompletion): chatPath,
|
||||
string(ApiNameEmbeddings): zhipuAiEmbeddingsPath,
|
||||
// string(ApiNameAnthropicMessages): zhipuAiAnthropicMessagesPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities(config.zhipuCodePlanMode))
|
||||
return &zhipuAiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -65,13 +73,35 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
|
||||
// Use configured domain or default to China domain
|
||||
domain := m.config.zhipuDomain
|
||||
if domain == "" {
|
||||
domain = zhipuAiDefaultDomain
|
||||
}
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
// Check if reasoning_effort is set
|
||||
reasoningEffort := gjson.GetBytes(body, "reasoning_effort").String()
|
||||
if reasoningEffort != "" {
|
||||
// Add thinking config for ZhipuAI
|
||||
body, _ = sjson.SetBytes(body, "thinking", map[string]string{"type": "enabled"})
|
||||
// Remove reasoning_effort field as ZhipuAI doesn't recognize it
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
}
|
||||
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) || strings.Contains(path, zhipuAiCodePlanPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, zhipuAiEmbeddingsPath) {
|
||||
|
||||
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:多 API Token 配置(用于测试 consumer affinity)
|
||||
var multiTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-1", "sk-token-2", "sk-token-3"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:单 API Token 配置
|
||||
var singleTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-single-token"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunConsumerAffinityParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("multi token config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunConsumerAffinityOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 stateful API(responses)使用 consumer affinity
|
||||
t.Run("stateful api responses with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 使用 x-mse-consumer header
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-alice"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证 Authorization header 使用了其中一个 token
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(files)使用 consumer affinity
|
||||
t.Run("stateful api files with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/files"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-files"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(batches)使用 consumer affinity
|
||||
t.Run("stateful api batches with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/batches"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-batches"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(fine_tuning)使用 consumer affinity
|
||||
t.Run("stateful api fine_tuning with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/fine_tuning/jobs"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-finetuning"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试非 stateful API 正常工作
|
||||
t.Run("non stateful api chat completions works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-chat"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试无 x-mse-consumer header 时正常工作
|
||||
t.Run("stateful api without consumer header works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试单个 token 时始终使用该 token
|
||||
t.Run("single token always used", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(singleTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-test"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.Contains(t, authValue, "sk-single-token", "Single token should always be used")
|
||||
})
|
||||
|
||||
// 测试同一 consumer 多次请求获得相同 token(consumer affinity 一致性)
|
||||
t.Run("same consumer gets consistent token across requests", func(t *testing.T) {
|
||||
consumer := "consumer-consistency-test"
|
||||
var firstToken string
|
||||
|
||||
// 运行 5 次请求,验证同一个 consumer 始终获得相同的 token
|
||||
for i := 0; i < 5; i++ {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens")
|
||||
|
||||
if i == 0 {
|
||||
firstToken = authValue
|
||||
} else {
|
||||
require.Equal(t, firstToken, authValue, "Same consumer should get same token consistently (consumer affinity)")
|
||||
}
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
})
|
||||
|
||||
// 测试不同 consumer 可能获得不同 token
|
||||
t.Run("different consumers get tokens based on hash", func(t *testing.T) {
|
||||
tokens := make(map[string]string)
|
||||
|
||||
consumers := []string{"consumer-alpha", "consumer-beta", "consumer-gamma", "consumer-delta", "consumer-epsilon"}
|
||||
for _, consumer := range consumers {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
tokens[consumer] = authValue
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
|
||||
// 验证至少使用了多个不同的 token(hash 分布)
|
||||
uniqueTokens := make(map[string]bool)
|
||||
for _, token := range tokens {
|
||||
uniqueTokens[token] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(uniqueTokens), 2, "Different consumers should use at least 2 different tokens")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -24,6 +24,8 @@ description: AI可观测配置参考
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-------|------|-----|------------------------|
|
||||
| `use_default_attributes` | bool | 非必填 | false | 是否使用默认完整属性配置,包含 messages、answer、question 等所有字段。适用于调试、审计场景 |
|
||||
| `use_default_response_attributes` | bool | 非必填 | false | 是否使用轻量级默认属性配置(推荐),包含 model 和 token 统计,不缓冲流式响应体。适用于高并发生产环境 |
|
||||
| `attributes` | []Attribute | 非必填 | - | 用户希望记录在log/span中的信息 |
|
||||
| `disable_openai_usage` | bool | 非必填 | false | 非openai兼容协议时,model、token的支持非标,配置为true时可以避免报错 |
|
||||
| `value_length_limit` | int | 非必填 | 4000 | 记录的单个value的长度限制 |
|
||||
@@ -67,6 +69,7 @@ Attribute 配置说明:
|
||||
| 内置属性键 | 说明 | 适用场景 |
|
||||
|---------|------|---------|
|
||||
| `question` | 用户提问内容 | 支持 OpenAI/Claude 消息格式 |
|
||||
| `system` | 系统提示词 | 支持 Claude `/v1/messages` 的顶层 system 字段 |
|
||||
| `answer` | AI 回答内容 | 支持 OpenAI/Claude 消息格式,流式和非流式 |
|
||||
| `tool_calls` | 工具调用信息 | OpenAI/Claude 工具调用 |
|
||||
| `reasoning` | 推理过程 | OpenAI o1 等推理模型 |
|
||||
@@ -332,6 +335,195 @@ attributes:
|
||||
2. **性能分析**:分析推理 token 占比,评估推理模型的实际开销
|
||||
3. **使用统计**:细粒度统计各类 token 的使用情况
|
||||
|
||||
## 流式响应观测能力
|
||||
|
||||
流式(Streaming)响应是 AI 对话的常见场景,插件提供了完善的流式观测支持,能够正确拼接和提取流式响应中的关键信息。
|
||||
|
||||
### 流式响应的挑战
|
||||
|
||||
流式响应将完整内容拆分为多个 SSE chunk 逐步返回,例如:
|
||||
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"Hello"}}]}
|
||||
data: {"choices":[{"delta":{"content":" 👋"}}]}
|
||||
data: {"choices":[{"delta":{"content":"!"}}]}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
要获取完整的回答内容,需要将各个 chunk 中的 `delta.content` 拼接起来。
|
||||
|
||||
### 自动拼接机制
|
||||
|
||||
插件针对不同类型的内容提供了自动拼接能力:
|
||||
|
||||
| 内容类型 | 拼接方式 | 说明 |
|
||||
|---------|---------|------|
|
||||
| `answer` | 文本追加(append) | 将各 chunk 的 `delta.content` 按顺序拼接成完整回答 |
|
||||
| `reasoning` | 文本追加(append) | 将各 chunk 的 `delta.reasoning_content` 按顺序拼接 |
|
||||
| `tool_calls` | 按 index 组装 | 识别每个工具调用的 `index`,分别拼接各自的 `arguments` |
|
||||
|
||||
#### answer 和 reasoning 拼接示例
|
||||
|
||||
流式响应:
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"你好"}}]}
|
||||
data: {"choices":[{"delta":{"content":",我是"}}]}
|
||||
data: {"choices":[{"delta":{"content":"AI助手"}}]}
|
||||
```
|
||||
|
||||
最终提取的 `answer`:`"你好,我是AI助手"`
|
||||
|
||||
#### tool_calls 拼接示例
|
||||
|
||||
流式响应(多个并行工具调用):
|
||||
```
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_001","function":{"name":"get_weather"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":1,"id":"call_002","function":{"name":"get_time"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Beijing\"}"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"city\":\"Shanghai\"}"}}]}}]}
|
||||
```
|
||||
|
||||
最终提取的 `tool_calls`:
|
||||
```json
|
||||
[
|
||||
{"index":0,"id":"call_001","function":{"name":"get_weather","arguments":"{\"city\":\"Beijing\"}"}},
|
||||
{"index":1,"id":"call_002","function":{"name":"get_time","arguments":"{\"city\":\"Shanghai\"}"}}
|
||||
]
|
||||
```
|
||||
|
||||
### 使用默认配置快速启用
|
||||
|
||||
插件提供两种默认配置模式:
|
||||
|
||||
#### 轻量模式(推荐用于生产环境)
|
||||
|
||||
通过 `use_default_response_attributes: true` 启用轻量模式:
|
||||
|
||||
```yaml
|
||||
use_default_response_attributes: true
|
||||
```
|
||||
|
||||
此配置是**推荐的生产环境配置**,特别适合高并发、高延迟的场景:
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `model` | 模型名称(从请求体提取) |
|
||||
| `reasoning_tokens` | 推理 token 数 |
|
||||
| `cached_tokens` | 缓存命中 token 数 |
|
||||
| `input_token_details` | 输入 token 详情 |
|
||||
| `output_token_details` | 输出 token 详情 |
|
||||
|
||||
**为什么推荐轻量模式?**
|
||||
|
||||
LLM 请求有两个显著特点:**延迟高**(通常数秒到数十秒)和**请求体大**(多轮对话可能达到数百 KB 甚至 MB 级别)。
|
||||
|
||||
在高并发场景下,如果请求体和响应体都被缓存在内存中,积压的请求会占用大量内存:
|
||||
- 假设 QPS=100,平均延迟=10秒,请求体=500KB
|
||||
- 同时在处理的请求数 ≈ 100 × 10 = 1000 个
|
||||
- 如果缓存完整请求体+响应体:1000 × 1.5MB ≈ **1.5GB 内存**
|
||||
|
||||
轻量模式通过以下方式降低内存占用:
|
||||
- **缓冲请求体**:仅用于提取 `model` 字段(很小),不提取 `question`、`system`、`messages` 等大字段
|
||||
- **不缓冲流式响应体**:不提取 `answer`、`reasoning`、`tool_calls` 等需要完整响应的字段
|
||||
- **只统计 token**:从响应的 usage 字段提取 token 信息
|
||||
|
||||
**内存对比:**
|
||||
|
||||
| 场景 | 完整模式 | 轻量模式 |
|
||||
|------|----------|----------|
|
||||
| 单次请求 (1MB 请求 + 500KB 响应) | ~1.5MB | ~1MB(请求体) |
|
||||
| 高并发 (100 QPS, 10s 延迟) | ~1.5GB | ~1GB |
|
||||
| 超高并发 (1000 QPS, 10s 延迟) | ~15GB | ~10GB |
|
||||
|
||||
**注意**:轻量模式下 `chat_round` 字段会正常计算,`model` 会从请求体正常提取。
|
||||
|
||||
#### 完整模式
|
||||
|
||||
通过 `use_default_attributes: true` 可以一键启用完整的流式观测能力:
|
||||
|
||||
```yaml
|
||||
use_default_attributes: true
|
||||
```
|
||||
|
||||
此配置会自动记录以下字段,**但会缓冲完整的请求体和流式响应体**:
|
||||
|
||||
| 字段 | 说明 | 内存影响 |
|
||||
|------|------|----------|
|
||||
| `messages` | 完整对话历史 | ⚠️ 可能很大 |
|
||||
| `question` | 最后一条用户消息 | 需要缓冲请求体 |
|
||||
| `system` | 系统提示词 | 需要缓冲请求体 |
|
||||
| `answer` | AI 回答(自动拼接流式 chunk) | ⚠️ 需要缓冲响应体 |
|
||||
| `reasoning` | 推理过程(自动拼接流式 chunk) | ⚠️ 需要缓冲响应体 |
|
||||
| `tool_calls` | 工具调用(自动按 index 组装) | 需要缓冲响应体 |
|
||||
| `reasoning_tokens` | 推理 token 数 | 无 |
|
||||
| `cached_tokens` | 缓存命中 token 数 | 无 |
|
||||
| `input_token_details` | 输入 token 详情 | 无 |
|
||||
| `output_token_details` | 输出 token 详情 | 无 |
|
||||
|
||||
**注意**:完整模式适用于调试、审计等需要完整对话记录的场景,但在高并发生产环境可能消耗大量内存。
|
||||
|
||||
### 流式日志示例
|
||||
|
||||
启用默认配置后,一个流式请求的日志输出示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"answer": "2 plus 2 equals 4.",
|
||||
"question": "What is 2+2?",
|
||||
"response_type": "stream",
|
||||
"tool_calls": null,
|
||||
"reasoning": null,
|
||||
"model": "glm-4-flash",
|
||||
"input_token": 10,
|
||||
"output_token": 8,
|
||||
"llm_first_token_duration": 425,
|
||||
"llm_service_duration": 985,
|
||||
"chat_id": "chat_abc123"
|
||||
}
|
||||
```
|
||||
|
||||
包含工具调用的流式日志示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"answer": null,
|
||||
"question": "What's the weather in Beijing?",
|
||||
"response_type": "stream",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"Beijing\"}"
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "glm-4-flash",
|
||||
"input_token": 50,
|
||||
"output_token": 15,
|
||||
"llm_first_token_duration": 300,
|
||||
"llm_service_duration": 500
|
||||
}
|
||||
```
|
||||
|
||||
### 流式特有指标
|
||||
|
||||
流式响应会额外记录以下指标:
|
||||
|
||||
- `llm_first_token_duration`:从请求发出到收到首个 token 的时间(首字延迟)
|
||||
- `llm_stream_duration_count`:流式请求次数
|
||||
|
||||
可用于监控流式响应的用户体验:
|
||||
|
||||
```promql
|
||||
# 平均首字延迟
|
||||
irate(route_upstream_model_consumer_metric_llm_first_token_duration[5m])
|
||||
/
|
||||
irate(route_upstream_model_consumer_metric_llm_stream_duration_count[5m])
|
||||
```
|
||||
|
||||
## 调试
|
||||
|
||||
### 验证 ai_log 内容
|
||||
|
||||
@@ -105,6 +105,7 @@ const (
|
||||
BuiltinAnswerKey = "answer"
|
||||
BuiltinToolCallsKey = "tool_calls"
|
||||
BuiltinReasoningKey = "reasoning"
|
||||
BuiltinSystemKey = "system"
|
||||
BuiltinReasoningTokens = "reasoning_tokens"
|
||||
BuiltinCachedTokens = "cached_tokens"
|
||||
BuiltinInputTokenDetails = "input_token_details"
|
||||
@@ -115,6 +116,9 @@ const (
|
||||
QuestionPathOpenAI = "messages.@reverse.0.content"
|
||||
QuestionPathClaude = "messages.@reverse.0.content" // Claude uses same format
|
||||
|
||||
// System prompt paths (from request body)
|
||||
SystemPathClaude = "system" // Claude /v1/messages has system as a top-level field
|
||||
|
||||
// Answer paths (from response body - non-streaming)
|
||||
AnswerPathOpenAINonStreaming = "choices.0.message.content"
|
||||
AnswerPathClaudeNonStreaming = "content.0.text"
|
||||
@@ -123,10 +127,19 @@ const (
|
||||
AnswerPathOpenAIStreaming = "choices.0.delta.content"
|
||||
AnswerPathClaudeStreaming = "delta.text"
|
||||
|
||||
// Tool calls paths
|
||||
// Tool calls paths (OpenAI format)
|
||||
ToolCallsPathNonStreaming = "choices.0.message.tool_calls"
|
||||
ToolCallsPathStreaming = "choices.0.delta.tool_calls"
|
||||
|
||||
// Claude/Anthropic tool calls paths (streaming)
|
||||
ClaudeEventType = "type"
|
||||
ClaudeContentBlockType = "content_block.type"
|
||||
ClaudeContentBlockID = "content_block.id"
|
||||
ClaudeContentBlockName = "content_block.name"
|
||||
ClaudeContentBlockInput = "content_block.input"
|
||||
ClaudeDeltaPartialJSON = "delta.partial_json"
|
||||
ClaudeIndex = "index"
|
||||
|
||||
// Reasoning paths
|
||||
ReasoningPathNonStreaming = "choices.0.message.reasoning_content"
|
||||
ReasoningPathStreaming = "choices.0.delta.reasoning_content"
|
||||
@@ -136,6 +149,7 @@ const (
|
||||
)
|
||||
|
||||
// getDefaultAttributes returns the default attributes configuration for empty config
|
||||
// This includes all attributes but may consume significant memory for large conversations
|
||||
func getDefaultAttributes() []Attribute {
|
||||
return []Attribute{
|
||||
// Extract complete conversation history from request body
|
||||
@@ -150,13 +164,19 @@ func getDefaultAttributes() []Attribute {
|
||||
Key: BuiltinQuestionKey,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinSystemKey,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinAnswerKey,
|
||||
ApplyToLog: true,
|
||||
Rule: RuleAppend, // Streaming responses need to append content from all chunks
|
||||
},
|
||||
{
|
||||
Key: BuiltinReasoningKey,
|
||||
ApplyToLog: true,
|
||||
Rule: RuleAppend, // Streaming responses need to append content from all chunks
|
||||
},
|
||||
{
|
||||
Key: BuiltinToolCallsKey,
|
||||
@@ -183,6 +203,34 @@ func getDefaultAttributes() []Attribute {
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultResponseAttributes returns a lightweight default attributes configuration
|
||||
// for production environments with high concurrency and high latency.
|
||||
// - Buffers request body for model extraction (small, essential field)
|
||||
// - Does NOT extract large fields like question, system, messages
|
||||
// - Does NOT buffer streaming response body (no answer, reasoning, tool_calls)
|
||||
// - Only extracts token statistics from response context
|
||||
func getDefaultResponseAttributes() []Attribute {
|
||||
return []Attribute{
|
||||
// Token statistics (extracted from context, no body buffering needed)
|
||||
{
|
||||
Key: BuiltinReasoningTokens,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinCachedTokens,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinInputTokenDetails,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinOutputTokenDetails,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Default session ID headers in priority order
|
||||
var defaultSessionHeaders = []string{
|
||||
"x-openclaw-session-key",
|
||||
@@ -225,14 +273,18 @@ type ToolCallFunction struct {
|
||||
|
||||
// StreamingToolCallsBuffer holds the state for assembling streaming tool calls
|
||||
type StreamingToolCallsBuffer struct {
|
||||
ToolCalls map[int]*ToolCall // keyed by index
|
||||
ToolCalls map[int]*ToolCall // keyed by index (OpenAI format)
|
||||
InToolBlock map[int]bool // tracks which indices are in tool_use blocks (Claude format)
|
||||
ArgumentsBuffer map[int]string // buffers partial JSON arguments (Claude format)
|
||||
}
|
||||
|
||||
// extractStreamingToolCalls extracts and assembles tool calls from streaming response chunks
|
||||
// extractStreamingToolCalls extracts and assembles tool calls from streaming response chunks (OpenAI format)
|
||||
func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *StreamingToolCallsBuffer {
|
||||
if buffer == nil {
|
||||
buffer = &StreamingToolCallsBuffer{
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
InToolBlock: make(map[int]bool),
|
||||
ArgumentsBuffer: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,6 +325,86 @@ func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *S
|
||||
return buffer
|
||||
}
|
||||
|
||||
// extractClaudeStreamingToolCalls extracts and assembles tool calls from Claude/Anthropic streaming response chunks
|
||||
// Claude format uses events: content_block_start, content_block_delta, content_block_stop
|
||||
func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *StreamingToolCallsBuffer {
|
||||
if buffer == nil {
|
||||
buffer = &StreamingToolCallsBuffer{
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
InToolBlock: make(map[int]bool),
|
||||
ArgumentsBuffer: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||
for _, chunk := range chunks {
|
||||
// Get event type
|
||||
eventType := gjson.GetBytes(chunk, ClaudeEventType)
|
||||
if !eventType.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch eventType.String() {
|
||||
case "content_block_start":
|
||||
// Check if this is a tool_use block
|
||||
contentBlockType := gjson.GetBytes(chunk, ClaudeContentBlockType)
|
||||
if contentBlockType.Exists() && contentBlockType.String() == "tool_use" {
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
|
||||
// Create tool call entry
|
||||
tc := &ToolCall{Index: index}
|
||||
|
||||
// Extract id and name
|
||||
if id := gjson.GetBytes(chunk, ClaudeContentBlockID).String(); id != "" {
|
||||
tc.ID = id
|
||||
}
|
||||
if name := gjson.GetBytes(chunk, ClaudeContentBlockName).String(); name != "" {
|
||||
tc.Function.Name = name
|
||||
}
|
||||
tc.Type = "tool_use"
|
||||
|
||||
buffer.ToolCalls[index] = tc
|
||||
buffer.InToolBlock[index] = true
|
||||
buffer.ArgumentsBuffer[index] = ""
|
||||
|
||||
// Try to extract initial input if present
|
||||
if input := gjson.GetBytes(chunk, ClaudeContentBlockInput); input.Exists() {
|
||||
if inputMap, ok := input.Value().(map[string]interface{}); ok {
|
||||
if jsonBytes, err := json.Marshal(inputMap); err == nil {
|
||||
buffer.ArgumentsBuffer[index] = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Check if we're in a tool block
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
if buffer.InToolBlock[index] {
|
||||
// Accumulate partial JSON arguments
|
||||
partialJSON := gjson.GetBytes(chunk, ClaudeDeltaPartialJSON)
|
||||
if partialJSON.Exists() {
|
||||
buffer.ArgumentsBuffer[index] += partialJSON.String()
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize the tool call if we were in a tool block
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
if buffer.InToolBlock[index] {
|
||||
buffer.InToolBlock[index] = false
|
||||
|
||||
// Parse accumulated arguments and set them
|
||||
if tc, exists := buffer.ToolCalls[index]; exists {
|
||||
tc.Function.Arguments = buffer.ArgumentsBuffer[index]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buffer
|
||||
}
|
||||
|
||||
// getToolCallsFromBuffer converts the buffer to a sorted slice of tool calls
|
||||
func getToolCallsFromBuffer(buffer *StreamingToolCallsBuffer) []ToolCall {
|
||||
if buffer == nil || len(buffer.ToolCalls) == 0 {
|
||||
@@ -317,6 +449,8 @@ type AIStatisticsConfig struct {
|
||||
attributes []Attribute
|
||||
// If there exist attributes extracted from streaming body, chunks should be buffered
|
||||
shouldBufferStreamingBody bool
|
||||
// If there exist attributes extracted from request body, request body should be buffered
|
||||
shouldBufferRequestBody bool
|
||||
// If disableOpenaiUsage is true, model/input_token/output_token logs will be skipped
|
||||
disableOpenaiUsage bool
|
||||
valueLengthLimit int
|
||||
@@ -411,6 +545,8 @@ func isContentTypeEnabled(contentType string, enabledContentTypes []string) bool
|
||||
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
// Check if use_default_attributes is enabled
|
||||
useDefaultAttributes := configJson.Get("use_default_attributes").Bool()
|
||||
// Check if use_default_response_attributes is enabled (lightweight mode)
|
||||
useDefaultResponseAttributes := configJson.Get("use_default_response_attributes").Bool()
|
||||
|
||||
// Parse tracing span attributes setting.
|
||||
attributeConfigs := configJson.Get("attributes").Array()
|
||||
@@ -430,6 +566,13 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
config.valueLengthLimit = 10485760 // 10MB
|
||||
}
|
||||
log.Infof("Using default attributes configuration")
|
||||
} else if useDefaultResponseAttributes {
|
||||
config.attributes = getDefaultResponseAttributes()
|
||||
// Use a reasonable default for lightweight mode
|
||||
if !configJson.Get("value_length_limit").Exists() {
|
||||
config.valueLengthLimit = 4000
|
||||
}
|
||||
log.Infof("Using default response attributes configuration (lightweight mode)")
|
||||
} else {
|
||||
config.attributes = make([]Attribute, len(attributeConfigs))
|
||||
for i, attributeConfig := range attributeConfigs {
|
||||
@@ -439,15 +582,38 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
log.Errorf("parse config failed, %v", err)
|
||||
return err
|
||||
}
|
||||
if attribute.ValueSource == ResponseStreamingBody {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
if attribute.Rule != "" && attribute.Rule != RuleFirst && attribute.Rule != RuleReplace && attribute.Rule != RuleAppend {
|
||||
return errors.New("value of rule must be one of [nil, first, replace, append]")
|
||||
}
|
||||
config.attributes[i] = attribute
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any attribute needs request body or streaming body buffering
|
||||
for _, attribute := range config.attributes {
|
||||
// Check for request body buffering
|
||||
if attribute.ValueSource == RequestBody {
|
||||
config.shouldBufferRequestBody = true
|
||||
}
|
||||
// Check for streaming body buffering (explicitly configured)
|
||||
if attribute.ValueSource == ResponseStreamingBody {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
// For built-in attributes without explicit ValueSource, check default sources
|
||||
if attribute.ValueSource == "" && isBuiltinAttribute(attribute.Key) {
|
||||
defaultSources := getBuiltinAttributeDefaultSources(attribute.Key)
|
||||
for _, src := range defaultSources {
|
||||
if src == RequestBody {
|
||||
config.shouldBufferRequestBody = true
|
||||
}
|
||||
// Only answer/reasoning/tool_calls need actual body buffering
|
||||
// Token-related attributes are extracted from context, not from body
|
||||
if src == ResponseStreamingBody && needsBodyBuffering(attribute.Key) {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Metric settings
|
||||
config.counterMetrics = make(map[string]proxywasm.MetricCounter)
|
||||
|
||||
@@ -458,8 +624,8 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
pathSuffixes := configJson.Get("enable_path_suffixes").Array()
|
||||
config.enablePathSuffixes = make([]string, 0, len(pathSuffixes))
|
||||
|
||||
// If use_default_attributes is enabled and enable_path_suffixes is not configured, use default path suffixes
|
||||
if useDefaultAttributes && !configJson.Get("enable_path_suffixes").Exists() {
|
||||
// If use_default_attributes or use_default_response_attributes is enabled and enable_path_suffixes is not configured, use default path suffixes
|
||||
if (useDefaultAttributes || useDefaultResponseAttributes) && !configJson.Get("enable_path_suffixes").Exists() {
|
||||
config.enablePathSuffixes = []string{"/completions", "/messages"}
|
||||
log.Infof("Using default path suffixes: /completions, /messages")
|
||||
} else {
|
||||
@@ -527,6 +693,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) ty
|
||||
ctx.SetContext(ConsumerKey, consumer)
|
||||
}
|
||||
|
||||
// Always buffer request body to extract model field
|
||||
// This is essential for metrics and logging
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
|
||||
// Extract session ID from headers
|
||||
@@ -551,13 +719,21 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body)
|
||||
// Set span attributes for ARMS.
|
||||
// Only process request body if we need to extract attributes from it
|
||||
if config.shouldBufferRequestBody && len(body) > 0 {
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body)
|
||||
}
|
||||
|
||||
// Extract model from request body if available, otherwise try path
|
||||
requestModel := "UNKNOWN"
|
||||
if model := gjson.GetBytes(body, "model"); model.Exists() {
|
||||
requestModel = model.String()
|
||||
} else {
|
||||
if len(body) > 0 {
|
||||
if model := gjson.GetBytes(body, "model"); model.Exists() {
|
||||
requestModel = model.String()
|
||||
}
|
||||
}
|
||||
// If model not found in body, try to extract from path (Gemini style)
|
||||
if requestModel == "UNKNOWN" {
|
||||
requestPath := ctx.GetStringContext(RequestPath, "")
|
||||
if strings.Contains(requestPath, "generateContent") || strings.Contains(requestPath, "streamGenerateContent") { // Google Gemini GenerateContent
|
||||
reg := regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):\w+Content$`)
|
||||
@@ -569,21 +745,23 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
}
|
||||
ctx.SetContext(tokenusage.CtxKeyRequestModel, requestModel)
|
||||
setSpanAttribute(ArmsRequestModel, requestModel)
|
||||
// Set the number of conversation rounds
|
||||
|
||||
// Set the number of conversation rounds (only if body is available)
|
||||
userPromptCount := 0
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
// OpenAI and Claude/Anthropic format - both use "messages" array with "role" field
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
// OpenAI and Claude/Anthropic format - both use "messages" array with "role" field
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() {
|
||||
// Google Gemini GenerateContent
|
||||
for _, content := range contents.Array() {
|
||||
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() {
|
||||
// Google Gemini GenerateContent
|
||||
for _, content := range contents.Array() {
|
||||
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -680,14 +858,14 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
responseEndTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
// Set user defined log & span attributes from streaming body.
|
||||
// Always call setAttributeBySource even if shouldBufferStreamingBody is false,
|
||||
// because token-related attributes are extracted from context (not buffered body).
|
||||
var streamingBodyBuffer []byte
|
||||
if config.shouldBufferStreamingBody {
|
||||
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
|
||||
streamingBodyBuffer, _ = ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
||||
}
|
||||
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
|
||||
|
||||
// Write log
|
||||
debugLogAiLog(ctx)
|
||||
@@ -849,21 +1027,32 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
|
||||
// isBuiltinAttribute checks if the given key is a built-in attribute
|
||||
func isBuiltinAttribute(key string) bool {
|
||||
return key == BuiltinQuestionKey || key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey ||
|
||||
return key == BuiltinQuestionKey || key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey || key == BuiltinSystemKey ||
|
||||
key == BuiltinReasoningTokens || key == BuiltinCachedTokens ||
|
||||
key == BuiltinInputTokenDetails || key == BuiltinOutputTokenDetails
|
||||
}
|
||||
|
||||
// needsBodyBuffering checks if a built-in attribute needs body buffering
|
||||
// Token-related attributes are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// so they don't require buffering the response body.
|
||||
func needsBodyBuffering(key string) bool {
|
||||
return key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey
|
||||
}
|
||||
|
||||
// getBuiltinAttributeDefaultSources returns the default value_source(s) for a built-in attribute
|
||||
// Returns nil if the key is not a built-in attribute
|
||||
// Note: Token-related attributes are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// so they don't require body buffering even though they're processed during response phase.
|
||||
func getBuiltinAttributeDefaultSources(key string) []string {
|
||||
switch key {
|
||||
case BuiltinQuestionKey:
|
||||
case BuiltinQuestionKey, BuiltinSystemKey:
|
||||
return []string{RequestBody}
|
||||
case BuiltinAnswerKey, BuiltinToolCallsKey, BuiltinReasoningKey:
|
||||
return []string{ResponseStreamingBody, ResponseBody}
|
||||
case BuiltinReasoningTokens, BuiltinCachedTokens, BuiltinInputTokenDetails, BuiltinOutputTokenDetails:
|
||||
// Token details are only available after response is received
|
||||
// Token details are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// not from body parsing. We use ResponseStreamingBody/ResponseBody to indicate
|
||||
// they should be processed during response phase, but they don't require body buffering.
|
||||
return []string{ResponseStreamingBody, ResponseBody}
|
||||
default:
|
||||
return nil
|
||||
@@ -896,6 +1085,13 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon
|
||||
return value
|
||||
}
|
||||
}
|
||||
case BuiltinSystemKey:
|
||||
if source == RequestBody {
|
||||
// Try Claude /v1/messages format (system is a top-level field)
|
||||
if value := gjson.GetBytes(body, SystemPathClaude).Value(); value != nil && value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
case BuiltinAnswerKey:
|
||||
if source == ResponseStreamingBody {
|
||||
// Try OpenAI format first
|
||||
@@ -923,7 +1119,10 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon
|
||||
if existingBuffer, ok := ctx.GetContext(CtxStreamingToolCallsBuffer).(*StreamingToolCallsBuffer); ok {
|
||||
buffer = existingBuffer
|
||||
}
|
||||
// Try OpenAI format first
|
||||
buffer = extractStreamingToolCalls(body, buffer)
|
||||
// Also try Claude format (both formats can be checked)
|
||||
buffer = extractClaudeStreamingToolCalls(body, buffer)
|
||||
ctx.SetContext(CtxStreamingToolCallsBuffer, buffer)
|
||||
|
||||
// Also set tool_calls to user attributes so they appear in ai_log
|
||||
@@ -1047,6 +1246,9 @@ func debugLogAiLog(ctx wrapper.HttpContext) {
|
||||
if question := ctx.GetUserAttribute("question"); question != nil {
|
||||
userAttrs["question"] = question
|
||||
}
|
||||
if system := ctx.GetUserAttribute("system"); system != nil {
|
||||
userAttrs["system"] = system
|
||||
}
|
||||
if answer := ctx.GetUserAttribute("answer"); answer != nil {
|
||||
userAttrs["answer"] = answer
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
module jsonrpc-converter
|
||||
|
||||
go 1.24.3
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.4
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -15,6 +19,7 @@ require (
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
@@ -22,8 +27,10 @@ require (
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
|
||||
@@ -20,10 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.4 h1:/GqbzCw4oWqJc8UbKEfF94E3/+4CPZGbzxpKo2L3Ldk=
|
||||
github.com/higress-group/wasm-go v1.0.4/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
@@ -49,6 +49,8 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestTruncateString tests the truncateString function
|
||||
func TestTruncateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -14,6 +20,8 @@ func TestTruncateString(t *testing.T) {
|
||||
{"Short String", "Higress Is an AI-Native API Gateway", 1000, "Higress Is an AI-Native API Gateway"},
|
||||
{"Exact Length", "Higress Is an AI-Native API Gateway", 35, "Higress Is an AI-Native API Gateway"},
|
||||
{"Truncated String", "Higress Is an AI-Native API Gateway", 20, "Higress Is...(truncated)...PI Gateway"},
|
||||
{"Empty String", "", 10, ""},
|
||||
{"Single Char", "A", 10, "A"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -26,3 +34,248 @@ func TestTruncateString(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPreRequestStage tests the isPreRequestStage function
|
||||
func TestIsPreRequestStage(t *testing.T) {
|
||||
config := McpConverterConfig{Stage: ProcessRequest}
|
||||
require.True(t, isPreRequestStage(config))
|
||||
|
||||
config = McpConverterConfig{Stage: ProcessResponse}
|
||||
require.False(t, isPreRequestStage(config))
|
||||
}
|
||||
|
||||
// TestIsPreResponseStage tests the isPreResponseStage function
|
||||
func TestIsPreResponseStage(t *testing.T) {
|
||||
config := McpConverterConfig{Stage: ProcessResponse}
|
||||
require.True(t, isPreResponseStage(config))
|
||||
|
||||
config = McpConverterConfig{Stage: ProcessRequest}
|
||||
require.False(t, isPreResponseStage(config))
|
||||
}
|
||||
|
||||
// TestIsMethodAllowed tests the isMethodAllowed function
|
||||
func TestIsMethodAllowed(t *testing.T) {
|
||||
config := McpConverterConfig{AllowedMethods: []string{MethodToolList, MethodToolCall}}
|
||||
|
||||
require.True(t, isMethodAllowed(config, MethodToolList))
|
||||
require.True(t, isMethodAllowed(config, MethodToolCall))
|
||||
require.False(t, isMethodAllowed(config, "invalid/method"))
|
||||
}
|
||||
|
||||
// TestConstants tests the constant values
|
||||
func TestConstants(t *testing.T) {
|
||||
require.Equal(t, "x-envoy-jsonrpc-id", JsonRpcId)
|
||||
require.Equal(t, "x-envoy-jsonrpc-method", JsonRpcMethod)
|
||||
require.Equal(t, "x-envoy-jsonrpc-params", JsonRpcParams)
|
||||
require.Equal(t, "x-envoy-jsonrpc-result", JsonRpcResult)
|
||||
require.Equal(t, "x-envoy-jsonrpc-error", JsonRpcError)
|
||||
require.Equal(t, "x-envoy-mcp-tool-name", McpToolName)
|
||||
require.Equal(t, "x-envoy-mcp-tool-arguments", McpToolArguments)
|
||||
require.Equal(t, "x-envoy-mcp-tool-response", McpToolResponse)
|
||||
require.Equal(t, "x-envoy-mcp-tool-error", McpToolError)
|
||||
require.Equal(t, 4000, DefaultMaxHeaderLength)
|
||||
require.Equal(t, "tools/list", MethodToolList)
|
||||
require.Equal(t, "tools/call", MethodToolCall)
|
||||
require.Equal(t, ProcessStage("request"), ProcessRequest)
|
||||
require.Equal(t, ProcessStage("response"), ProcessResponse)
|
||||
}
|
||||
|
||||
// TestMcpConverterConfigDefaults tests config default values
|
||||
func TestMcpConverterConfigDefaults(t *testing.T) {
|
||||
config := McpConverterConfig{}
|
||||
require.Equal(t, 0, config.MaxHeaderLength)
|
||||
require.Equal(t, ProcessStage(""), config.Stage)
|
||||
require.Nil(t, config.AllowedMethods)
|
||||
}
|
||||
|
||||
// TestProcessStage tests ProcessStage type
|
||||
func TestProcessStage(t *testing.T) {
|
||||
require.Equal(t, ProcessStage("request"), ProcessRequest)
|
||||
require.Equal(t, ProcessStage("response"), ProcessResponse)
|
||||
}
|
||||
|
||||
// TestRemoveJsonRpcHeadersFunction tests removeJsonRpcHeaders function logic
|
||||
func TestRemoveJsonRpcHeadersFunction(t *testing.T) {
|
||||
headersToRemove := []string{
|
||||
JsonRpcId,
|
||||
JsonRpcMethod,
|
||||
JsonRpcParams,
|
||||
JsonRpcResult,
|
||||
McpToolName,
|
||||
McpToolArguments,
|
||||
McpToolResponse,
|
||||
McpToolError,
|
||||
}
|
||||
require.Len(t, headersToRemove, 8)
|
||||
}
|
||||
|
||||
// TestTruncateStringLong tests truncation of very long strings
|
||||
func TestTruncateStringLong(t *testing.T) {
|
||||
longString := ""
|
||||
for i := 0; i < 5000; i++ {
|
||||
longString += "a"
|
||||
}
|
||||
config := McpConverterConfig{MaxHeaderLength: 1000}
|
||||
result := truncateString(longString, config)
|
||||
require.Contains(t, result, "...(truncated)...")
|
||||
require.LessOrEqual(t, len(result), 1020)
|
||||
}
|
||||
|
||||
// TestTruncateStringWithSmallMaxLength tests truncation with small max length
|
||||
func TestTruncateStringWithSmallMaxLength(t *testing.T) {
|
||||
config := McpConverterConfig{MaxHeaderLength: 10}
|
||||
result := truncateString("This is a very long string", config)
|
||||
require.Contains(t, result, "...(truncated)...")
|
||||
}
|
||||
|
||||
// TestPluginInit tests plugin initialization
|
||||
func TestPluginInit(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
}
|
||||
|
||||
// TestProcessJsonRpcRequest tests processJsonRpcRequest function
|
||||
func TestProcessJsonRpcRequest(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
toolsListRequest := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(toolsListRequest))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessToolCallRequest tests processToolCallRequest function
|
||||
func TestProcessToolCallRequest(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
toolCallRequest := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"arg1": "value1"}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(toolCallRequest))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessJsonRpcResponse tests processJsonRpcResponse function
|
||||
func TestProcessJsonRpcResponse(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessResponse,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"tools": [{"name": "test_tool"}]
|
||||
}
|
||||
}`
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessToolListResponse tests processToolListResponse function
|
||||
func TestProcessToolListResponse(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessResponse,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"tools": [{"name": "test_tool"}]
|
||||
}
|
||||
}`
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ module mcp-router
|
||||
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
@@ -20,12 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57 h1:WhNdnKSDtAQrh4Yil8HAtbl7VW+WC85m7WS8kirnHAA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 h1:2wlbNpFJCQNbPBFYgswz7Zvxo9O3L0PH0AJxwiCc5lk=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
@@ -22,8 +22,8 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/consts"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -1,13 +1,16 @@
|
||||
module all-in-one
|
||||
module mcp-server
|
||||
|
||||
go 1.24.1
|
||||
|
||||
replace quark-search => ../quark-search
|
||||
|
||||
replace amap-tools => ../amap-tools
|
||||
replace (
|
||||
amap-tools => ../../mcp-servers/amap-tools
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
quark-search => ../../mcp-servers/quark-search
|
||||
)
|
||||
|
||||
require (
|
||||
amap-tools v0.0.0-00010101000000-000000000000
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
@@ -22,10 +22,6 @@ github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rR
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500 h1:4BKKZ3BreIaIGub88nlvzihTK1uJmZYYoQ7r7Xkgb5Q=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251223122142-eae11e33a500/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1 h1:+usoX0B1cwECTA2qf73IaLGyCIMVopIMev5cBWGgEZk=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115083526-76699a1df2c1/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
amap "amap-tools/tools"
|
||||
quark "quark-search/tools"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
@@ -1,14 +0,0 @@
|
||||
# Use a minimal base image as we only need to store the wasm file.
|
||||
FROM scratch
|
||||
|
||||
# Add build argument for the filter name. This will be passed by the Makefile.
|
||||
ARG FILTER_NAME
|
||||
|
||||
# Copy the compiled WASM binary into the image's root directory.
|
||||
# The wasm file will be named after the filter.
|
||||
COPY ${FILTER_NAME}/main.wasm /plugin.wasm
|
||||
|
||||
# Metadata
|
||||
LABEL org.opencontainers.image.title="${FILTER_NAME}"
|
||||
LABEL org.opencontainers.image.description="Higress MCP filter - ${FILTER_NAME}"
|
||||
LABEL org.opencontainers.image.source="https://github.com/alibaba/higress"
|
||||
@@ -1,54 +0,0 @@
|
||||
# MCP Filter Makefile
|
||||
|
||||
# Variables
|
||||
FILTER_NAME ?= mcp-router
|
||||
REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
|
||||
BUILD_TIME := $(shell date "+%Y%m%d-%H%M%S")
|
||||
COMMIT_ID := $(shell git rev-parse --short HEAD 2>/dev/null)
|
||||
IMAGE_TAG = $(if $(strip $(FILTER_VERSION)),${FILTER_VERSION},${BUILD_TIME}-${COMMIT_ID})
|
||||
IMG ?= ${REGISTRY}${FILTER_NAME}:${IMAGE_TAG}
|
||||
|
||||
# Default target
|
||||
.DEFAULT: build
|
||||
|
||||
build:
|
||||
@echo "Building WASM binary for filter: ${FILTER_NAME}..."
|
||||
@if [ ! -d "${FILTER_NAME}" ]; then \
|
||||
echo "Error: Filter directory '${FILTER_NAME}' not found."; \
|
||||
exit 1; \
|
||||
fi
|
||||
cd ${FILTER_NAME} && GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
|
||||
@echo ""
|
||||
@echo "Output WASM file: ${FILTER_NAME}/main.wasm"
|
||||
|
||||
# Build Docker image (depends on build target to ensure WASM binary exists)
|
||||
build-image: build
|
||||
@echo "Building Docker image for ${FILTER_NAME}..."
|
||||
docker build -t ${IMG} \
|
||||
--build-arg FILTER_NAME=${FILTER_NAME} \
|
||||
-f Dockerfile .
|
||||
@echo ""
|
||||
@echo "Image: ${IMG}"
|
||||
|
||||
# Build and push Docker image
|
||||
build-push: build-image
|
||||
docker push ${IMG}
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning build artifacts for filter: ${FILTER_NAME}..."
|
||||
rm -f ${FILTER_NAME}/main.wasm
|
||||
|
||||
# Help
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " build - Build WASM binary for a specific filter"
|
||||
@echo " build-image - Build Docker image"
|
||||
@echo " build-push - Build and push Docker image"
|
||||
@echo " clean - Remove build artifacts for a specific filter"
|
||||
@echo ""
|
||||
@echo "Variables:"
|
||||
@echo " FILTER_NAME - Name of the MCP filter to build (default: ${FILTER_NAME})"
|
||||
@echo " REGISTRY - Docker registry (default: ${REGISTRY})"
|
||||
@echo " FILTER_VERSION - Version tag for the image (default: timestamp-commit)"
|
||||
@echo " IMG - Full image name (default: ${IMG})"
|
||||
@@ -80,8 +80,8 @@ import (
|
||||
"net/http"
|
||||
|
||||
"my-mcp-server/config"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
// Define your tool structure with input parameters
|
||||
@@ -145,8 +145,8 @@ For better organization, you can create a separate file to load all your tools:
|
||||
package tools
|
||||
|
||||
import (
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
)
|
||||
|
||||
func LoadTools(server *mcp.MCPServer) server.Server {
|
||||
@@ -170,7 +170,7 @@ import (
|
||||
amap "amap-tools/tools"
|
||||
quark "quark-search/tools"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
@@ -375,7 +375,7 @@ package main
|
||||
import (
|
||||
"my-mcp-server/tools"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -2,9 +2,12 @@ module amap-tools
|
||||
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -23,6 +26,7 @@ require (
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
|
||||
@@ -17,7 +17,7 @@ package main
|
||||
import (
|
||||
"amap-tools/tools"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
)
|
||||
|
||||
func LoadTools(server *mcp.MCPServer) server.Server {
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = AroundSearchRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = BicyclingRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = DrivingRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = TransitIntegratedRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = WalkingRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = DistanceRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = GeoRequest{}
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = ReGeocodeRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = SearchDetailRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = TextSearchRequest{}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
|
||||
"amap-tools/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
)
|
||||
|
||||
var _ server.Tool = WeatherRequest{}
|
||||
|
||||
@@ -2,8 +2,11 @@ module quark-search
|
||||
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -16,7 +19,7 @@ require (
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 // indirect
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
@@ -24,6 +27,7 @@ require (
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
|
||||
@@ -17,7 +17,7 @@ package main
|
||||
import (
|
||||
"quark-search/tools"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
)
|
||||
|
||||
func LoadTools(server *mcp.MCPServer) server.Server {
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
|
||||
"quark-search/config"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/server"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package log
|
||||
|
||||
type Log interface {
|
||||
Trace(msg string)
|
||||
Tracef(format string, args ...interface{})
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Warn(msg string)
|
||||
Warnf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
Critical(msg string)
|
||||
Criticalf(format string, args ...interface{})
|
||||
ResetID(pluginID string)
|
||||
}
|
||||
|
||||
var pluginLog Log
|
||||
|
||||
func SetPluginLog(log Log) {
|
||||
pluginLog = log
|
||||
}
|
||||
|
||||
func Trace(msg string) {
|
||||
pluginLog.Trace(msg)
|
||||
}
|
||||
|
||||
func Tracef(format string, args ...interface{}) {
|
||||
pluginLog.Tracef(format, args...)
|
||||
}
|
||||
|
||||
func Debug(msg string) {
|
||||
pluginLog.Debug(msg)
|
||||
}
|
||||
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
pluginLog.Debugf(format, args...)
|
||||
}
|
||||
|
||||
func Info(msg string) {
|
||||
pluginLog.Info(msg)
|
||||
}
|
||||
|
||||
func Infof(format string, args ...interface{}) {
|
||||
pluginLog.Infof(format, args...)
|
||||
}
|
||||
|
||||
func Warn(msg string) {
|
||||
pluginLog.Warn(msg)
|
||||
}
|
||||
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
pluginLog.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func Error(msg string) {
|
||||
pluginLog.Error(msg)
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
pluginLog.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func Critical(msg string) {
|
||||
pluginLog.Critical(msg)
|
||||
}
|
||||
|
||||
func Criticalf(format string, args ...interface{}) {
|
||||
pluginLog.Criticalf(format, args...)
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type Category int
|
||||
|
||||
const (
|
||||
Route Category = iota
|
||||
Host
|
||||
Service
|
||||
RoutePrefix
|
||||
)
|
||||
|
||||
type MatchType int
|
||||
|
||||
const (
|
||||
Prefix MatchType = iota
|
||||
Exact
|
||||
Suffix
|
||||
)
|
||||
|
||||
const (
|
||||
RULES_KEY = "_rules_"
|
||||
MATCH_ROUTE_KEY = "_match_route_"
|
||||
MATCH_DOMAIN_KEY = "_match_domain_"
|
||||
MATCH_SERVICE_KEY = "_match_service_"
|
||||
MATCH_ROUTE_PREFIX_KEY = "_match_route_prefix_"
|
||||
)
|
||||
|
||||
type HostMatcher struct {
|
||||
matchType MatchType
|
||||
host string
|
||||
}
|
||||
|
||||
type RuleConfig[PluginConfig any] struct {
|
||||
category Category
|
||||
routes map[string]struct{}
|
||||
services map[string]struct{}
|
||||
routePrefixs map[string]struct{}
|
||||
hosts []HostMatcher
|
||||
config PluginConfig
|
||||
}
|
||||
|
||||
type RuleMatcher[PluginConfig any] struct {
|
||||
ruleConfig []RuleConfig[PluginConfig]
|
||||
globalConfig PluginConfig
|
||||
hasGlobalConfig bool
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) GetMatchConfig() (*PluginConfig, error) {
|
||||
host, err := proxywasm.GetHttpRequestHeader(":authority")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routeName, err := proxywasm.GetProperty([]string{"route_name"})
|
||||
if err != nil && err != types.ErrorStatusNotFound {
|
||||
return nil, err
|
||||
}
|
||||
serviceName, err := proxywasm.GetProperty([]string{"cluster_name"})
|
||||
if err != nil && err != types.ErrorStatusNotFound {
|
||||
return nil, err
|
||||
}
|
||||
for _, rule := range m.ruleConfig {
|
||||
// category == Host
|
||||
if rule.category == Host {
|
||||
if m.hostMatch(rule, host) {
|
||||
return &rule.config, nil
|
||||
}
|
||||
}
|
||||
// category == Route
|
||||
if rule.category == Route {
|
||||
if _, ok := rule.routes[string(routeName)]; ok {
|
||||
return &rule.config, nil
|
||||
}
|
||||
}
|
||||
// category == RoutePrefix
|
||||
if rule.category == RoutePrefix {
|
||||
for routePrefix := range rule.routePrefixs {
|
||||
if strings.HasPrefix(string(routeName), routePrefix) {
|
||||
return &rule.config, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// category == Cluster
|
||||
if m.serviceMatch(rule, string(serviceName)) {
|
||||
return &rule.config, nil
|
||||
}
|
||||
}
|
||||
if m.hasGlobalConfig {
|
||||
return &m.globalConfig, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *RuleMatcher[PluginConfig]) ParseRuleConfig(config gjson.Result,
|
||||
parsePluginConfig func(gjson.Result, *PluginConfig) error,
|
||||
parseOverrideConfig func(gjson.Result, PluginConfig, *PluginConfig) error) error {
|
||||
var rules []gjson.Result
|
||||
obj := config.Map()
|
||||
keyCount := len(obj)
|
||||
if keyCount == 0 {
|
||||
// enable globally for empty config
|
||||
m.hasGlobalConfig = true
|
||||
return parsePluginConfig(config, &m.globalConfig)
|
||||
}
|
||||
if rulesJson, ok := obj[RULES_KEY]; ok {
|
||||
rules = rulesJson.Array()
|
||||
keyCount--
|
||||
}
|
||||
var pluginConfig PluginConfig
|
||||
var globalConfigError error
|
||||
if keyCount > 0 {
|
||||
err := parsePluginConfig(config, &pluginConfig)
|
||||
if err != nil {
|
||||
globalConfigError = err
|
||||
} else {
|
||||
m.globalConfig = pluginConfig
|
||||
m.hasGlobalConfig = true
|
||||
}
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
if m.hasGlobalConfig {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("parse config failed, no valid rules; global config parse error:%v", globalConfigError)
|
||||
}
|
||||
for _, ruleJson := range rules {
|
||||
var (
|
||||
rule RuleConfig[PluginConfig]
|
||||
err error
|
||||
)
|
||||
if parseOverrideConfig != nil {
|
||||
err = parseOverrideConfig(ruleJson, m.globalConfig, &rule.config)
|
||||
} else {
|
||||
err = parsePluginConfig(ruleJson, &rule.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule.routes = m.parseRouteMatchConfig(ruleJson)
|
||||
rule.hosts = m.parseHostMatchConfig(ruleJson)
|
||||
rule.services = m.parseServiceMatchConfig(ruleJson)
|
||||
rule.routePrefixs = m.parseRoutePrefixMatchConfig(ruleJson)
|
||||
noRoute := len(rule.routes) == 0
|
||||
noHosts := len(rule.hosts) == 0
|
||||
noService := len(rule.services) == 0
|
||||
noRoutePrefix := len(rule.routePrefixs) == 0
|
||||
if boolToInt(noRoute)+boolToInt(noService)+boolToInt(noHosts)+boolToInt(noRoutePrefix) != 3 {
|
||||
return errors.New("there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.")
|
||||
}
|
||||
if !noRoute {
|
||||
rule.category = Route
|
||||
} else if !noHosts {
|
||||
rule.category = Host
|
||||
} else if !noService {
|
||||
rule.category = Service
|
||||
} else {
|
||||
rule.category = RoutePrefix
|
||||
}
|
||||
m.ruleConfig = append(m.ruleConfig, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) parseRouteMatchConfig(config gjson.Result) map[string]struct{} {
|
||||
keys := config.Get(MATCH_ROUTE_KEY).Array()
|
||||
routes := make(map[string]struct{})
|
||||
for _, item := range keys {
|
||||
routeName := item.String()
|
||||
if routeName != "" {
|
||||
routes[routeName] = struct{}{}
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) parseRoutePrefixMatchConfig(config gjson.Result) map[string]struct{} {
|
||||
keys := config.Get(MATCH_ROUTE_PREFIX_KEY).Array()
|
||||
routePrefixs := make(map[string]struct{})
|
||||
for _, item := range keys {
|
||||
routePrefix := item.String()
|
||||
if routePrefix != "" {
|
||||
routePrefixs[routePrefix] = struct{}{}
|
||||
}
|
||||
}
|
||||
return routePrefixs
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) parseServiceMatchConfig(config gjson.Result) map[string]struct{} {
|
||||
keys := config.Get(MATCH_SERVICE_KEY).Array()
|
||||
clusters := make(map[string]struct{})
|
||||
for _, item := range keys {
|
||||
clusterName := item.String()
|
||||
if clusterName != "" {
|
||||
clusters[clusterName] = struct{}{}
|
||||
}
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) parseHostMatchConfig(config gjson.Result) []HostMatcher {
|
||||
keys := config.Get(MATCH_DOMAIN_KEY).Array()
|
||||
var hostMatchers []HostMatcher
|
||||
for _, item := range keys {
|
||||
host := item.String()
|
||||
var hostMatcher HostMatcher
|
||||
if strings.HasPrefix(host, "*") {
|
||||
hostMatcher.matchType = Suffix
|
||||
hostMatcher.host = host[1:]
|
||||
} else if strings.HasSuffix(host, "*") {
|
||||
hostMatcher.matchType = Prefix
|
||||
hostMatcher.host = host[:len(host)-1]
|
||||
} else {
|
||||
hostMatcher.matchType = Exact
|
||||
hostMatcher.host = host
|
||||
}
|
||||
hostMatchers = append(hostMatchers, hostMatcher)
|
||||
}
|
||||
return hostMatchers
|
||||
}
|
||||
|
||||
func stripPortFromHost(reqHost string) string {
|
||||
// Port removing code is inspired by
|
||||
// https://github.com/envoyproxy/envoy/blob/v1.17.0/source/common/http/header_utility.cc#L219
|
||||
portStart := strings.LastIndexByte(reqHost, ':')
|
||||
if portStart != -1 {
|
||||
// According to RFC3986 v6 address is always enclosed in "[]".
|
||||
// section 3.2.2.
|
||||
v6EndIndex := strings.LastIndexByte(reqHost, ']')
|
||||
if v6EndIndex == -1 || v6EndIndex < portStart {
|
||||
if portStart+1 <= len(reqHost) {
|
||||
return reqHost[:portStart]
|
||||
}
|
||||
}
|
||||
}
|
||||
return reqHost
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) hostMatch(rule RuleConfig[PluginConfig], reqHost string) bool {
|
||||
reqHost = stripPortFromHost(reqHost)
|
||||
for _, hostMatch := range rule.hosts {
|
||||
switch hostMatch.matchType {
|
||||
case Suffix:
|
||||
if strings.HasSuffix(reqHost, hostMatch.host) {
|
||||
return true
|
||||
}
|
||||
case Prefix:
|
||||
if strings.HasPrefix(reqHost, hostMatch.host) {
|
||||
return true
|
||||
}
|
||||
case Exact:
|
||||
if reqHost == hostMatch.host {
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m RuleMatcher[PluginConfig]) serviceMatch(rule RuleConfig[PluginConfig], serviceName string) bool {
|
||||
parts := strings.Split(serviceName, "|")
|
||||
if len(parts) != 4 {
|
||||
return false
|
||||
}
|
||||
port := parts[1]
|
||||
fqdn := parts[3]
|
||||
for configServiceName := range rule.services {
|
||||
colonIndex := strings.LastIndexByte(configServiceName, ':')
|
||||
if colonIndex != -1 && fqdn == string(configServiceName[:colonIndex]) && port == string(configServiceName[colonIndex+1:]) {
|
||||
return true
|
||||
} else if fqdn == string(configServiceName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,438 +0,0 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type customConfig struct {
|
||||
name string
|
||||
age int64
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *customConfig) error {
|
||||
config.name = json.Get("name").String()
|
||||
config.age = json.Get("age").Int()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHostMatch(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
config RuleConfig[customConfig]
|
||||
host string
|
||||
result bool
|
||||
}{
|
||||
{
|
||||
name: "prefix",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Prefix,
|
||||
host: "www.",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "www.test.com",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "prefix failed",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Prefix,
|
||||
host: "www.",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "test.com",
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
name: "suffix",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Suffix,
|
||||
host: ".example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "www.example.com",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "suffix failed",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Suffix,
|
||||
host: ".example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "example.com",
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
name: "exact",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Exact,
|
||||
host: "www.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "www.example.com",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "exact failed",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Exact,
|
||||
host: "www.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "example.com",
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
name: "exact port",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Exact,
|
||||
host: "www.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "www.example.com:8080",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "any",
|
||||
config: RuleConfig[customConfig]{
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Suffix,
|
||||
host: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
host: "www.example.com",
|
||||
result: true,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var m RuleMatcher[customConfig]
|
||||
assert.Equal(t, c.result, m.hostMatch(c.config, c.host))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceMatch(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
config RuleConfig[customConfig]
|
||||
service string
|
||||
result bool
|
||||
}{
|
||||
{
|
||||
name: "fqdn",
|
||||
config: RuleConfig[customConfig]{
|
||||
services: map[string]struct{}{
|
||||
"qwen.dns": {},
|
||||
},
|
||||
},
|
||||
service: "outbound|443||qwen.dns",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "fqdn with port",
|
||||
config: RuleConfig[customConfig]{
|
||||
services: map[string]struct{}{
|
||||
"qwen.dns:443": {},
|
||||
},
|
||||
},
|
||||
service: "outbound|443||qwen.dns",
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "not match",
|
||||
config: RuleConfig[customConfig]{
|
||||
services: map[string]struct{}{
|
||||
"moonshot.dns:443": {},
|
||||
},
|
||||
},
|
||||
service: "outbound|443||qwen.dns",
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
name: "error config format",
|
||||
config: RuleConfig[customConfig]{
|
||||
services: map[string]struct{}{
|
||||
"qwen.dns:": {},
|
||||
},
|
||||
},
|
||||
service: "outbound|443||qwen.dns",
|
||||
result: false,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var m RuleMatcher[customConfig]
|
||||
assert.Equal(t, c.result, m.serviceMatch(c.config, c.service))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRuleConfig(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
config string
|
||||
errMsg string
|
||||
expected RuleMatcher[customConfig]
|
||||
}{
|
||||
{
|
||||
name: "global config",
|
||||
config: `{"name":"john", "age":18}`,
|
||||
expected: RuleMatcher[customConfig]{
|
||||
globalConfig: customConfig{
|
||||
name: "john",
|
||||
age: 18,
|
||||
},
|
||||
hasGlobalConfig: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rules config",
|
||||
config: `{"_rules_":[{"_match_domain_":["*.example.com","www.*","*","www.abc.com"],"name":"john", "age":18},{"_match_route_":["test1","test2"],"name":"ann", "age":16},{"_match_service_":["test1.dns","test2.static:8080"],"name":"ann", "age":16},{"_match_route_prefix_":["api1","api2"],"name":"ann", "age":16}]}`,
|
||||
expected: RuleMatcher[customConfig]{
|
||||
ruleConfig: []RuleConfig[customConfig]{
|
||||
{
|
||||
category: Host,
|
||||
hosts: []HostMatcher{
|
||||
{
|
||||
matchType: Suffix,
|
||||
host: ".example.com",
|
||||
},
|
||||
{
|
||||
matchType: Prefix,
|
||||
host: "www.",
|
||||
},
|
||||
{
|
||||
matchType: Suffix,
|
||||
host: "",
|
||||
},
|
||||
{
|
||||
matchType: Exact,
|
||||
host: "www.abc.com",
|
||||
},
|
||||
},
|
||||
routes: map[string]struct{}{},
|
||||
services: map[string]struct{}{},
|
||||
routePrefixs: map[string]struct{}{},
|
||||
config: customConfig{
|
||||
name: "john",
|
||||
age: 18,
|
||||
},
|
||||
},
|
||||
{
|
||||
category: Route,
|
||||
routes: map[string]struct{}{
|
||||
"test1": {},
|
||||
"test2": {},
|
||||
},
|
||||
services: map[string]struct{}{},
|
||||
routePrefixs: map[string]struct{}{},
|
||||
config: customConfig{
|
||||
name: "ann",
|
||||
age: 16,
|
||||
},
|
||||
},
|
||||
{
|
||||
category: Service,
|
||||
routes: map[string]struct{}{},
|
||||
services: map[string]struct{}{
|
||||
"test1.dns": {},
|
||||
"test2.static:8080": {},
|
||||
},
|
||||
routePrefixs: map[string]struct{}{},
|
||||
config: customConfig{
|
||||
name: "ann",
|
||||
age: 16,
|
||||
},
|
||||
},
|
||||
{
|
||||
category: RoutePrefix,
|
||||
routes: map[string]struct{}{},
|
||||
services: map[string]struct{}{},
|
||||
routePrefixs: map[string]struct{}{
|
||||
"api1": {},
|
||||
"api2": {},
|
||||
},
|
||||
config: customConfig{
|
||||
name: "ann",
|
||||
age: 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rule",
|
||||
config: `{"_rules_":[]}`,
|
||||
errMsg: "parse config failed, no valid rules; global config parse error:<nil>",
|
||||
},
|
||||
{
|
||||
name: "invalid rule",
|
||||
config: `{"_rules_":[{"_match_domain_":["*"],"_match_route_":["test"]}]}`,
|
||||
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
|
||||
},
|
||||
{
|
||||
name: "invalid rule",
|
||||
config: `{"_rules_":[{"_match_domain_":["*"],"_match_service_":["test.dns"]}]}`,
|
||||
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
|
||||
},
|
||||
{
|
||||
name: "invalid rule",
|
||||
config: `{"_rules_":[{"age":16}]}`,
|
||||
errMsg: "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var actual RuleMatcher[customConfig]
|
||||
err := actual.ParseRuleConfig(gjson.Parse(c.config), parseConfig, nil)
|
||||
if err != nil {
|
||||
if c.errMsg == "" {
|
||||
t.Errorf("parse failed: %v", err)
|
||||
}
|
||||
if err.Error() != c.errMsg {
|
||||
t.Errorf("expect err: %s, actual err: %s", c.errMsg,
|
||||
err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.Equal(t, c.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type completeConfig struct {
|
||||
// global config
|
||||
consumers []string
|
||||
// rule config
|
||||
allow []string
|
||||
}
|
||||
|
||||
func parseGlobalConfig(json gjson.Result, global *completeConfig) error {
|
||||
if json.Get("consumers").Exists() && json.Get("allow").Exists() {
|
||||
return errors.New("consumers and allow should not be configured at the same level")
|
||||
}
|
||||
|
||||
for _, item := range json.Get("consumers").Array() {
|
||||
global.consumers = append(global.consumers, item.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseOverrideRuleConfig(json gjson.Result, global completeConfig, config *completeConfig) error {
|
||||
if json.Get("consumers").Exists() && json.Get("allow").Exists() {
|
||||
return errors.New("consumers and allow should not be configured at the same level")
|
||||
}
|
||||
|
||||
// override config via global
|
||||
*config = global
|
||||
|
||||
for _, item := range json.Get("allow").Array() {
|
||||
config.allow = append(config.allow, item.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestParseOverrideConfig(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
config string
|
||||
errMsg string
|
||||
expected RuleMatcher[completeConfig]
|
||||
}{
|
||||
{
|
||||
name: "override rule config",
|
||||
config: `{"consumers":["c1","c2","c3"],"_rules_":[{"_match_route_":["r1","r2"],"allow":["c1","c3"]}]}`,
|
||||
expected: RuleMatcher[completeConfig]{
|
||||
ruleConfig: []RuleConfig[completeConfig]{
|
||||
{
|
||||
category: Route,
|
||||
routes: map[string]struct{}{
|
||||
"r1": {},
|
||||
"r2": {},
|
||||
},
|
||||
services: map[string]struct{}{},
|
||||
routePrefixs: map[string]struct{}{},
|
||||
config: completeConfig{
|
||||
consumers: []string{"c1", "c2", "c3"},
|
||||
allow: []string{"c1", "c3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
globalConfig: completeConfig{
|
||||
consumers: []string{"c1", "c2", "c3"},
|
||||
},
|
||||
hasGlobalConfig: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
config: `{"consumers":["c1","c2","c3"],"allow":["c1"]}`,
|
||||
errMsg: "parse config failed, no valid rules; global config parse error:consumers and allow should not be configured at the same level",
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
config: `{"_rules_":[{"_match_route_":["r1","r2"],"consumers":["c1","c2"],"allow":["c1"]}]}`,
|
||||
errMsg: "consumers and allow should not be configured at the same level",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var actual RuleMatcher[completeConfig]
|
||||
err := actual.ParseRuleConfig(gjson.Parse(c.config), parseGlobalConfig, parseOverrideRuleConfig)
|
||||
if err != nil {
|
||||
if c.errMsg == "" {
|
||||
t.Errorf("parse failed: %v", err)
|
||||
}
|
||||
if err.Error() != c.errMsg {
|
||||
t.Errorf("expect err: %s, actual err: %s", c.errMsg, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.Equal(t, c.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package matcher
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,17 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wrapper
|
||||
package consts
|
||||
|
||||
import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
const (
|
||||
ToolSetNameSplitter = "___"
|
||||
)
|
||||
|
||||
func IsResponseFromUpstream() bool {
|
||||
if codeDetails, err := proxywasm.GetProperty([]string{"response", "code_details"}); err == nil {
|
||||
return string(codeDetails) == "via_upstream"
|
||||
} else {
|
||||
proxywasm.LogErrorf("get response code details failed: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
353
plugins/wasm-go/pkg/mcp/filter/plugin.go
Normal file
353
plugins/wasm-go/pkg/mcp/filter/plugin.go
Normal file
@@ -0,0 +1,353 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
||||
)
|
||||
|
||||
type HTTPFilterF func(context wrapper.HttpContext, config any, headers [][2]string, body []byte) types.Action
|
||||
|
||||
type ToolCallRequestFilterF func(context wrapper.HttpContext, config any, toolName string, toolArgs gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type ToolCallResponseFilterF func(context wrapper.HttpContext, config any, isError bool, content gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type ToolListResponseFilterF func(context wrapper.HttpContext, config any, tools gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type JsonRpcRequestFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type JsonRpcResponseFilterF func(context wrapper.HttpContext, config any, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type Context struct {
|
||||
filterName string
|
||||
httpRequestFilter HTTPFilterF
|
||||
httpResponseFilter HTTPFilterF
|
||||
jsonRpcRequestFilter JsonRpcRequestFilterF
|
||||
jsonRpcResponseFilter JsonRpcResponseFilterF
|
||||
toolCallRequestFilter ToolCallRequestFilterF
|
||||
toolCallResponseFilter ToolCallResponseFilterF
|
||||
toolListResponseFilter ToolListResponseFilterF
|
||||
parseFilterConfig ParseFilterConfigF
|
||||
parseFilterRuleOverrideConfig ParseFilterRuleOverrideConfigF
|
||||
}
|
||||
|
||||
type CtxOption interface {
|
||||
Apply(*Context)
|
||||
}
|
||||
|
||||
var globalContext Context
|
||||
|
||||
type ParseFilterConfigF func(configBytes []byte, filterConfig *any) error
|
||||
|
||||
type ParseFilterRuleOverrideConfigF func(configBytes []byte, filterGlobalConfig any, filterConfig *any) error
|
||||
|
||||
type setConfigParserOption struct {
|
||||
f ParseFilterConfigF
|
||||
g ParseFilterRuleOverrideConfigF
|
||||
}
|
||||
|
||||
func SetConfigParser(f ParseFilterConfigF) CtxOption {
|
||||
return &setConfigParserOption{
|
||||
f: f,
|
||||
}
|
||||
}
|
||||
|
||||
func SetConfigOverrideParser(f ParseFilterConfigF, g ParseFilterRuleOverrideConfigF) CtxOption {
|
||||
return &setConfigParserOption{
|
||||
f: f,
|
||||
g: g,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *setConfigParserOption) Apply(ctx *Context) {
|
||||
ctx.parseFilterConfig = o.f
|
||||
ctx.parseFilterRuleOverrideConfig = o.g
|
||||
}
|
||||
|
||||
type filterNameOption struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func FilterName(name string) CtxOption {
|
||||
return &filterNameOption{name}
|
||||
}
|
||||
|
||||
func (o *filterNameOption) Apply(ctx *Context) {
|
||||
ctx.filterName = o.name
|
||||
}
|
||||
|
||||
type setJsonRpcRequestFilterOption struct {
|
||||
f JsonRpcRequestFilterF
|
||||
}
|
||||
|
||||
func SetJsonRpcRequestFilter(f JsonRpcRequestFilterF) CtxOption {
|
||||
return &setJsonRpcRequestFilterOption{f}
|
||||
}
|
||||
|
||||
func (o *setJsonRpcRequestFilterOption) Apply(ctx *Context) {
|
||||
ctx.jsonRpcRequestFilter = o.f
|
||||
}
|
||||
|
||||
type setJsonRpcResponseFilterOption struct {
|
||||
f JsonRpcResponseFilterF
|
||||
}
|
||||
|
||||
func SetJsonRpcResponseFilter(f JsonRpcResponseFilterF) CtxOption {
|
||||
return &setJsonRpcResponseFilterOption{f}
|
||||
}
|
||||
|
||||
func (o *setJsonRpcResponseFilterOption) Apply(ctx *Context) {
|
||||
ctx.jsonRpcResponseFilter = o.f
|
||||
}
|
||||
|
||||
type setFallbackHTTPRequestFilterOption struct {
|
||||
f HTTPFilterF
|
||||
}
|
||||
|
||||
func SetFallbackHTTPRequestFilter(f HTTPFilterF) CtxOption {
|
||||
return &setFallbackHTTPRequestFilterOption{f}
|
||||
}
|
||||
|
||||
func (o *setFallbackHTTPRequestFilterOption) Apply(ctx *Context) {
|
||||
ctx.httpRequestFilter = o.f
|
||||
}
|
||||
|
||||
type setFallbackHTTPResponseFilterOption struct {
|
||||
f HTTPFilterF
|
||||
}
|
||||
|
||||
func SetFallbackHTTPResponseFilter(f HTTPFilterF) CtxOption {
|
||||
return &setFallbackHTTPResponseFilterOption{f}
|
||||
}
|
||||
|
||||
func (o *setFallbackHTTPResponseFilterOption) Apply(ctx *Context) {
|
||||
ctx.httpResponseFilter = o.f
|
||||
}
|
||||
|
||||
type toolCallRequestFilterOption struct {
|
||||
f ToolCallRequestFilterF
|
||||
}
|
||||
|
||||
func SetToolCallRequestFilter(f ToolCallRequestFilterF) CtxOption {
|
||||
return &toolCallRequestFilterOption{f: f}
|
||||
}
|
||||
|
||||
func (o *toolCallRequestFilterOption) Apply(ctx *Context) {
|
||||
ctx.toolCallRequestFilter = o.f
|
||||
}
|
||||
|
||||
type toolCallResponseFilterOption struct {
|
||||
f ToolCallResponseFilterF
|
||||
}
|
||||
|
||||
func SetToolCallResponseFilter(f ToolCallResponseFilterF) CtxOption {
|
||||
return &toolCallResponseFilterOption{f: f}
|
||||
}
|
||||
|
||||
func (o *toolCallResponseFilterOption) Apply(ctx *Context) {
|
||||
ctx.toolCallResponseFilter = o.f
|
||||
}
|
||||
|
||||
type toolListResponseFilterOption struct {
|
||||
f ToolListResponseFilterF
|
||||
}
|
||||
|
||||
func SetToolListResponseFilter(f ToolListResponseFilterF) CtxOption {
|
||||
return &toolListResponseFilterOption{f: f}
|
||||
}
|
||||
|
||||
func (o *toolListResponseFilterOption) Apply(ctx *Context) {
|
||||
ctx.toolListResponseFilter = o.f
|
||||
}
|
||||
|
||||
func Load(options ...CtxOption) {
|
||||
for _, opt := range options {
|
||||
opt.Apply(&globalContext)
|
||||
}
|
||||
}
|
||||
|
||||
func Initialize() {
|
||||
if globalContext.filterName == "" {
|
||||
panic("FilterName not set")
|
||||
}
|
||||
if globalContext.parseFilterConfig == nil {
|
||||
panic("SetConfigParser not set")
|
||||
}
|
||||
var configOption wrapper.CtxOption[mcpFilterConfig]
|
||||
if globalContext.parseFilterRuleOverrideConfig == nil {
|
||||
configOption = wrapper.ParseRawConfig(parseRawConfig)
|
||||
} else {
|
||||
configOption = wrapper.ParseOverrideRawConfig(parseGlobalConfig, parseOverrideConfig)
|
||||
}
|
||||
wrapper.SetCtx(
|
||||
globalContext.filterName,
|
||||
configOption,
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
type mcpFilterConfig struct {
|
||||
config any
|
||||
httpRequestHandler HTTPFilterF
|
||||
httpResponseHandler HTTPFilterF
|
||||
jsonRpcRequestHandler utils.JsonRpcRequestHandler
|
||||
jsonRpcResponseHandler utils.JsonRpcResponseHandler
|
||||
}
|
||||
|
||||
func installHandler(config *mcpFilterConfig) {
|
||||
config.httpRequestHandler = globalContext.httpRequestFilter
|
||||
config.httpResponseHandler = globalContext.httpResponseFilter
|
||||
bizConfig := config.config
|
||||
if globalContext.jsonRpcRequestFilter != nil || globalContext.toolCallRequestFilter != nil {
|
||||
config.jsonRpcRequestHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action {
|
||||
if globalContext.jsonRpcRequestFilter != nil {
|
||||
ret := globalContext.jsonRpcRequestFilter(context, bizConfig, id, method, params, rawBody)
|
||||
if ret != types.ActionContinue {
|
||||
return ret
|
||||
}
|
||||
}
|
||||
context.SetContext("JSONRPC_METHOD", method)
|
||||
if method == "tools/call" && globalContext.toolCallRequestFilter != nil {
|
||||
toolName := params.Get("name").String()
|
||||
toolArgs := params.Get("arguments")
|
||||
return globalContext.toolCallRequestFilter(context, bizConfig, toolName, toolArgs, rawBody)
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
if globalContext.jsonRpcResponseFilter != nil || globalContext.toolListResponseFilter != nil || globalContext.toolCallResponseFilter != nil {
|
||||
config.jsonRpcResponseHandler = func(context wrapper.HttpContext, id utils.JsonRpcID, result, error gjson.Result, rawBody []byte) types.Action {
|
||||
if globalContext.jsonRpcResponseFilter != nil {
|
||||
ret := globalContext.jsonRpcResponseFilter(context, bizConfig, id, result, error, rawBody)
|
||||
if ret != types.ActionContinue {
|
||||
return ret
|
||||
}
|
||||
}
|
||||
method := context.GetStringContext("JSONRPC_METHOD", "")
|
||||
if method == "tools/list" && globalContext.toolListResponseFilter != nil {
|
||||
return globalContext.toolListResponseFilter(context, bizConfig, result.Get("tools"), rawBody)
|
||||
}
|
||||
if method == "tools/call" && globalContext.toolCallResponseFilter != nil {
|
||||
return globalContext.toolCallResponseFilter(context, bizConfig, result.Get("isError").Bool(), result.Get("content"), rawBody)
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
log.Debugf("installHandler called, config is: %#v", config)
|
||||
}
|
||||
|
||||
func parseRawConfig(configBytes []byte, config *mcpFilterConfig) error {
|
||||
err := globalContext.parseFilterConfig(configBytes, &config.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
installHandler(config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseGlobalConfig(configBytes []byte, config *mcpFilterConfig) error {
|
||||
err := globalContext.parseFilterConfig(configBytes, &config.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseOverrideConfig(configBytes []byte, global mcpFilterConfig, config *mcpFilterConfig) error {
|
||||
err := globalContext.parseFilterRuleOverrideConfig(configBytes, global.config, &config.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
installHandler(config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action {
|
||||
log.Debugf("onHttpRequestHeaders called")
|
||||
if !ctx.HasRequestBody() || (config.httpRequestHandler == nil && config.jsonRpcRequestHandler == nil) {
|
||||
log.Debugf("no request body or no handler, skip reading body")
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Debugf("has request body and handler, read body")
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action {
|
||||
log.Debugf("onHttpRequestBody called, body size: %d", len(body))
|
||||
if !gjson.GetBytes(body, "jsonrpc").Exists() {
|
||||
if config.httpRequestHandler != nil {
|
||||
log.Debugf("body is not jsonrpc, using httpRequestHandler")
|
||||
headers, err := proxywasm.GetHttpRequestHeaders()
|
||||
if err != nil {
|
||||
log.Errorf("get request headers failed, err:%v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return config.httpRequestHandler(ctx, config.config, headers, body)
|
||||
}
|
||||
log.Debugf("body is not jsonrpc, but no httpRequestHandler, skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Debugf("body is jsonrpc, using HandleJsonRpcRequest")
|
||||
return utils.HandleJsonRpcRequest(ctx, body, config.jsonRpcRequestHandler)
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config mcpFilterConfig) types.Action {
|
||||
log.Debugf("onHttpResponseHeaders called")
|
||||
// IsApplicationJson checks if the content type is application/json, so we can skip reading the body if it's application/octet-stream
|
||||
if !ctx.HasResponseBody() || !wrapper.IsApplicationJson() || (config.httpResponseHandler == nil && config.jsonRpcResponseHandler == nil) {
|
||||
log.Debugf("no response body or no handler, skip reading body")
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Debugf("has response body and handler, read body")
|
||||
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config mcpFilterConfig, body []byte) types.Action {
|
||||
log.Debugf("onHttpResponseBody called, body size: %d", len(body))
|
||||
if !gjson.GetBytes(body, "jsonrpc").Exists() {
|
||||
if config.httpResponseHandler != nil {
|
||||
log.Debugf("body is not jsonrpc, using httpResponseHandler")
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Errorf("get response headers failed, err:%v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return config.httpResponseHandler(ctx, config.config, headers, body)
|
||||
}
|
||||
log.Debugf("body is not jsonrpc, but no httpResponseHandler, skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Debugf("body is jsonrpc, using HandleJsonRpcResponse")
|
||||
return utils.HandleJsonRpcResponse(ctx, body, config.jsonRpcResponseHandler)
|
||||
}
|
||||
39
plugins/wasm-go/pkg/mcp/go.mod
Normal file
39
plugins/wasm-go/pkg/mcp/go.mod
Normal file
@@ -0,0 +1,39 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/pkg/mcp
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/invopop/jsonschema v0.13.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
golang.org/x/crypto v0.26.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
84
plugins/wasm-go/pkg/mcp/mcp.go
Normal file
84
plugins/wasm-go/pkg/mcp/mcp.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/filter"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/server"
|
||||
)
|
||||
|
||||
var _ server.Server = &MCPServer{}
|
||||
|
||||
// MCPServer implements the Server interface using BaseMCPServer
|
||||
type MCPServer struct {
|
||||
base server.BaseMCPServer
|
||||
}
|
||||
|
||||
// NewMCPServer creates a new MCPServer
|
||||
func NewMCPServer() *MCPServer {
|
||||
return &MCPServer{
|
||||
base: server.NewBaseMCPServer(),
|
||||
}
|
||||
}
|
||||
|
||||
// Clone implements Server interface
|
||||
func (s *MCPServer) Clone() server.Server {
|
||||
return &MCPServer{
|
||||
base: s.base.CloneBase(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddMCPTool implements Server interface
|
||||
func (s *MCPServer) AddMCPTool(name string, tool server.Tool) server.Server {
|
||||
s.base.AddMCPTool(name, tool)
|
||||
return s
|
||||
}
|
||||
|
||||
// GetConfig implements Server interface
|
||||
func (s *MCPServer) GetConfig(v any) {
|
||||
s.base.GetConfig(v)
|
||||
}
|
||||
|
||||
// GetMCPTools implements Server interface
|
||||
func (s *MCPServer) GetMCPTools() map[string]server.Tool {
|
||||
return s.base.GetMCPTools()
|
||||
}
|
||||
|
||||
// SetConfig implements Server interface
|
||||
func (s *MCPServer) SetConfig(config []byte) {
|
||||
s.base.SetConfig(config)
|
||||
}
|
||||
|
||||
// mcp server function
|
||||
var (
|
||||
LoadMCPServer = server.Load
|
||||
|
||||
InitMCPServer = server.Initialize
|
||||
|
||||
AddMCPServer = server.AddMCPServer
|
||||
)
|
||||
|
||||
// mcp filter function
|
||||
var (
|
||||
LoadMCPFilter = filter.Load
|
||||
|
||||
InitMCPFilter = filter.Initialize
|
||||
|
||||
SetConfigParser = filter.SetConfigParser
|
||||
|
||||
SetConfigOverrideParser = filter.SetConfigOverrideParser
|
||||
|
||||
FilterName = filter.FilterName
|
||||
|
||||
SetJsonRpcRequestFilter = filter.SetJsonRpcRequestFilter
|
||||
|
||||
SetJsonRpcResponseFilter = filter.SetJsonRpcResponseFilter
|
||||
|
||||
SetFallbackHTTPRequestFilter = filter.SetFallbackHTTPRequestFilter
|
||||
|
||||
SetFallbackHTTPResponseFilter = filter.SetFallbackHTTPResponseFilter
|
||||
|
||||
SetToolCallRequestFilter = filter.SetToolCallRequestFilter
|
||||
|
||||
SetToolCallResponseFilter = filter.SetToolCallResponseFilter
|
||||
|
||||
SetToolListResponseFilter = filter.SetToolListResponseFilter
|
||||
)
|
||||
232
plugins/wasm-go/pkg/mcp/server/auth_utils.go
Normal file
232
plugins/wasm-go/pkg/mcp/server/auth_utils.go
Normal file
@@ -0,0 +1,232 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
// setOrReplaceHeader sets or replaces a header in the headers slice.
|
||||
// If the header exists (case-insensitive comparison), it replaces the value.
|
||||
// If the header doesn't exist, it appends a new header.
|
||||
func setOrReplaceHeader(headers *[][2]string, key, value string) {
|
||||
lowerKey := strings.ToLower(key)
|
||||
|
||||
// Check if header already exists
|
||||
for i, header := range *headers {
|
||||
if strings.ToLower(header[0]) == lowerKey {
|
||||
// Replace existing header value
|
||||
(*headers)[i][1] = value
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Header doesn't exist, append new one
|
||||
*headers = append(*headers, [2]string{key, value})
|
||||
}
|
||||
|
||||
// SecurityScheme defines a security scheme for the REST API
|
||||
type SecurityScheme struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // http, apiKey
|
||||
Scheme string `json:"scheme,omitempty"` // basic, bearer (for type: http)
|
||||
In string `json:"in,omitempty"` // header, query (for type: apiKey)
|
||||
Name string `json:"name,omitempty"` // Header or query parameter name (for type: apiKey)
|
||||
DefaultCredential string `json:"defaultCredential,omitempty"`
|
||||
}
|
||||
|
||||
// SecurityRequirement specifies a security scheme requirement for a tool
|
||||
type SecurityRequirement struct {
|
||||
ID string `json:"id"` // References a security scheme ID
|
||||
Credential string `json:"credential,omitempty"` // Overrides default credential
|
||||
Passthrough bool `json:"passthrough,omitempty"` // If true, credentials from client request will be passed through
|
||||
}
|
||||
|
||||
// AuthRequestContext holds the data needed for applying security schemes.
|
||||
type AuthRequestContext struct {
|
||||
Method string
|
||||
Headers [][2]string // Direct slice, modifications within applySecurity will update this field in the struct instance
|
||||
ParsedURL *url.URL // Pointer to allow modification (e.g., RawQuery)
|
||||
RequestBody []byte // For future security types that might inspect the body
|
||||
PassthroughCredential string // Credential extracted from client request for passthrough
|
||||
}
|
||||
|
||||
// SecuritySchemeProvider provides access to security schemes
|
||||
type SecuritySchemeProvider interface {
|
||||
GetSecurityScheme(id string) (SecurityScheme, bool)
|
||||
}
|
||||
|
||||
// ExtractAndRemoveIncomingCredential extracts a credential from the current incoming HTTP request
|
||||
// and removes it. It uses global proxywasm functions to access request details.
|
||||
// For query parameters, "removal" is conceptual as we build a new request;
|
||||
// this function primarily extracts the value for potential passthrough.
|
||||
func ExtractAndRemoveIncomingCredential(scheme SecurityScheme) (string, error) {
|
||||
credentialValue := ""
|
||||
var err error
|
||||
|
||||
switch scheme.Type {
|
||||
case "http":
|
||||
authHeader, _ := proxywasm.GetHttpRequestHeader("Authorization") // Error ignored, check content
|
||||
if authHeader == "" {
|
||||
// If no header, it's not an error for extraction if not required, but indicates not found.
|
||||
// For removal, there's nothing to remove.
|
||||
return "", nil // Or a specific "not found" error if scheme implies it must be there.
|
||||
}
|
||||
|
||||
if scheme.Scheme == "bearer" {
|
||||
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return "", fmt.Errorf("incoming Authorization header is not Bearer auth: %s", authHeader)
|
||||
}
|
||||
credentialValue = strings.TrimSpace(authHeader[len("Bearer "):])
|
||||
} else if scheme.Scheme == "basic" {
|
||||
if !strings.HasPrefix(strings.ToLower(authHeader), "basic ") {
|
||||
return "", fmt.Errorf("incoming Authorization header is not Basic auth: %s", authHeader)
|
||||
}
|
||||
credentialValue = strings.TrimSpace(authHeader[len("Basic "):])
|
||||
} else {
|
||||
return "", fmt.Errorf("unsupported http scheme for credential extraction/removal: %s", scheme.Scheme)
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("Authorization")
|
||||
log.Debugf("Extracted and removed Authorization header for incoming %s scheme.", scheme.Scheme)
|
||||
|
||||
case "apiKey":
|
||||
if scheme.In == "header" {
|
||||
if scheme.Name == "" {
|
||||
return "", errors.New("apiKey in header requires a name for the header")
|
||||
}
|
||||
headerValue, _ := proxywasm.GetHttpRequestHeader(scheme.Name) // Error ignored, check content
|
||||
if headerValue == "" {
|
||||
return "", nil // Not found, not necessarily an error for extraction.
|
||||
}
|
||||
credentialValue = headerValue
|
||||
proxywasm.RemoveHttpRequestHeader(scheme.Name)
|
||||
log.Debugf("Extracted and removed %s header for incoming apiKey auth.", scheme.Name)
|
||||
} else if scheme.In == "query" {
|
||||
if scheme.Name == "" {
|
||||
return "", errors.New("apiKey in query requires a name for the query parameter")
|
||||
}
|
||||
pathHeader, _ := proxywasm.GetHttpRequestHeader(":path") // Error ignored, check content
|
||||
if pathHeader == "" {
|
||||
// This case might be an error as :path should generally exist.
|
||||
return "", fmt.Errorf("no :path header found in incoming request for apiKey in query")
|
||||
}
|
||||
|
||||
requestURL, parseErr := url.Parse(pathHeader)
|
||||
if parseErr != nil {
|
||||
return "", fmt.Errorf("failed to parse incoming :path header '%s': %v", pathHeader, parseErr)
|
||||
}
|
||||
|
||||
queryValues := requestURL.Query()
|
||||
apiKeyValue := queryValues.Get(scheme.Name)
|
||||
if apiKeyValue == "" {
|
||||
return "", nil // Not found
|
||||
}
|
||||
credentialValue = apiKeyValue
|
||||
log.Debugf("Extracted %s query parameter from incoming request. Removal from original :path is implicit.", scheme.Name)
|
||||
} else {
|
||||
return "", fmt.Errorf("unsupported apiKey 'in' value: %s", scheme.In)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported security scheme type for credential extraction/removal: %s", scheme.Type)
|
||||
}
|
||||
|
||||
return credentialValue, err
|
||||
}
|
||||
|
||||
// ApplySecurity applies the configured security scheme to the request.
|
||||
// It modifies reqCtx.Headers and reqCtx.ParsedURL (specifically RawQuery) in place if necessary.
|
||||
func ApplySecurity(securityConfig SecurityRequirement, provider SecuritySchemeProvider, reqCtx *AuthRequestContext) error {
|
||||
if securityConfig.ID == "" {
|
||||
return nil // No security scheme defined
|
||||
}
|
||||
if reqCtx.ParsedURL == nil {
|
||||
return errors.New("ParsedURL in AuthRequestContext cannot be nil for ApplySecurity")
|
||||
}
|
||||
|
||||
upstreamScheme, schemeOk := provider.GetSecurityScheme(securityConfig.ID)
|
||||
if !schemeOk {
|
||||
return fmt.Errorf("upstream security scheme with id '%s' not found", securityConfig.ID)
|
||||
}
|
||||
|
||||
var credentialToUse string
|
||||
if reqCtx.PassthroughCredential != "" {
|
||||
// Use the passthrough credential value.
|
||||
// The upstreamScheme dictates how this value is formatted and applied.
|
||||
credentialToUse = reqCtx.PassthroughCredential
|
||||
log.Debugf("Using passthrough credential for upstream request with scheme %s.", upstreamScheme.ID)
|
||||
} else {
|
||||
// Use configured credential for the upstream request.
|
||||
credentialToUse = upstreamScheme.DefaultCredential
|
||||
if securityConfig.Credential != "" {
|
||||
credentialToUse = securityConfig.Credential
|
||||
}
|
||||
if credentialToUse == "" {
|
||||
return fmt.Errorf("no credential found or configured for upstream security scheme '%s'", upstreamScheme.ID)
|
||||
}
|
||||
log.Debugf("Using configured credential for upstream request with scheme %s.", upstreamScheme.ID)
|
||||
}
|
||||
|
||||
switch upstreamScheme.Type {
|
||||
case "http":
|
||||
authValue := credentialToUse
|
||||
if upstreamScheme.Scheme == "basic" {
|
||||
if !strings.HasPrefix(authValue, "Basic ") {
|
||||
if reqCtx.PassthroughCredential != "" { // Came from passthrough, it's the base64 token part
|
||||
authValue = "Basic " + credentialToUse
|
||||
} else { // Came from config
|
||||
if strings.Contains(credentialToUse, ":") { // Assumed to be "user:pass"
|
||||
authValue = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentialToUse))
|
||||
} else { // Assumed to be already base64 encoded string (token part)
|
||||
authValue = "Basic " + credentialToUse
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if upstreamScheme.Scheme == "bearer" {
|
||||
// Passthrough for Bearer gives the token part. Configured credential is the token.
|
||||
if !strings.HasPrefix(authValue, "Bearer ") {
|
||||
authValue = "Bearer " + credentialToUse
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unsupported http scheme type for upstream: %s", upstreamScheme.Scheme)
|
||||
}
|
||||
setOrReplaceHeader(&reqCtx.Headers, "Authorization", authValue)
|
||||
case "apiKey":
|
||||
if upstreamScheme.In == "header" {
|
||||
if upstreamScheme.Name == "" {
|
||||
return errors.New("apiKey in header requires a name for the header for upstream")
|
||||
}
|
||||
setOrReplaceHeader(&reqCtx.Headers, upstreamScheme.Name, credentialToUse)
|
||||
} else if upstreamScheme.In == "query" {
|
||||
if upstreamScheme.Name == "" {
|
||||
return errors.New("apiKey in query requires a name for the query parameter for upstream")
|
||||
}
|
||||
queryValues := reqCtx.ParsedURL.Query()
|
||||
queryValues.Set(upstreamScheme.Name, credentialToUse)
|
||||
reqCtx.ParsedURL.RawQuery = queryValues.Encode()
|
||||
} else {
|
||||
return fmt.Errorf("unsupported apiKey 'in' value for upstream: %s", upstreamScheme.In)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported security scheme type: %s", upstreamScheme.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
100
plugins/wasm-go/pkg/mcp/server/base_server.go
Normal file
100
plugins/wasm-go/pkg/mcp/server/base_server.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
// BaseMCPServer provides common functionality for MCP servers
|
||||
type BaseMCPServer struct {
|
||||
tools map[string]Tool
|
||||
config []byte
|
||||
}
|
||||
|
||||
// NewBaseMCPServer creates a new BaseMCPServer
|
||||
func NewBaseMCPServer() BaseMCPServer {
|
||||
return BaseMCPServer{
|
||||
tools: make(map[string]Tool),
|
||||
}
|
||||
}
|
||||
|
||||
// AddMCPTool adds a tool to the server
|
||||
func (s *BaseMCPServer) AddMCPTool(name string, tool Tool) Server {
|
||||
if _, exist := s.tools[name]; exist {
|
||||
log.Errorf("Conflict! There is a tool with the same name:%s", name)
|
||||
return s
|
||||
}
|
||||
s.tools[name] = tool
|
||||
return s
|
||||
}
|
||||
|
||||
// GetMCPTools returns all tools registered with the server
|
||||
func (s *BaseMCPServer) GetMCPTools() map[string]Tool {
|
||||
return s.tools
|
||||
}
|
||||
|
||||
// SetConfig sets the server configuration
|
||||
func (s *BaseMCPServer) SetConfig(config []byte) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
// GetConfig gets the server configuration
|
||||
// It first tries to get the config from the request header, then falls back to the stored config
|
||||
func (s *BaseMCPServer) GetConfig(v any) {
|
||||
var config []byte
|
||||
serverConfigBase64, _ := proxywasm.GetHttpRequestHeader("x-higress-mcpserver-config")
|
||||
proxywasm.RemoveHttpRequestHeader("x-higress-mcpserver-config")
|
||||
if serverConfigBase64 != "" {
|
||||
serverConfig, err := base64.StdEncoding.DecodeString(serverConfigBase64)
|
||||
if err != nil {
|
||||
log.Errorf("base64 decode mcp server config failed:%s, bytes:%s", err, serverConfigBase64)
|
||||
} else {
|
||||
config = serverConfig
|
||||
}
|
||||
log.Infof("parse server config from request, config:%s", serverConfig)
|
||||
} else {
|
||||
config = s.config
|
||||
}
|
||||
if len(config) == 0 {
|
||||
return
|
||||
}
|
||||
err := json.Unmarshal(config, v)
|
||||
if err != nil {
|
||||
log.Errorf("json unmarshal server config failed:%v, config:%s", err, config)
|
||||
}
|
||||
}
|
||||
|
||||
// Clone creates a copy of the server
|
||||
// This method should be overridden by derived types
|
||||
func (s *BaseMCPServer) Clone() Server {
|
||||
panic("Clone method must be implemented by derived types")
|
||||
}
|
||||
|
||||
// CloneBase creates a copy of the base server
|
||||
func (s *BaseMCPServer) CloneBase() BaseMCPServer {
|
||||
newServer := BaseMCPServer{
|
||||
tools: make(map[string]Tool),
|
||||
config: s.config,
|
||||
}
|
||||
for k, v := range s.tools {
|
||||
newServer.tools[k] = v
|
||||
}
|
||||
return newServer
|
||||
}
|
||||
127
plugins/wasm-go/pkg/mcp/server/composed_server.go
Normal file
127
plugins/wasm-go/pkg/mcp/server/composed_server.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/consts"
|
||||
)
|
||||
|
||||
// ComposedMCPServer represents a server composed of tools from other servers.
|
||||
type ComposedMCPServer struct {
|
||||
name string // Name of the composed server (from toolSet.name)
|
||||
serverTools []ServerToolConfig // Configuration of which tools to include
|
||||
registry *GlobalToolRegistry // Reference to the global tool registry
|
||||
config []byte // Configuration for the composed server itself (if any)
|
||||
}
|
||||
|
||||
// NewComposedMCPServer creates a new ComposedMCPServer.
|
||||
func NewComposedMCPServer(name string, serverToolsConfig []ServerToolConfig, registry *GlobalToolRegistry) *ComposedMCPServer {
|
||||
return &ComposedMCPServer{
|
||||
name: name,
|
||||
serverTools: serverToolsConfig,
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName returns the name of the composed server.
|
||||
func (cs *ComposedMCPServer) GetName() string {
|
||||
return cs.name
|
||||
}
|
||||
|
||||
// AddMCPTool for ComposedMCPServer is a no-op as tools are defined by toolSet.
|
||||
func (cs *ComposedMCPServer) AddMCPTool(name string, tool Tool) Server {
|
||||
log.Warnf("AddMCPTool called on ComposedMCPServer '%s'; this is a no-op.", cs.name)
|
||||
return cs
|
||||
}
|
||||
|
||||
// GetMCPTools constructs and returns the map of tools exposed by this composed server.
|
||||
// The tool names are prefixed with their original server name, e.g., "${originalServer}___${toolName}".
|
||||
// The Tool instances are DescriptiveTool, only providing Description and InputSchema.
|
||||
func (cs *ComposedMCPServer) GetMCPTools() map[string]Tool {
|
||||
composedTools := make(map[string]Tool)
|
||||
for _, stc := range cs.serverTools {
|
||||
originalServerName := stc.ServerName
|
||||
for _, originalToolName := range stc.Tools {
|
||||
toolInfo, found := cs.registry.GetToolInfo(originalServerName, originalToolName)
|
||||
if !found {
|
||||
log.Warnf("Tool %s/%s not found in global registry for composed server %s", originalServerName, originalToolName, cs.name)
|
||||
continue
|
||||
}
|
||||
|
||||
composedToolName := fmt.Sprintf("%s%s%s", originalServerName, consts.ToolSetNameSplitter, originalToolName)
|
||||
composedTools[composedToolName] = &DescriptiveTool{
|
||||
description: toolInfo.Description,
|
||||
inputSchema: toolInfo.InputSchema,
|
||||
outputSchema: toolInfo.OutputSchema, // New field for MCP Protocol Version 2025-06-18
|
||||
}
|
||||
}
|
||||
}
|
||||
return composedTools
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration for the composed server itself.
|
||||
func (cs *ComposedMCPServer) SetConfig(config []byte) {
|
||||
cs.config = config
|
||||
}
|
||||
|
||||
// GetConfig retrieves the configuration of the composed server itself.
|
||||
func (cs *ComposedMCPServer) GetConfig(v any) {
|
||||
if len(cs.config) == 0 {
|
||||
return
|
||||
}
|
||||
if ptrBytes, ok := v.(*[]byte); ok {
|
||||
*ptrBytes = cs.config
|
||||
} else {
|
||||
// If you need to unmarshal to a struct, you'd do it here.
|
||||
// For now, keeping it simple as per previous discussions.
|
||||
log.Warnf("ComposedMCPServer.GetConfig called with unhandled type for v. Config not set.")
|
||||
}
|
||||
}
|
||||
|
||||
// Clone creates a new instance of the ComposedMCPServer with the same configuration.
|
||||
func (cs *ComposedMCPServer) Clone() Server {
|
||||
cloned := NewComposedMCPServer(cs.name, cs.serverTools, cs.registry)
|
||||
cloned.SetConfig(cs.config)
|
||||
return cloned
|
||||
}
|
||||
|
||||
// DescriptiveTool is a placeholder Tool implementation for ComposedMCPServer.
|
||||
// Its Call and Create methods should never be invoked.
|
||||
type DescriptiveTool struct {
|
||||
description string
|
||||
inputSchema map[string]any
|
||||
outputSchema map[string]any // New field for MCP Protocol Version 2025-06-18
|
||||
}
|
||||
|
||||
// Create for DescriptiveTool should not be called.
|
||||
func (dt *DescriptiveTool) Create(params []byte) Tool {
|
||||
log.Errorf("DescriptiveTool.Create called for tool used in ComposedMCPServer. This should not happen.")
|
||||
// Return a new instance to fulfill the interface, though it's an error state.
|
||||
return &DescriptiveTool{
|
||||
description: dt.description,
|
||||
inputSchema: dt.inputSchema,
|
||||
outputSchema: dt.outputSchema,
|
||||
}
|
||||
}
|
||||
|
||||
// Call for DescriptiveTool should not be called.
|
||||
func (dt *DescriptiveTool) Call(httpCtx HttpContext, server Server) error {
|
||||
log.Errorf("DescriptiveTool.Call called for tool used in ComposedMCPServer. This should not happen.")
|
||||
return fmt.Errorf("DescriptiveTool.Call should not be invoked on a ComposedMCPServer's tool")
|
||||
}
|
||||
|
||||
// Description returns the tool's description.
|
||||
func (dt *DescriptiveTool) Description() string {
|
||||
return dt.description
|
||||
}
|
||||
|
||||
// InputSchema returns the tool's input schema.
|
||||
func (dt *DescriptiveTool) InputSchema() map[string]any {
|
||||
return dt.inputSchema
|
||||
}
|
||||
|
||||
// OutputSchema returns the tool's output schema (MCP Protocol Version 2025-06-18).
|
||||
func (dt *DescriptiveTool) OutputSchema() map[string]any {
|
||||
return dt.outputSchema
|
||||
}
|
||||
328
plugins/wasm-go/pkg/mcp/server/config_validator_test.go
Normal file
328
plugins/wasm-go/pkg/mcp/server/config_validator_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// testLogger is a mock logger for testing to prevent panics
|
||||
type testLogger struct{}
|
||||
|
||||
func (l *testLogger) Trace(msg string) { fmt.Fprintf(os.Stderr, "[TRACE] %s\n", msg) }
|
||||
func (l *testLogger) Tracef(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[TRACE] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) Debug(msg string) { fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) }
|
||||
func (l *testLogger) Debugf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[DEBUG] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) Info(msg string) { fmt.Fprintf(os.Stderr, "[INFO] %s\n", msg) }
|
||||
func (l *testLogger) Infof(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[INFO] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) Warn(msg string) { fmt.Fprintf(os.Stderr, "[WARN] %s\n", msg) }
|
||||
func (l *testLogger) Warnf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[WARN] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) Error(msg string) { fmt.Fprintf(os.Stderr, "[ERROR] %s\n", msg) }
|
||||
func (l *testLogger) Errorf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) Critical(msg string) { fmt.Fprintf(os.Stderr, "[CRITICAL] %s\n", msg) }
|
||||
func (l *testLogger) Criticalf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[CRITICAL] "+format+"\n", args...)
|
||||
}
|
||||
func (l *testLogger) ResetID(pluginID string) {}
|
||||
|
||||
func init() {
|
||||
// Set a custom logger for testing to prevent panics
|
||||
log.SetPluginLog(&testLogger{})
|
||||
}
|
||||
|
||||
// TestMcpProxyConfigValidation tests configuration validation for mcp-proxy servers
|
||||
func TestMcpProxyConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
shouldErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid basic proxy config",
|
||||
config: `{
|
||||
"server": {
|
||||
"name": "test-proxy",
|
||||
"type": "mcp-proxy",
|
||||
"transport": "http",
|
||||
"mcpServerURL": "http://backend.example.com/mcp",
|
||||
"timeout": 5000
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"name": "test-tool",
|
||||
"description": "Test tool",
|
||||
"args": [
|
||||
{
|
||||
"name": "input",
|
||||
"description": "Input parameter",
|
||||
"type": "string",
|
||||
"required": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "proxy config with security schemes",
|
||||
config: `{
|
||||
"server": {
|
||||
"name": "secure-proxy",
|
||||
"type": "mcp-proxy",
|
||||
"transport": "http",
|
||||
"mcpServerURL": "https://secure.example.com/mcp",
|
||||
"timeout": 8000,
|
||||
"securitySchemes": [
|
||||
{
|
||||
"id": "ApiKeyAuth",
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key",
|
||||
"defaultCredential": "test-key"
|
||||
}
|
||||
]
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"name": "secure-tool",
|
||||
"description": "Secure tool",
|
||||
"args": [
|
||||
{
|
||||
"name": "data",
|
||||
"description": "Data parameter",
|
||||
"type": "object",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"requestTemplate": {
|
||||
"security": {
|
||||
"id": "ApiKeyAuth"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing mcpServerURL should fail",
|
||||
config: `{
|
||||
"server": {
|
||||
"name": "invalid-proxy",
|
||||
"type": "mcp-proxy",
|
||||
"transport": "http",
|
||||
"timeout": 5000
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"name": "test-tool",
|
||||
"description": "Test tool",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}`,
|
||||
shouldErr: true,
|
||||
errMsg: "mcpServerURL is required",
|
||||
},
|
||||
{
|
||||
name: "invalid server type should use default REST handling",
|
||||
config: `{
|
||||
"server": {
|
||||
"name": "rest-server",
|
||||
"type": "rest-api"
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"name": "rest-tool",
|
||||
"description": "REST tool",
|
||||
"args": [],
|
||||
"requestTemplate": {
|
||||
"url": "http://example.com/api",
|
||||
"method": "GET"
|
||||
},
|
||||
"responseTemplate": {
|
||||
"body": "$.result"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`,
|
||||
shouldErr: false, // Should fall back to REST server logic
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configJson := gjson.Parse(tt.config)
|
||||
config := &McpServerConfig{}
|
||||
|
||||
// Create validation options (similar to validator package)
|
||||
toolRegistry := &GlobalToolRegistry{}
|
||||
toolRegistry.Initialize()
|
||||
|
||||
opts := &ConfigOptions{
|
||||
Servers: make(map[string]Server),
|
||||
ToolRegistry: toolRegistry,
|
||||
SkipPreRegisteredServers: true,
|
||||
}
|
||||
|
||||
err := ParseConfigCore(configJson, config, opts)
|
||||
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecuritySchemeValidation tests security scheme configuration validation
|
||||
func TestSecuritySchemeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scheme SecurityScheme
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid API key scheme",
|
||||
scheme: SecurityScheme{
|
||||
ID: "ApiKeyAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid HTTP bearer scheme",
|
||||
scheme: SecurityScheme{
|
||||
ID: "BearerAuth",
|
||||
Type: "http",
|
||||
Scheme: "bearer",
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid scheme - missing ID",
|
||||
scheme: SecurityScheme{
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid scheme - missing Name for apiKey",
|
||||
scheme: SecurityScheme{
|
||||
ID: "ApiKeyAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// This will test the validation logic once SecurityScheme validation is implemented
|
||||
err := ValidateSecurityScheme(tt.scheme)
|
||||
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolConfigValidation tests tool configuration validation
|
||||
func TestToolConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolCfg McpProxyToolConfig
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid tool config",
|
||||
toolCfg: McpProxyToolConfig{
|
||||
Name: "valid-tool",
|
||||
Description: "A valid tool",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "param1",
|
||||
Description: "Parameter 1",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tool - missing name",
|
||||
toolCfg: McpProxyToolConfig{
|
||||
Description: "Tool without name",
|
||||
Args: []ToolArg{},
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool - empty description",
|
||||
toolCfg: McpProxyToolConfig{
|
||||
Name: "tool-no-desc",
|
||||
Description: "",
|
||||
Args: []ToolArg{},
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateToolConfig(tt.toolCfg)
|
||||
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// These validation functions are now implemented in proxy_server.go
|
||||
817
plugins/wasm-go/pkg/mcp/server/plugin.go
Normal file
817
plugins/wasm-go/pkg/mcp/server/plugin.go
Normal file
@@ -0,0 +1,817 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/invopop/jsonschema"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
||||
GlobalToolRegistryKey = "GlobalToolRegistry"
|
||||
)
|
||||
|
||||
// SupportedMCPVersions contains all supported MCP protocol versions
|
||||
var SupportedMCPVersions = []string{"2024-11-05", "2025-03-26", "2025-06-18"}
|
||||
|
||||
// validateURL validates that the given string is a valid URL
|
||||
func validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return errors.New("url cannot be empty")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %v", err)
|
||||
}
|
||||
|
||||
// Allow both full URLs (with scheme and host) and path-only URLs
|
||||
// Path-only URLs will be resolved against the cluster's base URL
|
||||
if parsedURL.Scheme != "" {
|
||||
// If scheme is provided, host must also be provided
|
||||
if parsedURL.Host == "" {
|
||||
return errors.New("url with scheme must include a host")
|
||||
}
|
||||
|
||||
// Only allow http and https schemes for security
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported URL scheme '%s', only http and https are allowed", parsedURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupMcpProxyServer creates and configures an MCP proxy server
|
||||
func setupMcpProxyServer(serverName string, serverJson gjson.Result, serverConfigJsonForInstance string) (*McpProxyServer, error) {
|
||||
proxyServer := NewMcpProxyServer(serverName)
|
||||
proxyServer.SetConfig([]byte(serverConfigJsonForInstance))
|
||||
|
||||
// Parse and validate transport (required for mcp-proxy)
|
||||
transportStr := serverJson.Get("transport").String()
|
||||
if transportStr == "" {
|
||||
return nil, errors.New("transport field is required for mcp-proxy server type")
|
||||
}
|
||||
transport := TransportProtocol(transportStr)
|
||||
if transport != TransportHTTP && transport != TransportSSE {
|
||||
return nil, fmt.Errorf("invalid transport value: %s, must be 'http' or 'sse'", transportStr)
|
||||
}
|
||||
proxyServer.SetTransport(transport)
|
||||
|
||||
// Parse and validate mcpServerURL (required for mcp-proxy)
|
||||
mcpServerURL := serverJson.Get("mcpServerURL").String()
|
||||
if mcpServerURL == "" {
|
||||
return nil, errors.New("mcpServerURL is required for mcp-proxy server type")
|
||||
}
|
||||
if err := validateURL(mcpServerURL); err != nil {
|
||||
return nil, fmt.Errorf("invalid mcpServerURL: %v", err)
|
||||
}
|
||||
proxyServer.SetMcpServerURL(mcpServerURL)
|
||||
|
||||
// Parse timeout (optional)
|
||||
timeout := serverJson.Get("timeout").Int()
|
||||
if timeout > 0 {
|
||||
proxyServer.SetTimeout(int(timeout))
|
||||
}
|
||||
|
||||
// Parse passthroughAuthHeader (optional, defaults to false)
|
||||
passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool()
|
||||
proxyServer.SetPassthroughAuthHeader(passthroughAuthHeader)
|
||||
|
||||
// Parse security schemes
|
||||
securitySchemesJson := serverJson.Get("securitySchemes")
|
||||
if securitySchemesJson.Exists() {
|
||||
for _, schemeJson := range securitySchemesJson.Array() {
|
||||
var scheme SecurityScheme
|
||||
if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse security scheme config: %v", err)
|
||||
}
|
||||
proxyServer.AddSecurityScheme(scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse default downstream security
|
||||
defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity")
|
||||
if defaultDownstreamSecurityJson.Exists() {
|
||||
var defaultDownstreamSecurity SecurityRequirement
|
||||
if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err)
|
||||
}
|
||||
proxyServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
|
||||
}
|
||||
|
||||
// Parse default upstream security
|
||||
defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity")
|
||||
if defaultUpstreamSecurityJson.Exists() {
|
||||
var defaultUpstreamSecurity SecurityRequirement
|
||||
if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err)
|
||||
}
|
||||
proxyServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
|
||||
}
|
||||
|
||||
return proxyServer, nil
|
||||
}
|
||||
|
||||
type HttpContext wrapper.HttpContext
|
||||
|
||||
type Context struct {
|
||||
servers map[string]Server
|
||||
}
|
||||
|
||||
type CtxOption interface {
|
||||
Apply(*Context)
|
||||
}
|
||||
|
||||
var globalContext Context
|
||||
|
||||
// ToolInfo stores information about a tool for the global registry.
|
||||
type ToolInfo struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema map[string]any
|
||||
OutputSchema map[string]any // New field for MCP Protocol Version 2025-06-18
|
||||
ServerName string // Original server name
|
||||
Tool Tool // The actual tool instance for cloning
|
||||
}
|
||||
|
||||
// GlobalToolRegistry holds all tools from all servers.
|
||||
type GlobalToolRegistry struct {
|
||||
// serverName -> toolName -> toolInfo
|
||||
serverTools map[string]map[string]ToolInfo
|
||||
}
|
||||
|
||||
// Initialize initializes the GlobalToolRegistry
|
||||
func (r *GlobalToolRegistry) Initialize() {
|
||||
r.serverTools = make(map[string]map[string]ToolInfo)
|
||||
}
|
||||
|
||||
// RegisterTool registers a tool into the global registry.
|
||||
func (r *GlobalToolRegistry) RegisterTool(serverName string, toolName string, tool Tool) {
|
||||
if _, ok := r.serverTools[serverName]; !ok {
|
||||
r.serverTools[serverName] = make(map[string]ToolInfo)
|
||||
}
|
||||
toolInfo := ToolInfo{
|
||||
Name: toolName,
|
||||
Description: tool.Description(),
|
||||
InputSchema: tool.InputSchema(),
|
||||
ServerName: serverName,
|
||||
Tool: tool,
|
||||
}
|
||||
// Check if tool implements OutputSchema (MCP Protocol Version 2025-06-18)
|
||||
if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok {
|
||||
toolInfo.OutputSchema = toolWithSchema.OutputSchema()
|
||||
}
|
||||
r.serverTools[serverName][toolName] = toolInfo
|
||||
log.Debugf("Registered tool %s/%s", serverName, toolName)
|
||||
}
|
||||
|
||||
// GetToolInfo retrieves tool information from the global registry.
|
||||
func (r *GlobalToolRegistry) GetToolInfo(serverName string, toolName string) (ToolInfo, bool) {
|
||||
if serverTools, ok := r.serverTools[serverName]; ok {
|
||||
toolInfo, found := serverTools[toolName]
|
||||
return toolInfo, found
|
||||
}
|
||||
return ToolInfo{}, false
|
||||
}
|
||||
|
||||
func onPluginStartOrReload(context wrapper.PluginContext) error {
|
||||
toolRegistry := &GlobalToolRegistry{}
|
||||
toolRegistry.Initialize()
|
||||
context.SetContext(GlobalToolRegistryKey, toolRegistry)
|
||||
context.EnableRuleLevelConfigIsolation()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServer retrieves a server instance from the global context.
|
||||
// This is needed by ComposedMCPServer to get original server instances.
|
||||
func GetServerFromGlobalContext(serverName string) (Server, bool) {
|
||||
server, exist := globalContext.servers[serverName]
|
||||
return server, exist
|
||||
}
|
||||
|
||||
type Server interface {
|
||||
AddMCPTool(name string, tool Tool) Server
|
||||
GetMCPTools() map[string]Tool // For single server, returns its tools. For composed, returns composed tools.
|
||||
SetConfig(config []byte)
|
||||
GetConfig(v any)
|
||||
Clone() Server
|
||||
// GetName() string // Returns the server name - REMOVED
|
||||
}
|
||||
|
||||
type Tool interface {
|
||||
Create(params []byte) Tool
|
||||
Call(httpCtx HttpContext, server Server) error
|
||||
Description() string
|
||||
InputSchema() map[string]any
|
||||
}
|
||||
|
||||
// ToolWithOutputSchema is an optional interface for tools that support output schema
|
||||
// (MCP Protocol Version 2025-06-18). Tools can optionally implement this interface
|
||||
// to provide output schema information.
|
||||
type ToolWithOutputSchema interface {
|
||||
Tool
|
||||
OutputSchema() map[string]any
|
||||
}
|
||||
|
||||
// ToolSetConfig defines the configuration for a toolset.
|
||||
type ToolSetConfig struct {
|
||||
Name string `json:"name"`
|
||||
ServerTools []ServerToolConfig `json:"serverTools"`
|
||||
}
|
||||
|
||||
// ServerToolConfig specifies which tools from a server to include in a toolset.
|
||||
type ServerToolConfig struct {
|
||||
ServerName string `json:"serverName"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
// ConfigOptions contains the dependencies needed for config parsing
|
||||
type ConfigOptions struct {
|
||||
Servers map[string]Server
|
||||
ToolRegistry *GlobalToolRegistry
|
||||
// Skip validation for pre-registered Go-based servers
|
||||
SkipPreRegisteredServers bool
|
||||
}
|
||||
|
||||
type McpServerConfig struct {
|
||||
serverName string // Store the server name directly
|
||||
server Server // Can be a single server or a composed server
|
||||
methodHandlers utils.MethodHandlers
|
||||
toolSet *ToolSetConfig // Parsed toolset configuration
|
||||
isComposed bool
|
||||
}
|
||||
|
||||
// GetServerName returns the server name for external access
|
||||
func (c *McpServerConfig) GetServerName() string {
|
||||
return c.serverName
|
||||
}
|
||||
|
||||
// GetIsComposed returns whether this is a composed server for external access
|
||||
func (c *McpServerConfig) GetIsComposed() bool {
|
||||
return c.isComposed
|
||||
}
|
||||
|
||||
// computeEffectiveAllowTools computes the effective allowTools by taking the intersection
|
||||
// of config allowTools and request header allowTools.
|
||||
// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set.
|
||||
func computeEffectiveAllowTools(configAllowTools *map[string]struct{}) *map[string]struct{} {
|
||||
// Get allowTools from request header
|
||||
allowToolsHeaderStr, _ := proxywasm.GetHttpRequestHeader("x-envoy-allow-mcp-tools")
|
||||
proxywasm.RemoveHttpRequestHeader("x-envoy-allow-mcp-tools")
|
||||
// Only consider header as "present" if it has non-empty value
|
||||
// Empty string means header is not set or explicitly empty, both treated as "no restriction"
|
||||
headerExists := allowToolsHeaderStr != ""
|
||||
return computeEffectiveAllowToolsFromHeader(configAllowTools, allowToolsHeaderStr, headerExists)
|
||||
}
|
||||
|
||||
// computeEffectiveAllowToolsFromHeader computes the effective allowTools by taking the intersection
|
||||
// of config allowTools and header allowTools string.
|
||||
// This is useful when the header string is already extracted (e.g., in async callbacks).
|
||||
// Returns nil if no restrictions (allow all), otherwise returns a pointer to the effective set.
|
||||
func computeEffectiveAllowToolsFromHeader(configAllowTools *map[string]struct{}, allowToolsHeaderStr string, headerExists bool) *map[string]struct{} {
|
||||
var allowToolsFromHeader *map[string]struct{}
|
||||
if headerExists {
|
||||
// Header is present (even if empty string), parse it
|
||||
headerMap := make(map[string]struct{})
|
||||
for tool := range strings.SplitSeq(allowToolsHeaderStr, ",") {
|
||||
trimmedTool := strings.TrimSpace(tool)
|
||||
if trimmedTool == "" {
|
||||
continue
|
||||
}
|
||||
headerMap[trimmedTool] = struct{}{}
|
||||
}
|
||||
// Always create pointer even if map is empty, to distinguish from "not configured"
|
||||
allowToolsFromHeader = &headerMap
|
||||
}
|
||||
|
||||
// Compute effective allowTools (intersection of config and header)
|
||||
if configAllowTools == nil && allowToolsFromHeader == nil {
|
||||
// Both not configured, allow all tools
|
||||
return nil
|
||||
} else if configAllowTools == nil {
|
||||
// Only header restrictions
|
||||
return allowToolsFromHeader
|
||||
} else if allowToolsFromHeader == nil {
|
||||
// Only config restrictions
|
||||
return configAllowTools
|
||||
} else {
|
||||
// Both restrictions exist, compute intersection
|
||||
intersection := make(map[string]struct{})
|
||||
for tool := range *configAllowTools {
|
||||
if _, exists := (*allowToolsFromHeader)[tool]; exists {
|
||||
intersection[tool] = struct{}{}
|
||||
}
|
||||
}
|
||||
return &intersection
|
||||
}
|
||||
}
|
||||
|
||||
// parseConfigCore contains the core config parsing logic with dependency injection
|
||||
func parseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error {
|
||||
toolSetJson := configJson.Get("toolSet")
|
||||
serverJson := configJson.Get("server") // This is for single server or REST server definition
|
||||
pluginServerConfigJson := configJson.Get("server.config").Raw // Config for the plugin instance itself, if any.
|
||||
|
||||
// serverConfigJsonForInstance is the config passed to the specific server instance (single or REST)
|
||||
// It's distinct from pluginServerConfigJson which might be for the mcp-server plugin itself.
|
||||
var serverConfigJsonForInstance string
|
||||
|
||||
if toolSetJson.Exists() {
|
||||
config.isComposed = true
|
||||
var tsConfig ToolSetConfig
|
||||
if err := json.Unmarshal([]byte(toolSetJson.Raw), &tsConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse toolSet config: %v", err)
|
||||
}
|
||||
config.toolSet = &tsConfig
|
||||
config.serverName = tsConfig.Name // Use toolSet name as the server name for composed server
|
||||
log.Infof("Parsing toolSet configuration: %s", config.serverName)
|
||||
|
||||
composedServer := NewComposedMCPServer(config.serverName, tsConfig.ServerTools, opts.ToolRegistry)
|
||||
// A composed server itself might have a config block, e.g. for shared settings, though not typical.
|
||||
composedServer.SetConfig([]byte(pluginServerConfigJson))
|
||||
config.server = composedServer
|
||||
} else if serverJson.Exists() {
|
||||
config.isComposed = false
|
||||
config.serverName = serverJson.Get("name").String()
|
||||
if config.serverName == "" {
|
||||
return errors.New("server.name field is missing for single server config")
|
||||
}
|
||||
// This is the config for the specific server being defined (e.g. REST server's own config)
|
||||
serverConfigJsonForInstance = serverJson.Get("config").Raw
|
||||
log.Infof("Parsing single server configuration: %s", config.serverName)
|
||||
|
||||
// Check server type to determine which type of server to create
|
||||
serverType := serverJson.Get("type").String()
|
||||
if serverType == "" {
|
||||
serverType = "rest" // Default to REST server type
|
||||
}
|
||||
|
||||
toolsJson := configJson.Get("tools") // These are REST tools for this server instance or MCP proxy tools
|
||||
|
||||
if serverType == "mcp-proxy" {
|
||||
// Create MCP proxy server
|
||||
proxyServer, err := setupMcpProxyServer(config.serverName, serverJson, serverConfigJsonForInstance)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle tools configuration (optional for MCP proxy)
|
||||
if toolsJson.Exists() && len(toolsJson.Array()) > 0 {
|
||||
for _, toolJson := range toolsJson.Array() {
|
||||
var proxyTool McpProxyToolConfig
|
||||
if err := json.Unmarshal([]byte(toolJson.Raw), &proxyTool); err != nil {
|
||||
return fmt.Errorf("failed to parse proxy tool config: %v", err)
|
||||
}
|
||||
|
||||
if err := proxyServer.AddProxyTool(proxyTool); err != nil {
|
||||
return fmt.Errorf("failed to add proxy tool %s: %v", proxyTool.Name, err)
|
||||
}
|
||||
// Register tool to registry
|
||||
opts.ToolRegistry.RegisterTool(config.serverName, proxyTool.Name, proxyServer.GetMCPTools()[proxyTool.Name])
|
||||
}
|
||||
}
|
||||
// Set the proxy server regardless of whether tools are configured
|
||||
config.server = proxyServer
|
||||
} else if toolsJson.Exists() && len(toolsJson.Array()) > 0 {
|
||||
// Handle REST-to-MCP server (requires tools configuration)
|
||||
// Create REST-to-MCP server (default behavior)
|
||||
restServer := NewRestMCPServer(config.serverName) // Pass the server name
|
||||
restServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config
|
||||
|
||||
securitySchemesJson := serverJson.Get("securitySchemes")
|
||||
if securitySchemesJson.Exists() {
|
||||
for _, schemeJson := range securitySchemesJson.Array() {
|
||||
var scheme SecurityScheme
|
||||
if err := json.Unmarshal([]byte(schemeJson.Raw), &scheme); err != nil {
|
||||
return fmt.Errorf("failed to parse security scheme config: %v", err)
|
||||
}
|
||||
restServer.AddSecurityScheme(scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse default downstream security
|
||||
defaultDownstreamSecurityJson := serverJson.Get("defaultDownstreamSecurity")
|
||||
if defaultDownstreamSecurityJson.Exists() {
|
||||
var defaultDownstreamSecurity SecurityRequirement
|
||||
if err := json.Unmarshal([]byte(defaultDownstreamSecurityJson.Raw), &defaultDownstreamSecurity); err != nil {
|
||||
return fmt.Errorf("failed to parse defaultDownstreamSecurity config: %v", err)
|
||||
}
|
||||
restServer.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
|
||||
}
|
||||
|
||||
// Parse default upstream security
|
||||
defaultUpstreamSecurityJson := serverJson.Get("defaultUpstreamSecurity")
|
||||
if defaultUpstreamSecurityJson.Exists() {
|
||||
var defaultUpstreamSecurity SecurityRequirement
|
||||
if err := json.Unmarshal([]byte(defaultUpstreamSecurityJson.Raw), &defaultUpstreamSecurity); err != nil {
|
||||
return fmt.Errorf("failed to parse defaultUpstreamSecurity config: %v", err)
|
||||
}
|
||||
restServer.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
|
||||
}
|
||||
|
||||
// Parse passthroughAuthHeader (optional, defaults to false)
|
||||
passthroughAuthHeader := serverJson.Get("passthroughAuthHeader").Bool()
|
||||
restServer.SetPassthroughAuthHeader(passthroughAuthHeader)
|
||||
|
||||
for _, toolJson := range toolsJson.Array() {
|
||||
var restTool RestTool
|
||||
if err := json.Unmarshal([]byte(toolJson.Raw), &restTool); err != nil {
|
||||
return fmt.Errorf("failed to parse tool config: %v", err)
|
||||
}
|
||||
|
||||
if err := restServer.AddRestTool(restTool); err != nil {
|
||||
return fmt.Errorf("failed to add tool %s: %v", restTool.Name, err)
|
||||
}
|
||||
// Register tool to registry
|
||||
opts.ToolRegistry.RegisterTool(config.serverName, restTool.Name, restServer.GetMCPTools()[restTool.Name])
|
||||
}
|
||||
config.server = restServer
|
||||
} else {
|
||||
// Logic for pre-registered Go-based servers (non-REST)
|
||||
if opts.SkipPreRegisteredServers {
|
||||
// In validation mode, skip pre-registered servers validation
|
||||
// Just validate the basic structure without actual server instance
|
||||
config.server = nil // Will be handled appropriately in validation context
|
||||
} else {
|
||||
if serverInstance, exist := opts.Servers[config.serverName]; exist {
|
||||
clonedServer := serverInstance.Clone()
|
||||
clonedServer.SetConfig([]byte(serverConfigJsonForInstance)) // Pass the server's specific config
|
||||
config.server = clonedServer
|
||||
// Register tools from this server to registry
|
||||
for toolName, toolInstance := range clonedServer.GetMCPTools() {
|
||||
opts.ToolRegistry.RegisterTool(config.serverName, toolName, toolInstance)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("mcp server type '%s' not registered", config.serverName)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors.New("either 'server' or 'toolSet' field must be present in the configuration")
|
||||
}
|
||||
|
||||
// Parse allowTools - this might need adjustment for composed servers
|
||||
// Use pointer to distinguish between "not configured" (nil) and "configured as empty" (empty map)
|
||||
var allowTools *map[string]struct{} // For single server, tool name. For composed, serverName/toolName.
|
||||
allowToolsResult := configJson.Get("allowTools")
|
||||
if allowToolsResult.Exists() {
|
||||
// allowTools is configured, create the map
|
||||
toolsMap := make(map[string]struct{})
|
||||
allowToolsArray := allowToolsResult.Array()
|
||||
for _, toolJson := range allowToolsArray {
|
||||
toolsMap[toolJson.String()] = struct{}{}
|
||||
}
|
||||
allowTools = &toolsMap
|
||||
}
|
||||
// If allowTools is nil, it means not configured (allow all)
|
||||
|
||||
config.methodHandlers = make(utils.MethodHandlers)
|
||||
// Use config.serverName which is now reliably set
|
||||
currentServerNameForHandlers := config.serverName
|
||||
|
||||
config.methodHandlers["ping"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
utils.OnMCPResponseSuccess(ctx, map[string]any{}, fmt.Sprintf("mcp:%s:ping", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
config.methodHandlers["notifications/initialized"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/initialized", currentServerNameForHandlers), nil, nil, -1)
|
||||
return nil
|
||||
}
|
||||
config.methodHandlers["notifications/cancelled"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
proxywasm.SendHttpResponseWithDetail(202, fmt.Sprintf("mcp:%s:notifications/cancelled", currentServerNameForHandlers), nil, nil, -1)
|
||||
return nil
|
||||
}
|
||||
config.methodHandlers["initialize"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
requestedVersion := params.Get("protocolVersion").String()
|
||||
if requestedVersion == "" {
|
||||
utils.OnMCPResponseError(ctx, errors.New("protocolVersion is required"), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:initialize:error", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCP specification compliant version negotiation:
|
||||
// If the server supports the requested protocol version, it MUST respond with the same version.
|
||||
// Otherwise, the server MUST respond with another protocol version it supports.
|
||||
// This SHOULD be the latest version supported by the server.
|
||||
negotiatedVersion := requestedVersion
|
||||
if !slices.Contains(SupportedMCPVersions, requestedVersion) {
|
||||
// Return the latest supported version instead of rejecting the request
|
||||
negotiatedVersion = SupportedMCPVersions[len(SupportedMCPVersions)-1]
|
||||
log.Warnf("Client requested unsupported version %s, responding with latest supported version %s",
|
||||
requestedVersion, negotiatedVersion)
|
||||
}
|
||||
|
||||
utils.OnMCPResponseSuccess(ctx, map[string]any{
|
||||
"protocolVersion": negotiatedVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": currentServerNameForHandlers, // Use the actual server name (single or composed)
|
||||
"version": "1.0.0",
|
||||
},
|
||||
}, fmt.Sprintf("mcp:%s:initialize", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Override tools/list and tools/call handlers for MCP proxy servers first
|
||||
if config.server != nil {
|
||||
if proxyServer, ok := config.server.(*McpProxyServer); ok {
|
||||
// Use MCP proxy specific handlers that support ActionPause
|
||||
proxyHandlers := CreateMcpProxyMethodHandlers(proxyServer, allowTools)
|
||||
config.methodHandlers["tools/list"] = proxyHandlers["tools/list"]
|
||||
config.methodHandlers["tools/call"] = proxyHandlers["tools/call"]
|
||||
}
|
||||
}
|
||||
|
||||
// Default tools/list handler for non-proxy servers
|
||||
if config.methodHandlers["tools/list"] == nil {
|
||||
config.methodHandlers["tools/list"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
var listedTools []map[string]any
|
||||
// GetMCPTools() will return appropriately formatted tools for both single and composed servers
|
||||
allTools := config.server.GetMCPTools() // For composed, keys are "serverName/toolName"
|
||||
|
||||
// Compute effective allowTools using helper function
|
||||
effectiveAllowTools := computeEffectiveAllowTools(allowTools)
|
||||
|
||||
for toolFullName, tool := range allTools {
|
||||
// For composed server, toolFullName is "originalServerName/originalToolName"
|
||||
// For single server, toolFullName is "originalToolName"
|
||||
// The allowTools map should use the same format as toolFullName
|
||||
if effectiveAllowTools != nil {
|
||||
if _, allow := (*effectiveAllowTools)[toolFullName]; !allow {
|
||||
continue
|
||||
}
|
||||
}
|
||||
toolDef := map[string]any{
|
||||
"name": toolFullName,
|
||||
"description": tool.Description(),
|
||||
"inputSchema": tool.InputSchema(),
|
||||
}
|
||||
// Add outputSchema if tool implements ToolWithOutputSchema (MCP Protocol Version 2025-06-18)
|
||||
if toolWithSchema, ok := tool.(ToolWithOutputSchema); ok {
|
||||
if outputSchema := toolWithSchema.OutputSchema(); len(outputSchema) > 0 {
|
||||
toolDef["outputSchema"] = outputSchema
|
||||
}
|
||||
}
|
||||
listedTools = append(listedTools, toolDef)
|
||||
}
|
||||
utils.OnMCPResponseSuccess(ctx, map[string]any{
|
||||
"tools": listedTools,
|
||||
}, fmt.Sprintf("mcp:%s:tools/list", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Default tools/call handler for non-proxy servers
|
||||
if config.methodHandlers["tools/call"] == nil {
|
||||
config.methodHandlers["tools/call"] = func(ctx wrapper.HttpContext, id utils.JsonRpcID, params gjson.Result) error {
|
||||
if config.isComposed {
|
||||
// This endpoint is for a composed server (toolSet).
|
||||
// Actual tool calls should be routed by mcp-router to individual servers.
|
||||
// If a tools/call request reaches here, it's a misconfiguration or unexpected.
|
||||
errMsg := fmt.Sprintf("tools/call is not supported on a composed toolSet endpoint ('%s'). It should be routed by mcp-router to the target server.", currentServerNameForHandlers)
|
||||
log.Errorf(errMsg)
|
||||
utils.OnMCPResponseError(ctx, errors.New(errMsg), utils.ErrMethodNotFound, fmt.Sprintf("mcp:%s:tools/call:not_supported_on_toolset", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logic for single (non-composed) server
|
||||
toolName := params.Get("name").String() // For single server, this is the direct tool name
|
||||
args := params.Get("arguments")
|
||||
|
||||
// Compute effective allowTools using helper function
|
||||
effectiveAllowTools := computeEffectiveAllowTools(allowTools)
|
||||
|
||||
// Check if tool is allowed
|
||||
if effectiveAllowTools != nil {
|
||||
if _, allow := (*effectiveAllowTools)[toolName]; !allow {
|
||||
utils.OnMCPResponseError(ctx, fmt.Errorf("Tool not allowed: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:tool_not_allowed", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
proxywasm.SetProperty([]string{"mcp_server_name"}, []byte(currentServerNameForHandlers))
|
||||
proxywasm.SetProperty([]string{"mcp_tool_name"}, []byte(toolName))
|
||||
|
||||
toolToCall, ok := config.server.GetMCPTools()[toolName]
|
||||
if !ok {
|
||||
utils.OnMCPResponseError(ctx, fmt.Errorf("unknown tool: %s", toolName), utils.ErrInvalidParams, fmt.Sprintf("mcp:%s:tools/call:invalid_tool_name", currentServerNameForHandlers))
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Tool call [%s] on server [%s] with arguments[%s]", toolName, currentServerNameForHandlers, args.Raw)
|
||||
toolInstance := toolToCall.Create([]byte(args.Raw))
|
||||
err := toolInstance.Call(ctx, config.server) // Pass the single server instance
|
||||
if err != nil {
|
||||
utils.OnMCPToolCallError(ctx, err)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseConfigCore exports the core parsing logic for external use (e.g., validation)
|
||||
func ParseConfigCore(configJson gjson.Result, config *McpServerConfig, opts *ConfigOptions) error {
|
||||
return parseConfigCore(configJson, config, opts)
|
||||
}
|
||||
|
||||
func parseConfig(context wrapper.PluginContext, configJson gjson.Result, config *McpServerConfig) error {
|
||||
registryI := context.GetContext(GlobalToolRegistryKey)
|
||||
if registryI == nil {
|
||||
return errors.New("GlobalToolRegistry not found")
|
||||
}
|
||||
registry, ok := registryI.(*GlobalToolRegistry)
|
||||
if !ok {
|
||||
return errors.New("invalid GlobalToolRegistry")
|
||||
}
|
||||
// Build runtime dependencies using global variables
|
||||
opts := &ConfigOptions{
|
||||
Servers: globalContext.servers,
|
||||
ToolRegistry: registry,
|
||||
}
|
||||
|
||||
// Call the core parsing logic
|
||||
return parseConfigCore(configJson, config, opts)
|
||||
}
|
||||
|
||||
func Load(options ...CtxOption) {
|
||||
for _, opt := range options {
|
||||
opt.Apply(&globalContext)
|
||||
}
|
||||
}
|
||||
|
||||
func Initialize() {
|
||||
if globalContext.servers == nil {
|
||||
panic("At least one mcpserver needs to be added.")
|
||||
}
|
||||
wrapper.SetCtx(
|
||||
"mcp-server",
|
||||
wrapper.PrePluginStartOrReload[McpServerConfig](onPluginStartOrReload),
|
||||
wrapper.ParseConfigWithContext(parseConfig),
|
||||
wrapper.WithLogger[McpServerConfig](&utils.MCPServerLog{}),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
|
||||
wrapper.WithRebuildMaxMemBytes[McpServerConfig](200*1024*1024),
|
||||
)
|
||||
}
|
||||
|
||||
type addMCPServerOption struct {
|
||||
name string
|
||||
server Server
|
||||
}
|
||||
|
||||
func AddMCPServer(name string, server Server) CtxOption {
|
||||
return &addMCPServerOption{
|
||||
name: name,
|
||||
server: server,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *addMCPServerOption) Apply(ctx *Context) {
|
||||
if ctx.servers == nil {
|
||||
ctx.servers = make(map[string]Server)
|
||||
}
|
||||
if _, exist := ctx.servers[o.name]; exist {
|
||||
panic(fmt.Sprintf("Conflict! There is a mcp server with the same name:%s",
|
||||
o.name))
|
||||
}
|
||||
ctx.servers[o.name] = o.server
|
||||
}
|
||||
|
||||
func ToInputSchema(v any) map[string]any {
|
||||
t := reflect.TypeOf(v)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
inputSchema := jsonschema.Reflect(v).Definitions[t.Name()]
|
||||
inputSchemaBytes, _ := json.Marshal(inputSchema)
|
||||
var result map[string]any
|
||||
json.Unmarshal(inputSchemaBytes, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
func StoreServerState(ctx wrapper.HttpContext, config any) {
|
||||
if utils.IsStatefulSession(ctx) {
|
||||
log.Warnf("There is no session ID, unable to store state.")
|
||||
return
|
||||
}
|
||||
configBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
log.Errorf("Server config marshal failed:%v, config:%s", err, configBytes)
|
||||
return
|
||||
}
|
||||
proxywasm.SetProperty([]string{"mcp_server_config"}, configBytes)
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action {
|
||||
ctx.DisableReroute()
|
||||
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
ctx.SetResponseBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
|
||||
// Remove accept-encoding header to prevent backend from compressing the response
|
||||
// This ensures we can properly process and modify the response body
|
||||
proxywasm.RemoveHttpRequestHeader("accept-encoding")
|
||||
|
||||
// Parse MCP-Protocol-Version header and store in context
|
||||
// This allows clients to specify the MCP protocol version via HTTP header
|
||||
// instead of only through the JSON-RPC initialize method
|
||||
protocolVersion, _ := proxywasm.GetHttpRequestHeader("MCP-Protocol-Version")
|
||||
if protocolVersion != "" {
|
||||
// Validate the protocol version against supported versions
|
||||
if slices.Contains(SupportedMCPVersions, protocolVersion) {
|
||||
log.Debugf("MCP Protocol Version set from header: %s", protocolVersion)
|
||||
} else {
|
||||
log.Warnf("Unsupported MCP Protocol Version in header: %s", protocolVersion)
|
||||
}
|
||||
|
||||
// Remove the header from the request to prevent it from being forwarded
|
||||
proxywasm.RemoveHttpRequestHeader("MCP-Protocol-Version")
|
||||
}
|
||||
|
||||
if ctx.Method() == "GET" {
|
||||
proxywasm.SendHttpResponseWithDetail(405, "not_support_sse_on_this_endpoint", nil, nil, -1)
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
// Handle DELETE request for session termination (MCP 2025-06-18 spec)
|
||||
// Per spec: "Clients that no longer need a particular session SHOULD send an HTTP DELETE
|
||||
// to the MCP endpoint with the Mcp-Session-Id header, to explicitly terminate the session."
|
||||
// Per spec: "The server MAY respond to this request with HTTP 405 Method Not Allowed,
|
||||
// indicating that the server does not allow clients to terminate sessions."
|
||||
if ctx.Method() == "DELETE" {
|
||||
proxywasm.SendHttpResponseWithDetail(405, "session_termination_not_supported", nil, nil, -1)
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
if !ctx.HasRequestBody() {
|
||||
proxywasm.SendHttpResponseWithDetail(400, "missing_body_in_mcp_request", nil, nil, -1)
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config McpServerConfig, body []byte) types.Action {
|
||||
return utils.HandleJsonRpcMethod(ctx, body, config.methodHandlers)
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config McpServerConfig) types.Action {
|
||||
// Check if this request initiated SSE channel (tools/list or tools/call with SSE transport)
|
||||
// Only these requests need special SSE streaming response processing
|
||||
if ctx.GetContext(CtxSSEProxyState) != nil {
|
||||
// Check if response has a body
|
||||
if ctx.HasResponseBody() {
|
||||
// Pause streaming response for processing
|
||||
// Content-type validation will be done in onHttpStreamingResponseBody
|
||||
ctx.NeedPauseStreamingResponse()
|
||||
return types.HeaderStopIteration
|
||||
} else {
|
||||
// No body, return error
|
||||
utils.OnMCPResponseError(ctx, fmt.Errorf("no response body in SSE response"), utils.ErrInternalError, "mcp-proxy:sse:no_body")
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
}
|
||||
|
||||
// For non-SSE streaming requests, continue normally
|
||||
return types.HeaderContinue
|
||||
}
|
||||
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte {
|
||||
// Check if this request initiated SSE channel (tools/list or tools/call with SSE transport)
|
||||
// Only these requests need special SSE streaming response processing
|
||||
if ctx.GetContext(CtxSSEProxyState) != nil {
|
||||
return handleSSEStreamingResponse(ctx, config, data, endOfStream)
|
||||
}
|
||||
|
||||
// For non-SSE streaming requests, return data as-is
|
||||
return data
|
||||
}
|
||||
429
plugins/wasm-go/pkg/mcp/server/proxy_auth_test.go
Normal file
429
plugins/wasm-go/pkg/mcp/server/proxy_auth_test.go
Normal file
@@ -0,0 +1,429 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestApiKeyAuthentication tests API key authentication forwarding
|
||||
func TestApiKeyAuthentication(t *testing.T) {
|
||||
server := NewMcpProxyServer("auth-test")
|
||||
|
||||
// Configure security scheme
|
||||
scheme := SecurityScheme{
|
||||
ID: "ApiKeyAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
DefaultCredential: "default-api-key",
|
||||
}
|
||||
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://secure-backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Create tool with client-to-gateway and gateway-to-backend security
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "secure_tool",
|
||||
Description: "Tool requiring authentication",
|
||||
Security: SecurityRequirement{
|
||||
ID: "ApiKeyAuth", // Client-to-gateway authentication
|
||||
Passthrough: true, // Extract client credential for backend use
|
||||
},
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "data",
|
||||
Description: "Data parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
OutputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"result": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The result of the operation",
|
||||
},
|
||||
},
|
||||
},
|
||||
RequestTemplate: RequestTemplate{
|
||||
Security: SecurityRequirement{
|
||||
ID: "ApiKeyAuth", // Gateway-to-backend authentication (same scheme for simplicity)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()["secure_tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"data": "test data",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Authentication is now handled automatically during tool calls
|
||||
// The actual authentication flow is tested in integration tests
|
||||
}
|
||||
|
||||
// TestBearerAuthentication tests Bearer token authentication
|
||||
func TestBearerAuthentication(t *testing.T) {
|
||||
server := NewMcpProxyServer("bearer-auth-test")
|
||||
|
||||
// Configure Bearer security scheme
|
||||
scheme := SecurityScheme{
|
||||
ID: "BearerAuth",
|
||||
Type: "http",
|
||||
Scheme: "bearer",
|
||||
}
|
||||
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("https://secure-backend.example.com/mcp")
|
||||
server.SetTimeout(8000)
|
||||
|
||||
// Create tool with Bearer authentication
|
||||
// Create tool using only gateway-to-backend authentication (no client auth required)
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "bearer_tool",
|
||||
Description: "Tool with Bearer authentication to backend only",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "query",
|
||||
Description: "Query parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
RequestTemplate: RequestTemplate{
|
||||
Security: SecurityRequirement{
|
||||
ID: "BearerAuth", // Only gateway-to-backend authentication
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()["bearer_tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"query": "test query",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Authentication is now handled automatically during tool calls
|
||||
// The actual authentication flow is tested in integration tests
|
||||
|
||||
// Test backward compatibility: this tool uses RequestTemplate.Security (legacy way)
|
||||
// which should still work
|
||||
}
|
||||
|
||||
// TestBasicAuthentication tests Basic authentication
|
||||
func TestBasicAuthentication(t *testing.T) {
|
||||
server := NewMcpProxyServer("basic-auth-test")
|
||||
|
||||
// Configure Basic security scheme
|
||||
scheme := SecurityScheme{
|
||||
ID: "BasicAuth",
|
||||
Type: "http",
|
||||
Scheme: "basic",
|
||||
}
|
||||
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Test tool call with Basic authentication
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "basic_tool",
|
||||
Description: "Tool with Basic authentication",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "resource",
|
||||
Description: "Resource identifier",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
RequestTemplate: RequestTemplate{
|
||||
Security: SecurityRequirement{
|
||||
ID: "BasicAuth",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()["basic_tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"resource": "test-resource",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Authentication is now handled automatically during tool calls
|
||||
// The actual authentication flow is tested in integration tests
|
||||
|
||||
// Test OutputSchema functionality (only for tools that have it configured)
|
||||
if toolWithOutputSchema, ok := tool.(ToolWithOutputSchema); ok {
|
||||
outputSchema := toolWithOutputSchema.OutputSchema()
|
||||
if outputSchema != nil {
|
||||
// Only validate if outputSchema is configured
|
||||
assert.Equal(t, "object", outputSchema["type"])
|
||||
properties, hasProperties := outputSchema["properties"].(map[string]any)
|
||||
require.True(t, hasProperties)
|
||||
resultSchema, hasResult := properties["result"].(map[string]any)
|
||||
require.True(t, hasResult)
|
||||
assert.Equal(t, "string", resultSchema["type"])
|
||||
assert.Equal(t, "The result of the operation", resultSchema["description"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleSecuritySchemes tests multiple security schemes in one server
|
||||
func TestMultipleSecuritySchemes(t *testing.T) {
|
||||
server := NewMcpProxyServer("multi-auth-test")
|
||||
|
||||
// Add multiple security schemes
|
||||
schemes := []SecurityScheme{
|
||||
{
|
||||
ID: "ApiKeyAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
},
|
||||
{
|
||||
ID: "BearerAuth",
|
||||
Type: "http",
|
||||
Scheme: "bearer",
|
||||
},
|
||||
}
|
||||
|
||||
for _, scheme := range schemes {
|
||||
server.AddSecurityScheme(scheme)
|
||||
}
|
||||
|
||||
// Test that both schemes are available
|
||||
for _, scheme := range schemes {
|
||||
retrievedScheme, exists := server.GetSecurityScheme(scheme.ID)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, scheme.ID, retrievedScheme.ID)
|
||||
assert.Equal(t, scheme.Type, retrievedScheme.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyAuthContext, RequestTemplate, SecurityConfig and authentication methods
|
||||
// are now implemented in proxy_server.go
|
||||
|
||||
// TestToolsListAuthentication tests authentication configuration for tools/list requests
|
||||
func TestToolsListAuthentication(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-server")
|
||||
|
||||
// Add a security scheme for global authentication
|
||||
scheme := SecurityScheme{
|
||||
ID: "GlobalAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
DefaultCredential: "default-global-key",
|
||||
}
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Test that we can retrieve the security scheme
|
||||
retrievedScheme, exists := server.GetSecurityScheme("GlobalAuth")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "GlobalAuth", retrievedScheme.ID)
|
||||
assert.Equal(t, "apiKey", retrievedScheme.Type)
|
||||
assert.Equal(t, "header", retrievedScheme.In)
|
||||
assert.Equal(t, "X-API-Key", retrievedScheme.Name)
|
||||
|
||||
// Test setting default security directly on server
|
||||
defaultDownstreamSecurity := SecurityRequirement{
|
||||
ID: "GlobalAuth",
|
||||
Passthrough: true,
|
||||
}
|
||||
defaultUpstreamSecurity := SecurityRequirement{
|
||||
ID: "GlobalAuth",
|
||||
}
|
||||
|
||||
server.SetDefaultDownstreamSecurity(defaultDownstreamSecurity)
|
||||
server.SetDefaultUpstreamSecurity(defaultUpstreamSecurity)
|
||||
|
||||
// Verify default security settings
|
||||
retrievedDownstream := server.GetDefaultDownstreamSecurity()
|
||||
assert.Equal(t, "GlobalAuth", retrievedDownstream.ID)
|
||||
assert.True(t, retrievedDownstream.Passthrough)
|
||||
|
||||
retrievedUpstream := server.GetDefaultUpstreamSecurity()
|
||||
assert.Equal(t, "GlobalAuth", retrievedUpstream.ID)
|
||||
|
||||
t.Logf("Tools/list authentication configuration test completed successfully")
|
||||
}
|
||||
|
||||
// TestDefaultSecurityFallback tests the fallback mechanism from tool-level to default security
|
||||
func TestDefaultSecurityFallback(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-server")
|
||||
|
||||
// Add security schemes
|
||||
defaultScheme := SecurityScheme{
|
||||
ID: "DefaultAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Default-Key",
|
||||
DefaultCredential: "default-key",
|
||||
}
|
||||
toolScheme := SecurityScheme{
|
||||
ID: "ToolAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Tool-Key",
|
||||
DefaultCredential: "tool-key",
|
||||
}
|
||||
server.AddSecurityScheme(defaultScheme)
|
||||
server.AddSecurityScheme(toolScheme)
|
||||
|
||||
// Test tool configuration with tool-level security (should use tool-level, not default)
|
||||
toolConfigWithSecurity := McpProxyToolConfig{
|
||||
Name: "secure_tool",
|
||||
Description: "Tool with its own security",
|
||||
Security: SecurityRequirement{
|
||||
ID: "ToolAuth",
|
||||
Passthrough: true,
|
||||
},
|
||||
RequestTemplate: RequestTemplate{
|
||||
Security: SecurityRequirement{
|
||||
ID: "ToolAuth",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test tool configuration without tool-level security (should fallback to default)
|
||||
toolConfigWithoutSecurity := McpProxyToolConfig{
|
||||
Name: "fallback_tool",
|
||||
Description: "Tool that falls back to default security",
|
||||
// No Security field configured, should use default
|
||||
RequestTemplate: RequestTemplate{
|
||||
// No Security field configured, should use default
|
||||
},
|
||||
}
|
||||
|
||||
// Set default security directly on server
|
||||
server.SetDefaultDownstreamSecurity(SecurityRequirement{
|
||||
ID: "DefaultAuth",
|
||||
Passthrough: false,
|
||||
})
|
||||
server.SetDefaultUpstreamSecurity(SecurityRequirement{
|
||||
ID: "DefaultAuth",
|
||||
})
|
||||
|
||||
// Set server configuration directly
|
||||
server.SetMcpServerURL("http://backend.example.com")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Add tools to server
|
||||
err := server.AddProxyTool(toolConfigWithSecurity)
|
||||
assert.NoError(t, err)
|
||||
err = server.AddProxyTool(toolConfigWithoutSecurity)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify tools were added
|
||||
tools := server.GetMCPTools()
|
||||
assert.Contains(t, tools, "secure_tool")
|
||||
assert.Contains(t, tools, "fallback_tool")
|
||||
|
||||
t.Logf("Default security fallback test completed successfully")
|
||||
}
|
||||
|
||||
// TestURLModificationInAuthentication tests that authentication can modify the URL (e.g., adding query parameters)
|
||||
func TestURLModificationInAuthentication(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-server")
|
||||
|
||||
// Add a security scheme that adds parameters to query (apiKey in query)
|
||||
scheme := SecurityScheme{
|
||||
ID: "QueryApiKey",
|
||||
Type: "apiKey",
|
||||
In: "query",
|
||||
Name: "api_key",
|
||||
DefaultCredential: "test-key-123",
|
||||
}
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Verify the security scheme was added correctly
|
||||
retrievedScheme, exists := server.GetSecurityScheme("QueryApiKey")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "apiKey", retrievedScheme.Type)
|
||||
assert.Equal(t, "query", retrievedScheme.In)
|
||||
assert.Equal(t, "api_key", retrievedScheme.Name)
|
||||
|
||||
t.Logf("URL modification authentication configuration test completed successfully")
|
||||
}
|
||||
|
||||
// TestProxyServerFields tests the server-level field setting and getting
|
||||
func TestProxyServerFields(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-server")
|
||||
|
||||
// Test mcpServerURL
|
||||
testURL := "http://mcp.example.com:8080/mcp"
|
||||
server.SetMcpServerURL(testURL)
|
||||
assert.Equal(t, testURL, server.GetMcpServerURL())
|
||||
|
||||
// Test timeout
|
||||
testTimeout := 10000
|
||||
server.SetTimeout(testTimeout)
|
||||
assert.Equal(t, testTimeout, server.GetTimeout())
|
||||
|
||||
// Test default security settings
|
||||
downstreamSec := SecurityRequirement{
|
||||
ID: "test-downstream",
|
||||
Passthrough: true,
|
||||
}
|
||||
upstreamSec := SecurityRequirement{
|
||||
ID: "test-upstream",
|
||||
}
|
||||
|
||||
server.SetDefaultDownstreamSecurity(downstreamSec)
|
||||
server.SetDefaultUpstreamSecurity(upstreamSec)
|
||||
|
||||
assert.Equal(t, "test-downstream", server.GetDefaultDownstreamSecurity().ID)
|
||||
assert.True(t, server.GetDefaultDownstreamSecurity().Passthrough)
|
||||
assert.Equal(t, "test-upstream", server.GetDefaultUpstreamSecurity().ID)
|
||||
|
||||
t.Logf("Proxy server fields test completed successfully")
|
||||
}
|
||||
302
plugins/wasm-go/pkg/mcp/server/proxy_integration_test.go
Normal file
302
plugins/wasm-go/pkg/mcp/server/proxy_integration_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockHttpContext is a mock implementation for testing - skipping interface implementation for now
|
||||
// Tests that require full HttpContext will be tested in integration tests with real host
|
||||
type MockHttpContext struct {
|
||||
responseBody []byte
|
||||
responseStatus int
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
// TestMcpProtocolInitialization tests the MCP protocol initialization flow
|
||||
func TestMcpProtocolInitialization(t *testing.T) {
|
||||
// Create proxy server
|
||||
server := NewMcpProxyServer("test-proxy")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://mock-backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Create proxy tool
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "test-tool",
|
||||
Description: "Test tool for initialization",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "input",
|
||||
Description: "Test input",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()["test-tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
// Create tool instance with parameters
|
||||
params := map[string]interface{}{
|
||||
"input": "test value",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// mockCtx := &MockHttpContext{}
|
||||
// err = toolInstance.Call(mockCtx, server)
|
||||
// assert.NoError(t, err)
|
||||
|
||||
// Test the tool creation was successful
|
||||
assert.NotNil(t, toolInstance)
|
||||
}
|
||||
|
||||
// TestMcpSessionManagement tests temporary session creation and cleanup
|
||||
func TestMcpSessionManagement(t *testing.T) {
|
||||
_ = NewMcpProxyServer("session-test")
|
||||
|
||||
// Skip session management test until implemented
|
||||
t.Skip("Session management not implemented yet")
|
||||
|
||||
// Test session creation
|
||||
sessionManager := NewMcpSessionManager()
|
||||
sessionID, err := sessionManager.CreateSession("http://backend.example.com/mcp")
|
||||
|
||||
// This will fail until session management is implemented
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, sessionID)
|
||||
|
||||
// Test session retrieval
|
||||
session, exists := sessionManager.GetSession(sessionID)
|
||||
assert.True(t, exists)
|
||||
assert.NotNil(t, session)
|
||||
|
||||
// Test session cleanup
|
||||
sessionManager.CleanupSession(sessionID)
|
||||
_, exists = sessionManager.GetSession(sessionID)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
// TestMcpProtocolVersionNegotiation tests protocol version handling
|
||||
func TestMcpProtocolVersionNegotiation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedVersion string
|
||||
supportedVersions []string
|
||||
shouldSucceed bool
|
||||
expectedVersion string
|
||||
}{
|
||||
{
|
||||
name: "supported version 2025-03-26",
|
||||
requestedVersion: "2025-03-26",
|
||||
supportedVersions: []string{"2024-11-05", "2025-03-26"},
|
||||
shouldSucceed: true,
|
||||
expectedVersion: "2025-03-26",
|
||||
},
|
||||
{
|
||||
name: "unsupported version",
|
||||
requestedVersion: "2026-01-01",
|
||||
supportedVersions: []string{"2024-11-05", "2025-03-26"},
|
||||
shouldSucceed: false,
|
||||
expectedVersion: "",
|
||||
},
|
||||
{
|
||||
name: "fallback to supported version",
|
||||
requestedVersion: "2025-06-18",
|
||||
supportedVersions: []string{"2024-11-05", "2025-03-26"},
|
||||
shouldSucceed: false,
|
||||
expectedVersion: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip until NewMcpVersionNegotiator is implemented
|
||||
t.Skip("Version negotiation not implemented yet")
|
||||
|
||||
negotiator := NewMcpVersionNegotiator(tt.supportedVersions)
|
||||
version, err := negotiator.NegotiateVersion(tt.requestedVersion)
|
||||
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedVersion, version)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMcpInitializeRequest tests the initialize request format and handling
|
||||
func TestMcpInitializeRequest(t *testing.T) {
|
||||
_ = NewMcpProxyServer("init-test")
|
||||
|
||||
// Skip until CreateInitializeRequest is implemented
|
||||
t.Skip("MCP protocol initialization not implemented yet")
|
||||
|
||||
// Test initialize request creation
|
||||
initRequest := CreateInitializeRequest()
|
||||
|
||||
assert.Equal(t, "2.0", initRequest.JsonRPC)
|
||||
assert.Equal(t, "initialize", initRequest.Method)
|
||||
assert.NotNil(t, initRequest.Params)
|
||||
|
||||
// Validate client info
|
||||
params := initRequest.Params.(map[string]interface{})
|
||||
clientInfo := params["clientInfo"].(map[string]interface{})
|
||||
assert.Equal(t, "Higress-mcp-proxy", clientInfo["name"])
|
||||
assert.Equal(t, "1.0.0", clientInfo["version"])
|
||||
|
||||
// Test protocol version
|
||||
assert.Equal(t, "2025-03-26", params["protocolVersion"])
|
||||
}
|
||||
|
||||
// TestMcpNotificationsInitialized tests the notifications/initialized message
|
||||
func TestMcpNotificationsInitialized(t *testing.T) {
|
||||
// Skip until CreateInitializedNotification is implemented
|
||||
t.Skip("MCP notifications not implemented yet")
|
||||
|
||||
// Test notifications/initialized request creation
|
||||
notification := CreateInitializedNotification()
|
||||
|
||||
assert.Equal(t, "2.0", notification.JsonRPC)
|
||||
assert.Equal(t, "notifications/initialized", notification.Method)
|
||||
assert.Nil(t, notification.ID) // Notifications don't have IDs
|
||||
}
|
||||
|
||||
// TestMcpErrorHandling tests error response handling and source identification
|
||||
func TestMcpErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorType string
|
||||
originalError error
|
||||
expectedSource string
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "backend connection error",
|
||||
errorType: "connection",
|
||||
originalError: assert.AnError,
|
||||
expectedSource: "mcp-proxy",
|
||||
expectedCode: -32603,
|
||||
},
|
||||
{
|
||||
name: "backend timeout error",
|
||||
errorType: "timeout",
|
||||
originalError: assert.AnError,
|
||||
expectedSource: "mcp-proxy",
|
||||
expectedCode: -32000,
|
||||
},
|
||||
{
|
||||
name: "protocol version error",
|
||||
errorType: "version",
|
||||
originalError: assert.AnError,
|
||||
expectedSource: "mcp-proxy",
|
||||
expectedCode: -32602,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip until CreateMcpErrorResponse is implemented
|
||||
t.Skip("MCP error handling not implemented yet")
|
||||
|
||||
errorResponse := CreateMcpErrorResponse(tt.errorType, tt.originalError, "http://backend.example.com/mcp")
|
||||
|
||||
assert.Equal(t, "2.0", errorResponse.JsonRPC)
|
||||
assert.NotNil(t, errorResponse.Error)
|
||||
assert.Equal(t, tt.expectedCode, errorResponse.Error.Code)
|
||||
assert.Equal(t, tt.expectedSource, errorResponse.Error.Data["source"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper types and functions that will fail until implemented
|
||||
|
||||
type McpSessionManager struct{}
|
||||
|
||||
func NewMcpSessionManager() *McpSessionManager {
|
||||
panic("McpSessionManager not implemented yet")
|
||||
}
|
||||
|
||||
func (m *McpSessionManager) CreateSession(backendURL string) (string, error) {
|
||||
panic("CreateSession not implemented yet")
|
||||
}
|
||||
|
||||
func (m *McpSessionManager) GetSession(sessionID string) (interface{}, bool) {
|
||||
panic("GetSession not implemented yet")
|
||||
}
|
||||
|
||||
func (m *McpSessionManager) CleanupSession(sessionID string) {
|
||||
panic("CleanupSession not implemented yet")
|
||||
}
|
||||
|
||||
type McpVersionNegotiator struct {
|
||||
supportedVersions []string
|
||||
}
|
||||
|
||||
func NewMcpVersionNegotiator(versions []string) *McpVersionNegotiator {
|
||||
panic("McpVersionNegotiator not implemented yet")
|
||||
}
|
||||
|
||||
func (n *McpVersionNegotiator) NegotiateVersion(requested string) (string, error) {
|
||||
panic("NegotiateVersion not implemented yet")
|
||||
}
|
||||
|
||||
type McpRequest struct {
|
||||
JsonRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type McpErrorResponse struct {
|
||||
JsonRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
Error *McpError `json:"error"`
|
||||
}
|
||||
|
||||
type McpError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func CreateInitializeRequest() *McpRequest {
|
||||
panic("CreateInitializeRequest not implemented yet")
|
||||
}
|
||||
|
||||
func CreateInitializedNotification() *McpRequest {
|
||||
panic("CreateInitializedNotification not implemented yet")
|
||||
}
|
||||
|
||||
func CreateMcpErrorResponse(errorType string, originalError error, backendURL string) *McpErrorResponse {
|
||||
panic("CreateMcpErrorResponse not implemented yet")
|
||||
}
|
||||
500
plugins/wasm-go/pkg/mcp/server/proxy_server.go
Normal file
500
plugins/wasm-go/pkg/mcp/server/proxy_server.go
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
// McpProxyConfig represents the configuration for MCP proxy server
|
||||
// Note: mcpServerURL, timeout, defaultDownstreamSecurity, and defaultUpstreamSecurity
|
||||
// are now direct server fields, not part of this config structure
|
||||
type McpProxyConfig struct {
|
||||
// This structure is kept for any additional server configuration that may be needed in the future
|
||||
// Currently, most configuration is handled as direct server fields
|
||||
}
|
||||
|
||||
// TransportProtocol represents the transport protocol type for MCP proxy
|
||||
type TransportProtocol string
|
||||
|
||||
const (
|
||||
TransportHTTP TransportProtocol = "http" // StreamableHTTP protocol
|
||||
TransportSSE TransportProtocol = "sse" // SSE protocol
|
||||
)
|
||||
|
||||
// ToolArg represents an argument for a proxy tool
|
||||
type ToolArg struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type"`
|
||||
Required bool `json:"required"`
|
||||
Default interface{} `json:"default,omitempty"`
|
||||
Enum []interface{} `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// McpProxyToolConfig represents a tool configuration for MCP proxy
|
||||
type McpProxyToolConfig struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Security SecurityRequirement `json:"security,omitempty"` // Tool-level security for MCP Client to MCP Server
|
||||
Args []ToolArg `json:"args"`
|
||||
OutputSchema map[string]any `json:"outputSchema,omitempty"` // Output schema for MCP Protocol Version 2025-06-18
|
||||
RequestTemplate RequestTemplate `json:"requestTemplate,omitempty"`
|
||||
}
|
||||
|
||||
// RequestTemplate defines request template configuration for proxy tools
|
||||
type RequestTemplate struct {
|
||||
Security SecurityRequirement `json:"security,omitempty"`
|
||||
}
|
||||
|
||||
// McpProxyServer implements Server interface for MCP-to-MCP proxy
|
||||
type McpProxyServer struct {
|
||||
Name string
|
||||
base BaseMCPServer
|
||||
toolsConfig map[string]McpProxyToolConfig
|
||||
securitySchemes map[string]SecurityScheme
|
||||
defaultDownstreamSecurity SecurityRequirement // Default client-to-gateway authentication
|
||||
defaultUpstreamSecurity SecurityRequirement // Default gateway-to-backend authentication
|
||||
mcpServerURL string // Backend MCP server URL
|
||||
timeout int // Request timeout in milliseconds
|
||||
transport TransportProtocol // Transport protocol (http or sse)
|
||||
passthroughAuthHeader bool // If true, pass through Authorization header even without downstream security
|
||||
}
|
||||
|
||||
// NewMcpProxyServer creates a new MCP proxy server
|
||||
func NewMcpProxyServer(name string) *McpProxyServer {
|
||||
return &McpProxyServer{
|
||||
Name: name,
|
||||
base: NewBaseMCPServer(),
|
||||
toolsConfig: make(map[string]McpProxyToolConfig),
|
||||
securitySchemes: make(map[string]SecurityScheme),
|
||||
}
|
||||
}
|
||||
|
||||
// AddSecurityScheme adds a security scheme to the server's map
|
||||
func (s *McpProxyServer) AddSecurityScheme(scheme SecurityScheme) {
|
||||
if s.securitySchemes == nil {
|
||||
s.securitySchemes = make(map[string]SecurityScheme)
|
||||
}
|
||||
s.securitySchemes[scheme.ID] = scheme
|
||||
}
|
||||
|
||||
// GetSecurityScheme retrieves a security scheme by its ID from the map
|
||||
func (s *McpProxyServer) GetSecurityScheme(id string) (SecurityScheme, bool) {
|
||||
scheme, ok := s.securitySchemes[id]
|
||||
return scheme, ok
|
||||
}
|
||||
|
||||
// SetDefaultDownstreamSecurity sets the default downstream security configuration
|
||||
func (s *McpProxyServer) SetDefaultDownstreamSecurity(security SecurityRequirement) {
|
||||
s.defaultDownstreamSecurity = security
|
||||
}
|
||||
|
||||
// GetDefaultDownstreamSecurity gets the default downstream security configuration
|
||||
func (s *McpProxyServer) GetDefaultDownstreamSecurity() SecurityRequirement {
|
||||
return s.defaultDownstreamSecurity
|
||||
}
|
||||
|
||||
// SetDefaultUpstreamSecurity sets the default upstream security configuration
|
||||
func (s *McpProxyServer) SetDefaultUpstreamSecurity(security SecurityRequirement) {
|
||||
s.defaultUpstreamSecurity = security
|
||||
}
|
||||
|
||||
// GetDefaultUpstreamSecurity gets the default upstream security configuration
|
||||
func (s *McpProxyServer) GetDefaultUpstreamSecurity() SecurityRequirement {
|
||||
return s.defaultUpstreamSecurity
|
||||
}
|
||||
|
||||
// SetMcpServerURL sets the backend MCP server URL
|
||||
func (s *McpProxyServer) SetMcpServerURL(url string) {
|
||||
s.mcpServerURL = url
|
||||
}
|
||||
|
||||
// GetMcpServerURL gets the backend MCP server URL
|
||||
func (s *McpProxyServer) GetMcpServerURL() string {
|
||||
return s.mcpServerURL
|
||||
}
|
||||
|
||||
// SetTimeout sets the request timeout in milliseconds
|
||||
func (s *McpProxyServer) SetTimeout(timeout int) {
|
||||
s.timeout = timeout
|
||||
}
|
||||
|
||||
// GetTimeout gets the request timeout in milliseconds
|
||||
func (s *McpProxyServer) GetTimeout() int {
|
||||
return s.timeout
|
||||
}
|
||||
|
||||
// SetTransport sets the transport protocol
|
||||
func (s *McpProxyServer) SetTransport(transport TransportProtocol) {
|
||||
s.transport = transport
|
||||
}
|
||||
|
||||
// GetTransport gets the transport protocol
|
||||
func (s *McpProxyServer) GetTransport() TransportProtocol {
|
||||
return s.transport
|
||||
}
|
||||
|
||||
// AddMCPTool implements Server interface
|
||||
func (s *McpProxyServer) AddMCPTool(name string, tool Tool) Server {
|
||||
s.base.AddMCPTool(name, tool)
|
||||
return s
|
||||
}
|
||||
|
||||
// AddProxyTool adds a proxy tool configuration
|
||||
func (s *McpProxyServer) AddProxyTool(toolConfig McpProxyToolConfig) error {
|
||||
s.toolsConfig[toolConfig.Name] = toolConfig
|
||||
s.base.AddMCPTool(toolConfig.Name, &McpProxyTool{
|
||||
serverName: s.Name,
|
||||
name: toolConfig.Name,
|
||||
toolConfig: toolConfig,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPTools implements Server interface
|
||||
func (s *McpProxyServer) GetMCPTools() map[string]Tool {
|
||||
return s.base.GetMCPTools()
|
||||
}
|
||||
|
||||
// SetConfig implements Server interface
|
||||
func (s *McpProxyServer) SetConfig(config []byte) {
|
||||
s.base.SetConfig(config)
|
||||
}
|
||||
|
||||
// GetConfig implements Server interface
|
||||
func (s *McpProxyServer) GetConfig(v any) {
|
||||
s.base.GetConfig(v)
|
||||
}
|
||||
|
||||
// Clone implements Server interface
|
||||
func (s *McpProxyServer) Clone() Server {
|
||||
newServer := &McpProxyServer{
|
||||
Name: s.Name,
|
||||
base: s.base.CloneBase(),
|
||||
toolsConfig: make(map[string]McpProxyToolConfig),
|
||||
securitySchemes: make(map[string]SecurityScheme),
|
||||
}
|
||||
for k, v := range s.toolsConfig {
|
||||
newServer.toolsConfig[k] = v
|
||||
}
|
||||
// Deep copy securitySchemes
|
||||
if s.securitySchemes != nil {
|
||||
for k, v := range s.securitySchemes {
|
||||
newServer.securitySchemes[k] = v
|
||||
}
|
||||
}
|
||||
return newServer
|
||||
}
|
||||
|
||||
// GetToolConfig returns the proxy tool configuration for a given tool name
|
||||
func (s *McpProxyServer) GetToolConfig(name string) (McpProxyToolConfig, bool) {
|
||||
config, ok := s.toolsConfig[name]
|
||||
return config, ok
|
||||
}
|
||||
|
||||
// SetPassthroughAuthHeader sets the passthrough auth header flag
|
||||
func (s *McpProxyServer) SetPassthroughAuthHeader(passthrough bool) {
|
||||
s.passthroughAuthHeader = passthrough
|
||||
}
|
||||
|
||||
// GetPassthroughAuthHeader gets the passthrough auth header flag
|
||||
func (s *McpProxyServer) GetPassthroughAuthHeader() bool {
|
||||
return s.passthroughAuthHeader
|
||||
}
|
||||
|
||||
// ForwardToolsList forwards tools/list request to backend MCP server
|
||||
func (s *McpProxyServer) ForwardToolsList(ctx HttpContext, cursor *string) error {
|
||||
wrapperCtx := ctx.(wrapper.HttpContext)
|
||||
|
||||
// Handle default downstream security for tools/list requests
|
||||
// tools/list requests use server-level default authentication configuration
|
||||
passthroughCredential := ""
|
||||
downstreamSecurity := s.GetDefaultDownstreamSecurity()
|
||||
if downstreamSecurity.ID != "" {
|
||||
clientScheme, schemeOk := s.GetSecurityScheme(downstreamSecurity.ID)
|
||||
if !schemeOk {
|
||||
log.Warnf("Default downstream security scheme ID '%s' not found for tools/list request.", downstreamSecurity.ID)
|
||||
} else {
|
||||
// Extract and remove the credential from the incoming request
|
||||
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to extract/remove incoming credential for tools/list using scheme %s: %v", clientScheme.ID, err)
|
||||
} else if extractedCred == "" {
|
||||
log.Debugf("No incoming credential found for tools/list using scheme %s for extraction/removal.", clientScheme.ID)
|
||||
}
|
||||
|
||||
// Only use passthrough if explicitly configured
|
||||
if downstreamSecurity.Passthrough && extractedCred != "" {
|
||||
passthroughCredential = extractedCred
|
||||
log.Debugf("Passthrough credential set for tools/list request.")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: Remove Authorization header if no downstream security is defined
|
||||
// This prevents downstream credentials from being mistakenly passed to upstream
|
||||
// Unless passthroughAuthHeader is explicitly set to true
|
||||
if !s.GetPassthroughAuthHeader() {
|
||||
proxywasm.RemoveHttpRequestHeader("Authorization")
|
||||
}
|
||||
}
|
||||
|
||||
// Create protocol handler using server fields
|
||||
handler := NewMcpProtocolHandler(s.GetMcpServerURL(), s.GetTimeout())
|
||||
|
||||
// Prepare authentication information for gateway-to-backend communication
|
||||
var authInfo *ProxyAuthInfo
|
||||
upstreamSecurity := s.GetDefaultUpstreamSecurity()
|
||||
if upstreamSecurity.ID != "" {
|
||||
authInfo = &ProxyAuthInfo{
|
||||
SecuritySchemeID: upstreamSecurity.ID,
|
||||
PassthroughCredential: passthroughCredential,
|
||||
Server: s,
|
||||
}
|
||||
}
|
||||
|
||||
// This will handle initialization asynchronously if needed and use ActionPause/Resume
|
||||
return handler.ForwardToolsList(wrapperCtx, cursor, authInfo)
|
||||
}
|
||||
|
||||
// McpProxyTool implements Tool interface for MCP-to-MCP proxy
|
||||
type McpProxyTool struct {
|
||||
serverName string
|
||||
name string
|
||||
toolConfig McpProxyToolConfig
|
||||
arguments map[string]interface{}
|
||||
}
|
||||
|
||||
// Create implements Tool interface
|
||||
func (t *McpProxyTool) Create(params []byte) Tool {
|
||||
newTool := &McpProxyTool{
|
||||
serverName: t.serverName,
|
||||
name: t.name,
|
||||
toolConfig: t.toolConfig,
|
||||
arguments: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if len(params) > 0 {
|
||||
json.Unmarshal(params, &newTool.arguments)
|
||||
}
|
||||
|
||||
return newTool
|
||||
}
|
||||
|
||||
// Call implements Tool interface - this is where the MCP protocol handling happens
|
||||
func (t *McpProxyTool) Call(httpCtx HttpContext, server Server) error {
|
||||
ctx := httpCtx.(wrapper.HttpContext)
|
||||
|
||||
// Get proxy server instance to access configuration
|
||||
proxyServer, ok := server.(*McpProxyServer)
|
||||
if !ok {
|
||||
return fmt.Errorf("server is not a McpProxyServer")
|
||||
}
|
||||
|
||||
// Handle tool-level or default downstream security: extract credential for passthrough if configured
|
||||
// toolConfig.Security represents client-to-gateway authentication, falls back to server's defaultDownstreamSecurity
|
||||
passthroughCredential := ""
|
||||
var downstreamSecurity SecurityRequirement
|
||||
if t.toolConfig.Security.ID != "" {
|
||||
// Use tool-level security if configured
|
||||
downstreamSecurity = t.toolConfig.Security
|
||||
log.Debugf("Using tool-level downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
|
||||
} else {
|
||||
// Fall back to server's default downstream security
|
||||
downstreamSecurity = proxyServer.GetDefaultDownstreamSecurity()
|
||||
if downstreamSecurity.ID != "" {
|
||||
log.Debugf("Using default downstream security for tool %s: %s", t.name, downstreamSecurity.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if downstreamSecurity.ID != "" {
|
||||
clientScheme, schemeOk := proxyServer.GetSecurityScheme(downstreamSecurity.ID)
|
||||
if !schemeOk {
|
||||
log.Warnf("Downstream security scheme ID '%s' not found for tool %s.", downstreamSecurity.ID, t.name)
|
||||
} else {
|
||||
// Extract and remove the credential from the incoming request
|
||||
extractedCred, err := ExtractAndRemoveIncomingCredential(clientScheme)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to extract/remove incoming credential for tool %s using scheme %s: %v", t.name, clientScheme.ID, err)
|
||||
} else if extractedCred == "" {
|
||||
log.Debugf("No incoming credential found for tool %s using scheme %s for extraction/removal.", t.name, clientScheme.ID)
|
||||
}
|
||||
|
||||
// Only use passthrough if explicitly configured
|
||||
if downstreamSecurity.Passthrough && extractedCred != "" {
|
||||
passthroughCredential = extractedCred
|
||||
log.Debugf("Passthrough credential set for tool %s.", t.name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: Remove Authorization header if no downstream security is defined
|
||||
// This prevents downstream credentials from being mistakenly passed to upstream
|
||||
// Unless passthroughAuthHeader is explicitly set to true
|
||||
if !proxyServer.GetPassthroughAuthHeader() {
|
||||
proxywasm.RemoveHttpRequestHeader("Authorization")
|
||||
}
|
||||
}
|
||||
|
||||
// Create protocol handler using server fields
|
||||
handler := NewMcpProtocolHandler(proxyServer.GetMcpServerURL(), proxyServer.GetTimeout())
|
||||
|
||||
// Prepare authentication information for gateway-to-backend communication
|
||||
// toolConfig.RequestTemplate.Security represents gateway-to-backend authentication, falls back to server's defaultUpstreamSecurity
|
||||
var authInfo *ProxyAuthInfo
|
||||
var upstreamSecurity SecurityRequirement
|
||||
if t.toolConfig.RequestTemplate.Security.ID != "" {
|
||||
// Use tool-level upstream security if configured
|
||||
upstreamSecurity = t.toolConfig.RequestTemplate.Security
|
||||
log.Debugf("Using tool-level upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
|
||||
} else {
|
||||
// Fall back to server's default upstream security
|
||||
upstreamSecurity = proxyServer.GetDefaultUpstreamSecurity()
|
||||
if upstreamSecurity.ID != "" {
|
||||
log.Debugf("Using default upstream security for tool %s: %s", t.name, upstreamSecurity.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if upstreamSecurity.ID != "" {
|
||||
authInfo = &ProxyAuthInfo{
|
||||
SecuritySchemeID: upstreamSecurity.ID,
|
||||
PassthroughCredential: passthroughCredential,
|
||||
Server: proxyServer,
|
||||
}
|
||||
}
|
||||
|
||||
// This will handle initialization asynchronously if needed and use ActionPause/Resume
|
||||
return handler.ForwardToolsCall(ctx, t.name, t.arguments, authInfo)
|
||||
}
|
||||
|
||||
// Description implements Tool interface
|
||||
func (t *McpProxyTool) Description() string {
|
||||
return t.toolConfig.Description
|
||||
}
|
||||
|
||||
// InputSchema implements Tool interface
|
||||
func (t *McpProxyTool) InputSchema() map[string]any {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": make(map[string]any),
|
||||
"required": []string{},
|
||||
}
|
||||
|
||||
properties := schema["properties"].(map[string]any)
|
||||
var required []string
|
||||
|
||||
for _, arg := range t.toolConfig.Args {
|
||||
argSchema := map[string]any{
|
||||
"type": arg.Type,
|
||||
"description": arg.Description,
|
||||
}
|
||||
|
||||
if arg.Default != nil {
|
||||
argSchema["default"] = arg.Default
|
||||
}
|
||||
|
||||
if len(arg.Enum) > 0 {
|
||||
argSchema["enum"] = arg.Enum
|
||||
}
|
||||
|
||||
properties[arg.Name] = argSchema
|
||||
|
||||
if arg.Required {
|
||||
required = append(required, arg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
schema["required"] = required
|
||||
return schema
|
||||
}
|
||||
|
||||
// OutputSchema implements Tool interface (MCP Protocol Version 2025-06-18)
|
||||
func (t *McpProxyTool) OutputSchema() map[string]any {
|
||||
return t.toolConfig.OutputSchema
|
||||
}
|
||||
|
||||
// ValidateSecurityScheme validates a security scheme configuration
|
||||
func ValidateSecurityScheme(scheme SecurityScheme) error {
|
||||
if scheme.ID == "" {
|
||||
return fmt.Errorf("security scheme ID is required")
|
||||
}
|
||||
|
||||
if scheme.Type != "apiKey" && scheme.Type != "http" {
|
||||
return fmt.Errorf("invalid security scheme type: %s", scheme.Type)
|
||||
}
|
||||
|
||||
if scheme.Type == "apiKey" {
|
||||
if scheme.Name == "" {
|
||||
return fmt.Errorf("security scheme name is required for apiKey type")
|
||||
}
|
||||
if scheme.In != "header" && scheme.In != "query" && scheme.In != "cookie" {
|
||||
return fmt.Errorf("invalid security scheme location: %s", scheme.In)
|
||||
}
|
||||
}
|
||||
|
||||
if scheme.Type == "http" {
|
||||
if scheme.Scheme == "" {
|
||||
return fmt.Errorf("security scheme scheme is required for http type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateToolConfig validates a tool configuration
|
||||
func ValidateToolConfig(config McpProxyToolConfig) error {
|
||||
if config.Name == "" {
|
||||
return fmt.Errorf("tool name is required")
|
||||
}
|
||||
|
||||
if config.Description == "" {
|
||||
return fmt.Errorf("tool description is required")
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
argNames := make(map[string]bool)
|
||||
for _, arg := range config.Args {
|
||||
if arg.Name == "" {
|
||||
return fmt.Errorf("argument name is required")
|
||||
}
|
||||
|
||||
if argNames[arg.Name] {
|
||||
return fmt.Errorf("duplicate argument name: %s", arg.Name)
|
||||
}
|
||||
argNames[arg.Name] = true
|
||||
|
||||
if arg.Description == "" {
|
||||
return fmt.Errorf("argument description is required for %s", arg.Name)
|
||||
}
|
||||
|
||||
validTypes := []string{"string", "number", "integer", "boolean", "array", "object"}
|
||||
validType := false
|
||||
for _, t := range validTypes {
|
||||
if arg.Type == t {
|
||||
validType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validType {
|
||||
return fmt.Errorf("invalid argument type %s for %s", arg.Type, arg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
112
plugins/wasm-go/pkg/mcp/server/proxy_server_test.go
Normal file
112
plugins/wasm-go/pkg/mcp/server/proxy_server_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMcpProxyServerBasicInterface tests that McpProxyServer implements the Server interface
|
||||
func TestMcpProxyServerBasicInterface(t *testing.T) {
|
||||
// This test will fail until we implement McpProxyServer
|
||||
server := NewMcpProxyServer("test-proxy")
|
||||
|
||||
// Test Server interface implementation
|
||||
assert.NotNil(t, server)
|
||||
assert.Equal(t, "test-proxy", server.Name)
|
||||
|
||||
// Test that it implements all required methods
|
||||
tools := server.GetMCPTools()
|
||||
assert.NotNil(t, tools)
|
||||
assert.Equal(t, 0, len(tools))
|
||||
|
||||
// Test Clone method
|
||||
cloned := server.Clone()
|
||||
assert.NotNil(t, cloned)
|
||||
}
|
||||
|
||||
// TestMcpProxyServerConfiguration tests configuration setting and getting
|
||||
func TestMcpProxyServerConfiguration(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-proxy")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Add security scheme
|
||||
scheme := SecurityScheme{
|
||||
ID: "test-auth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
}
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
// Verify server fields
|
||||
assert.Equal(t, "http://backend.example.com/mcp", server.GetMcpServerURL())
|
||||
assert.Equal(t, 5000, server.GetTimeout())
|
||||
|
||||
// Verify security scheme
|
||||
retrievedScheme, exists := server.GetSecurityScheme("test-auth")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "test-auth", retrievedScheme.ID)
|
||||
assert.Equal(t, "apiKey", retrievedScheme.Type)
|
||||
}
|
||||
|
||||
// TestMcpProxyServerAddTool tests adding proxy tools
|
||||
func TestMcpProxyServerAddTool(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-proxy")
|
||||
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "test-tool",
|
||||
Description: "Test tool for proxy",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "input",
|
||||
Description: "Test input",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tools := server.GetMCPTools()
|
||||
assert.Len(t, tools, 1)
|
||||
assert.Contains(t, tools, "test-tool")
|
||||
}
|
||||
|
||||
// TestMcpProxyServerSecuritySchemes tests security scheme management
|
||||
func TestMcpProxyServerSecuritySchemes(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-proxy")
|
||||
|
||||
scheme := SecurityScheme{
|
||||
ID: "test-auth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-API-Key",
|
||||
}
|
||||
|
||||
server.AddSecurityScheme(scheme)
|
||||
|
||||
retrievedScheme, exists := server.GetSecurityScheme("test-auth")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, scheme.ID, retrievedScheme.ID)
|
||||
assert.Equal(t, scheme.Type, retrievedScheme.Type)
|
||||
}
|
||||
1269
plugins/wasm-go/pkg/mcp/server/proxy_tool.go
Normal file
1269
plugins/wasm-go/pkg/mcp/server/proxy_tool.go
Normal file
File diff suppressed because it is too large
Load Diff
485
plugins/wasm-go/pkg/mcp/server/proxy_tools_test.go
Normal file
485
plugins/wasm-go/pkg/mcp/server/proxy_tools_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestToolsListForwarding tests the tools/list request forwarding
|
||||
func TestToolsListForwarding(t *testing.T) {
|
||||
// Create proxy server with tools
|
||||
server := NewMcpProxyServer("tools-list-test")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Add test tools
|
||||
toolConfigs := []McpProxyToolConfig{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "location",
|
||||
Description: "City name",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "get_news",
|
||||
Description: "Get latest news",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "category",
|
||||
Description: "News category",
|
||||
Type: "string",
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, toolConfig := range toolConfigs {
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// Test that tools were added to server successfully
|
||||
tools := server.GetMCPTools()
|
||||
assert.Len(t, tools, 2)
|
||||
assert.Contains(t, tools, "get_weather")
|
||||
assert.Contains(t, tools, "get_news")
|
||||
}
|
||||
|
||||
// TestToolsCallForwarding tests the tools/call request forwarding
|
||||
func TestToolsCallForwarding(t *testing.T) {
|
||||
server := NewMcpProxyServer("tools-call-test")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Add test tool
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "test_tool",
|
||||
Description: "Test tool for call forwarding",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "input",
|
||||
Description: "Input parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the tool and create instance
|
||||
tool, exists := server.GetMCPTools()["test_tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"input": "test value",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// Test tool instance creation was successful
|
||||
assert.NotNil(t, toolInstance)
|
||||
assert.Equal(t, "test_tool", toolInstance.(*McpProxyTool).name)
|
||||
assert.Equal(t, "test value", toolInstance.(*McpProxyTool).arguments["input"])
|
||||
}
|
||||
|
||||
// TestToolsCallWithParameters tests tool call with various parameter types
|
||||
func TestToolsCallWithParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolConfig McpProxyToolConfig
|
||||
params map[string]interface{}
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "string parameter",
|
||||
toolConfig: McpProxyToolConfig{
|
||||
Name: "string_tool",
|
||||
Description: "Tool with string parameter",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "text",
|
||||
Description: "Text input",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
params: map[string]interface{}{
|
||||
"text": "hello world",
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "number parameter",
|
||||
toolConfig: McpProxyToolConfig{
|
||||
Name: "number_tool",
|
||||
Description: "Tool with number parameter",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "value",
|
||||
Description: "Numeric value",
|
||||
Type: "number",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
params: map[string]interface{}{
|
||||
"value": 42.5,
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "object parameter",
|
||||
toolConfig: McpProxyToolConfig{
|
||||
Name: "object_tool",
|
||||
Description: "Tool with object parameter",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "data",
|
||||
Description: "Object data",
|
||||
Type: "object",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
params: map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
},
|
||||
},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing required parameter",
|
||||
toolConfig: McpProxyToolConfig{
|
||||
Name: "required_tool",
|
||||
Description: "Tool with required parameter",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "required_param",
|
||||
Description: "Required parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
params: map[string]interface{}{},
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewMcpProxyServer("param-test")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
err := server.AddProxyTool(tt.toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()[tt.toolConfig.Name]
|
||||
require.True(t, exists)
|
||||
|
||||
paramsBytes, err := json.Marshal(tt.params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// Test tool instance creation
|
||||
assert.NotNil(t, toolInstance)
|
||||
if !tt.shouldErr {
|
||||
assert.Equal(t, tt.toolConfig.Name, toolInstance.(*McpProxyTool).name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolsCallWithCursor tests tools/list with pagination cursor
|
||||
func TestToolsCallWithCursor(t *testing.T) {
|
||||
server := NewMcpProxyServer("cursor-test")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// Test cursor parameter handling logic (basic validation)
|
||||
cursor := "page-2-cursor"
|
||||
assert.NotNil(t, cursor)
|
||||
assert.NotEmpty(t, cursor)
|
||||
}
|
||||
|
||||
// TestBackendErrorHandling tests handling of backend MCP server errors
|
||||
func TestBackendErrorHandling(t *testing.T) {
|
||||
server := NewMcpProxyServer("error-test")
|
||||
|
||||
// Set server fields directly
|
||||
server.SetMcpServerURL("http://failing-backend.example.com/mcp")
|
||||
server.SetTimeout(5000)
|
||||
|
||||
toolConfig := McpProxyToolConfig{
|
||||
Name: "failing_tool",
|
||||
Description: "Tool that will fail on backend",
|
||||
Args: []ToolArg{
|
||||
{
|
||||
Name: "input",
|
||||
Description: "Input parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.AddProxyTool(toolConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, exists := server.GetMCPTools()["failing_tool"]
|
||||
require.True(t, exists)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"input": "test value",
|
||||
}
|
||||
paramsBytes, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
toolInstance := tool.Create(paramsBytes)
|
||||
require.NotNil(t, toolInstance)
|
||||
|
||||
// Skip HttpContext-dependent test for now - will be tested in integration
|
||||
// Test tool instance creation for error scenario
|
||||
assert.NotNil(t, toolInstance)
|
||||
assert.Equal(t, "failing_tool", toolInstance.(*McpProxyTool).name)
|
||||
}
|
||||
|
||||
// TestParseSSEResponse tests the SSE response parsing functionality
|
||||
func TestParseSSEResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sseData string
|
||||
expectedData string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid SSE with JSON data",
|
||||
sseData: `event: message
|
||||
data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}}
|
||||
|
||||
`,
|
||||
expectedData: `{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"experimental":{},"prompts":{"listChanged":true},"resources":{"subscribe":false,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"Echo Server","version":"1.17.0"}}}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "SSE with multiple lines",
|
||||
sseData: `event: message
|
||||
data: {"jsonrpc":"2.0","id":2,"result":{"success":true}}
|
||||
|
||||
event: close
|
||||
data: {"jsonrpc":"2.0","method":"close"}
|
||||
|
||||
`,
|
||||
expectedData: `{"jsonrpc":"2.0","id":2,"result":{"success":true}}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "SSE with comments and empty lines",
|
||||
sseData: `: This is a comment
|
||||
event: message
|
||||
|
||||
data: {"jsonrpc":"2.0","id":3,"result":{"test":true}}
|
||||
|
||||
: Another comment
|
||||
`,
|
||||
expectedData: `{"jsonrpc":"2.0","id":3,"result":{"test":true}}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "SSE with any data content",
|
||||
sseData: `event: message
|
||||
data: {invalid json}
|
||||
|
||||
`,
|
||||
expectedData: `{invalid json}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "SSE with no data field",
|
||||
sseData: `event: message
|
||||
id: 123
|
||||
|
||||
`,
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty SSE data",
|
||||
sseData: ``,
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseSSEResponse([]byte(tt.sseData))
|
||||
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tt.expectedData, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsBackendError tests detection of backend error responses
|
||||
func TestIsBackendError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response string
|
||||
expectError bool
|
||||
expectErrType string
|
||||
}{
|
||||
{
|
||||
name: "JSON-RPC 2.0 error with unknown tool",
|
||||
response: `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Unknown tool: invalid_tool_name"
|
||||
}
|
||||
}`,
|
||||
expectError: true,
|
||||
expectErrType: "jsonrpc_error",
|
||||
},
|
||||
{
|
||||
name: "JSON-RPC 2.0 error with method not found",
|
||||
response: `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found"
|
||||
}
|
||||
}`,
|
||||
expectError: true,
|
||||
expectErrType: "jsonrpc_error",
|
||||
},
|
||||
{
|
||||
name: "result.isError format",
|
||||
response: `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"isError": true,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Tool execution failed: connection timeout"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expectError: true,
|
||||
expectErrType: "result_isError",
|
||||
},
|
||||
{
|
||||
name: "successful response with result",
|
||||
response: `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Success!"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expectError: false,
|
||||
expectErrType: "",
|
||||
},
|
||||
{
|
||||
name: "successful response with isError false",
|
||||
response: `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"isError": false,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Success!"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expectError: false,
|
||||
expectErrType: "",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
response: `{invalid json}`,
|
||||
expectError: false,
|
||||
expectErrType: "",
|
||||
},
|
||||
{
|
||||
name: "empty response",
|
||||
response: `{}`,
|
||||
expectError: false,
|
||||
expectErrType: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isError, errType := IsBackendError([]byte(tt.response))
|
||||
assert.Equal(t, tt.expectError, isError, "isError mismatch")
|
||||
assert.Equal(t, tt.expectErrType, errType, "error type mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardToolsList is now implemented in proxy_server.go
|
||||
1027
plugins/wasm-go/pkg/mcp/server/rest_server.go
Normal file
1027
plugins/wasm-go/pkg/mcp/server/rest_server.go
Normal file
File diff suppressed because it is too large
Load Diff
922
plugins/wasm-go/pkg/mcp/server/rest_server_test.go
Normal file
922
plugins/wasm-go/pkg/mcp/server/rest_server_test.go
Normal file
@@ -0,0 +1,922 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestConvertArgToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "string value",
|
||||
input: "test string",
|
||||
expected: "test string",
|
||||
},
|
||||
{
|
||||
name: "boolean true",
|
||||
input: true,
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "boolean false",
|
||||
input: false,
|
||||
expected: "false",
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
input: 42,
|
||||
expected: "42",
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
input: 3.14,
|
||||
expected: "3.14",
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
input: map[string]interface{}{"key": "value"},
|
||||
expected: `{"key":"value"}`,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
input: []interface{}{1, 2, 3},
|
||||
expected: "[1,2,3]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertArgToString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("convertArgToString(%v) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseTemplatePrependAppend(t *testing.T) {
|
||||
// Test response template with PrependBody and AppendBody
|
||||
sampleResponse := `{"result": "success", "data": {"name": "Test", "value": 42}}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
template RestToolResponseTemplate
|
||||
expected []string
|
||||
notExpected []string
|
||||
}{
|
||||
{
|
||||
name: "with body template only",
|
||||
template: RestToolResponseTemplate{
|
||||
Body: "# Result\n- Name: {{.data.name}}\n- Value: {{.data.value}}",
|
||||
},
|
||||
expected: []string{
|
||||
"# Result",
|
||||
"- Name: Test",
|
||||
"- Value: 42",
|
||||
},
|
||||
notExpected: []string{
|
||||
"Field Descriptions:",
|
||||
"End of Response",
|
||||
`{"result": "success"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with prepend only",
|
||||
template: RestToolResponseTemplate{
|
||||
PrependBody: "# Field Descriptions:\n- result: Operation result\n- data: Response data\n\n",
|
||||
},
|
||||
expected: []string{
|
||||
"# Field Descriptions:",
|
||||
"- result: Operation result",
|
||||
"- data: Response data",
|
||||
`{"result": "success"`,
|
||||
`"name": "Test"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with append only",
|
||||
template: RestToolResponseTemplate{
|
||||
AppendBody: "\n\n*End of Response*",
|
||||
},
|
||||
expected: []string{
|
||||
`{"result": "success"`,
|
||||
`"name": "Test"`,
|
||||
"*End of Response*",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with both prepend and append",
|
||||
template: RestToolResponseTemplate{
|
||||
PrependBody: "# API Response:\n\n",
|
||||
AppendBody: "\n\n*This is raw JSON data with field 'name' = Test and 'value' = 42*",
|
||||
},
|
||||
expected: []string{
|
||||
"# API Response:",
|
||||
`{"result": "success"`,
|
||||
`"name": "Test"`,
|
||||
"*This is raw JSON data with field 'name' = Test and 'value' = 42*",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a tool with the test template
|
||||
// For tests with only prepend/append (no body), add a RequestTemplate.URL
|
||||
// to avoid direct response mode validation
|
||||
tool := RestTool{
|
||||
ResponseTemplate: tt.template,
|
||||
}
|
||||
if tt.template.Body == "" && (tt.template.PrependBody != "" || tt.template.AppendBody != "") {
|
||||
tool.RequestTemplate.URL = "http://example.com/api"
|
||||
}
|
||||
|
||||
// Parse templates
|
||||
err := tool.parseTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse templates: %v", err)
|
||||
}
|
||||
|
||||
// Simulate response processing
|
||||
var result string
|
||||
responseBody := []byte(sampleResponse)
|
||||
|
||||
// Case 1: Full response template is provided
|
||||
if tool.parsedResponseTemplate != nil {
|
||||
templateResult, err := executeTemplate(tool.parsedResponseTemplate, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute response template: %v", err)
|
||||
}
|
||||
result = templateResult
|
||||
} else {
|
||||
// Case 2: No template, but prepend/append might be used
|
||||
rawResponse := string(responseBody)
|
||||
|
||||
// Apply prepend/append if specified
|
||||
if tool.ResponseTemplate.PrependBody != "" || tool.ResponseTemplate.AppendBody != "" {
|
||||
result = tool.ResponseTemplate.PrependBody + rawResponse + tool.ResponseTemplate.AppendBody
|
||||
} else {
|
||||
// Case 3: No template and no prepend/append, just use raw response
|
||||
result = rawResponse
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the result contains expected substrings
|
||||
for _, substr := range tt.expected {
|
||||
if !strings.Contains(result, substr) {
|
||||
t.Errorf("Expected substring not found: %s", substr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the result does not contain unexpected substrings
|
||||
for _, substr := range tt.notExpected {
|
||||
if strings.Contains(result, substr) {
|
||||
t.Errorf("Unexpected substring found: %s", substr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers [][2]string
|
||||
contentTypeStr string
|
||||
expectedOutcome bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
headers: [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
headers: [][2]string{
|
||||
{"content-type", "application/JSON"},
|
||||
},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: true,
|
||||
},
|
||||
{
|
||||
name: "substring match",
|
||||
headers: [][2]string{
|
||||
{"Content-Type", "application/json; charset=utf-8"},
|
||||
},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
headers: [][2]string{
|
||||
{"Content-Type", "text/plain"},
|
||||
},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: false,
|
||||
},
|
||||
{
|
||||
name: "header not present",
|
||||
headers: [][2]string{
|
||||
{"Accept", "application/json"},
|
||||
},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: false,
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: [][2]string{},
|
||||
contentTypeStr: "application/json",
|
||||
expectedOutcome: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasContentType(tt.headers, tt.contentTypeStr)
|
||||
if result != tt.expectedOutcome {
|
||||
t.Errorf("hasContentType(%v, %v) = %v, want %v", tt.headers, tt.contentTypeStr, result, tt.expectedOutcome)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestToolValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool RestTool
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "valid tool with no args options",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool with argsToJsonBody",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "POST",
|
||||
ArgsToJsonBody: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool with argsToUrlParam",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
ArgsToUrlParam: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool with argsToFormBody",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "POST",
|
||||
ArgsToFormBody: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tool with multiple args options",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "POST",
|
||||
ArgsToJsonBody: true,
|
||||
ArgsToFormBody: true,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool with all args options",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "POST",
|
||||
ArgsToJsonBody: true,
|
||||
ArgsToUrlParam: true,
|
||||
ArgsToFormBody: true,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool with both Body and PrependBody",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
ResponseTemplate: RestToolResponseTemplate{
|
||||
Body: "# Result\n{{.data}}",
|
||||
PrependBody: "# Field Descriptions:\n",
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool with both Body and AppendBody",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
ResponseTemplate: RestToolResponseTemplate{
|
||||
Body: "# Result\n{{.data}}",
|
||||
AppendBody: "\n*End of response*",
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool with Body, PrependBody, and AppendBody",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
ResponseTemplate: RestToolResponseTemplate{
|
||||
Body: "# Result\n{{.data}}",
|
||||
PrependBody: "# Field Descriptions:\n",
|
||||
AppendBody: "\n*End of response*",
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "valid tool with PrependBody and AppendBody but no Body",
|
||||
tool: RestTool{
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
ResponseTemplate: RestToolResponseTemplate{
|
||||
PrependBody: "# Field Descriptions:\n",
|
||||
AppendBody: "\n*End of response*",
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.tool.parseTemplates()
|
||||
if (err != nil) != tt.expectedError {
|
||||
t.Errorf("parseTemplates() error = %v, expectedError %v", err, tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputSchemaWithComplexTypes(t *testing.T) {
|
||||
// Create a tool with array and object type arguments
|
||||
tool := RestMCPTool{
|
||||
toolConfig: RestTool{
|
||||
Args: []RestToolArg{
|
||||
{
|
||||
Name: "stringArg",
|
||||
Description: "A string argument",
|
||||
Type: "string",
|
||||
},
|
||||
{
|
||||
Name: "arrayArg",
|
||||
Description: "An array argument",
|
||||
Type: "array",
|
||||
Items: map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "objectArg",
|
||||
Description: "An object argument",
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Name property",
|
||||
},
|
||||
"age": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Age property",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "arrayOfObjects",
|
||||
Description: "An array of objects",
|
||||
Type: "array",
|
||||
Items: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"id": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
"value": map[string]interface{}{
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := tool.InputSchema()
|
||||
|
||||
// Check schema structure
|
||||
if schema["type"] != "object" {
|
||||
t.Errorf("Expected schema type to be 'object', got %v", schema["type"])
|
||||
}
|
||||
|
||||
properties, ok := schema["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected properties to be a map, got %T", schema["properties"])
|
||||
}
|
||||
|
||||
// Check individual property types
|
||||
checkProperty := func(name, expectedType string) {
|
||||
prop, ok := properties[name].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected property %s to be a map, got %T", name, properties[name])
|
||||
}
|
||||
if prop["type"] != expectedType {
|
||||
t.Errorf("Expected property %s type to be '%s', got %v", name, expectedType, prop["type"])
|
||||
}
|
||||
}
|
||||
|
||||
checkProperty("stringArg", "string")
|
||||
checkProperty("arrayArg", "array")
|
||||
checkProperty("objectArg", "object")
|
||||
checkProperty("arrayOfObjects", "array")
|
||||
|
||||
// Check array items
|
||||
arrayArg, _ := properties["arrayArg"].(map[string]interface{})
|
||||
if arrayArg["items"] == nil {
|
||||
t.Errorf("Expected arrayArg to have items property")
|
||||
}
|
||||
|
||||
// Check object properties
|
||||
objectArg, _ := properties["objectArg"].(map[string]interface{})
|
||||
if objectArg["properties"] == nil {
|
||||
t.Errorf("Expected objectArg to have properties property")
|
||||
}
|
||||
|
||||
// Check array of objects
|
||||
arrayOfObjects, _ := properties["arrayOfObjects"].(map[string]interface{})
|
||||
items, ok := arrayOfObjects["items"].(map[string]interface{})
|
||||
if !ok || items["type"] != "object" {
|
||||
t.Errorf("Expected arrayOfObjects items to be of type object")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgsToUrlParamAndFormBody(t *testing.T) {
|
||||
// Test argsToUrlParam
|
||||
t.Run("argsToUrlParam", func(t *testing.T) {
|
||||
args := map[string]interface{}{
|
||||
"string": "value",
|
||||
"int": 42,
|
||||
"bool": true,
|
||||
"array": []interface{}{1, 2, 3},
|
||||
"object": map[string]interface{}{"key": "value"},
|
||||
}
|
||||
|
||||
// Parse URL and add parameters
|
||||
baseURL := "https://example.com/api"
|
||||
parsedURL, _ := url.Parse(baseURL)
|
||||
query := parsedURL.Query()
|
||||
|
||||
for key, value := range args {
|
||||
query.Set(key, convertArgToString(value))
|
||||
}
|
||||
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
result := parsedURL.String()
|
||||
|
||||
// Verify each parameter is in the URL
|
||||
for key, value := range args {
|
||||
strValue := convertArgToString(value)
|
||||
encodedValue := url.QueryEscape(strValue)
|
||||
paramStr := key + "=" + encodedValue
|
||||
|
||||
if !strings.Contains(result, paramStr) {
|
||||
t.Errorf("URL parameter missing: %s", paramStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test argsToFormBody
|
||||
t.Run("argsToFormBody", func(t *testing.T) {
|
||||
args := map[string]interface{}{
|
||||
"string": "value",
|
||||
"int": 42,
|
||||
"bool": true,
|
||||
"array": []interface{}{1, 2, 3},
|
||||
"object": map[string]interface{}{"key": "value"},
|
||||
}
|
||||
|
||||
// Create form values
|
||||
formValues := url.Values{}
|
||||
for key, value := range args {
|
||||
formValues.Set(key, convertArgToString(value))
|
||||
}
|
||||
|
||||
formBody := formValues.Encode()
|
||||
|
||||
// Verify each parameter is in the form body
|
||||
for key, value := range args {
|
||||
strValue := convertArgToString(value)
|
||||
encodedValue := url.QueryEscape(strValue)
|
||||
paramStr := key + "=" + encodedValue
|
||||
|
||||
if !strings.Contains(formBody, paramStr) {
|
||||
t.Errorf("Form body missing parameter: %s", paramStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRestToolConfig(t *testing.T) {
|
||||
// Example REST tool configuration
|
||||
configJSON := `
|
||||
{
|
||||
"server": {
|
||||
"name": "rest-amap-server",
|
||||
"config": {
|
||||
"apiKey": "xxxxx"
|
||||
}
|
||||
},
|
||||
"tools": [
|
||||
{
|
||||
"name": "maps-geo",
|
||||
"description": "将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标",
|
||||
"args": [
|
||||
{
|
||||
"name": "address",
|
||||
"description": "待解析的结构化地址信息",
|
||||
"type": "string",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"name": "city",
|
||||
"description": "指定查询的城市",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"name": "output",
|
||||
"description": "输出格式",
|
||||
"type": "string",
|
||||
"enum": ["json", "xml"],
|
||||
"default": "json"
|
||||
},
|
||||
{
|
||||
"name": "options",
|
||||
"description": "高级选项",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extensions": {
|
||||
"type": "string",
|
||||
"enum": ["base", "all"]
|
||||
},
|
||||
"batch": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "batch_addresses",
|
||||
"description": "批量地址",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestTemplate": {
|
||||
"url": "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp",
|
||||
"method": "GET",
|
||||
"headers": [
|
||||
{
|
||||
"key": "Content-Type",
|
||||
"value": "application/json"
|
||||
}
|
||||
]
|
||||
},
|
||||
"responseTemplate": {
|
||||
"body": "# 地理编码信息\n{{- range $index, $geo := .Geocodes }}\n## 地点 {{add $index 1}}\n\n- **国家**: {{ $geo.Country }}\n- **省份**: {{ $geo.Province }}\n- **城市**: {{ $geo.City }}\n- **城市代码**: {{ $geo.Citycode }}\n- **区/县**: {{ $geo.District }}\n- **街道**: {{ $geo.Street }}\n- **门牌号**: {{ $geo.Number }}\n- **行政编码**: {{ $geo.Adcode }}\n- **坐标**: {{ $geo.Location }}\n- **级别**: {{ $geo.Level }}\n{{- end }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
`
|
||||
|
||||
// Parse the config to verify it's valid JSON
|
||||
var configData map[string]interface{}
|
||||
err := json.Unmarshal([]byte(configJSON), &configData)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid JSON config: %v", err)
|
||||
}
|
||||
|
||||
// Example tool configuration
|
||||
tool := RestTool{
|
||||
Name: "maps-geo",
|
||||
Description: "将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标",
|
||||
Args: []RestToolArg{
|
||||
{
|
||||
Name: "address",
|
||||
Description: "待解析的结构化地址信息",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "city",
|
||||
Description: "指定查询的城市",
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Name: "output",
|
||||
Description: "输出格式",
|
||||
Type: "string",
|
||||
Enum: []interface{}{"json", "xml"},
|
||||
Default: "json",
|
||||
},
|
||||
{
|
||||
Name: "options",
|
||||
Description: "高级选项",
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{
|
||||
"extensions": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []interface{}{"base", "all"},
|
||||
},
|
||||
"batch": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "batch_addresses",
|
||||
Description: "批量地址",
|
||||
Type: "array",
|
||||
Items: map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "https://restapi.amap.com/v3/geocode/geo?key={{.config.apiKey}}&address={{.args.address}}&city={{.args.city}}&output={{.args.output}}&source=ts_mcp",
|
||||
Method: "GET",
|
||||
Headers: []RestToolHeader{
|
||||
{
|
||||
Key: "Content-Type",
|
||||
Value: "application/json",
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponseTemplate: RestToolResponseTemplate{
|
||||
Body: `# 地理编码信息
|
||||
{{- range $index, $geo := .Geocodes }}
|
||||
## 地点 {{add $index 1}}
|
||||
|
||||
- **国家**: {{ $geo.Country }}
|
||||
- **省份**: {{ $geo.Province }}
|
||||
- **城市**: {{ $geo.City }}
|
||||
- **城市代码**: {{ $geo.Citycode }}
|
||||
- **区/县**: {{ $geo.District }}
|
||||
- **街道**: {{ $geo.Street }}
|
||||
- **门牌号**: {{ $geo.Number }}
|
||||
- **行政编码**: {{ $geo.Adcode }}
|
||||
- **坐标**: {{ $geo.Location }}
|
||||
- **级别**: {{ $geo.Level }}
|
||||
{{- end }}`,
|
||||
},
|
||||
}
|
||||
|
||||
// Parse templates
|
||||
err = tool.parseTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse templates: %v", err)
|
||||
}
|
||||
|
||||
var templateData []byte
|
||||
templateData, _ = sjson.SetBytes(templateData, "config", map[string]interface{}{"apiKey": "test-api-key"})
|
||||
templateData, _ = sjson.SetBytes(templateData, "args", map[string]interface{}{
|
||||
"address": "北京市朝阳区阜通东大街6号",
|
||||
"city": "北京",
|
||||
"output": "json",
|
||||
})
|
||||
|
||||
// Test URL template
|
||||
url, err := executeTemplate(tool.parsedURLTemplate, templateData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute URL template: %v", err)
|
||||
}
|
||||
|
||||
expectedURL := "https://restapi.amap.com/v3/geocode/geo?key=test-api-key&address=北京市朝阳区阜通东大街6号&city=北京&output=json&source=ts_mcp"
|
||||
if url != expectedURL {
|
||||
t.Errorf("URL template rendering failed. Expected: %s, Got: %s", expectedURL, url)
|
||||
}
|
||||
|
||||
// Test InputSchema for complex types
|
||||
mcpTool := &RestMCPTool{
|
||||
toolConfig: tool,
|
||||
}
|
||||
|
||||
schema := mcpTool.InputSchema()
|
||||
properties := schema["properties"].(map[string]interface{})
|
||||
|
||||
// Check object type
|
||||
options, ok := properties["options"].(map[string]interface{})
|
||||
if !ok || options["type"] != "object" {
|
||||
t.Errorf("Expected options to be of type object")
|
||||
}
|
||||
|
||||
// Check array type
|
||||
batchAddresses, ok := properties["batch_addresses"].(map[string]interface{})
|
||||
if !ok || batchAddresses["type"] != "array" {
|
||||
t.Errorf("Expected batch_addresses to be of type array")
|
||||
}
|
||||
|
||||
// Test response template with sample data
|
||||
sampleResponse := `
|
||||
{"Geocodes": [
|
||||
{
|
||||
"Country": "中国",
|
||||
"Province": "北京市",
|
||||
"City": "北京市",
|
||||
"Citycode": "010",
|
||||
"District": "朝阳区",
|
||||
"Street": "阜通东大街",
|
||||
"Number": "6号",
|
||||
"Adcode": "110105",
|
||||
"Location": "116.483038,39.990633",
|
||||
"Level": "门牌号",
|
||||
}]}`
|
||||
|
||||
result, err := executeTemplate(tool.parsedResponseTemplate, []byte(sampleResponse))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute response template: %v", err)
|
||||
}
|
||||
|
||||
// Just check that the result contains expected substrings
|
||||
expectedSubstrings := []string{
|
||||
"# 地理编码信息",
|
||||
"## 地点 1",
|
||||
"**国家**: 中国",
|
||||
"**省份**: 北京市",
|
||||
"**坐标**: 116.483038,39.990633",
|
||||
}
|
||||
|
||||
for _, substr := range expectedSubstrings {
|
||||
if !strings.Contains(result, substr) {
|
||||
t.Errorf("Response template rendering failed. Expected substring not found: %s", substr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRestServerDefaultSecurity tests the default security configuration for REST MCP server
|
||||
func TestRestServerDefaultSecurity(t *testing.T) {
|
||||
server := NewRestMCPServer("test-rest-server")
|
||||
|
||||
// Add security schemes
|
||||
defaultScheme := SecurityScheme{
|
||||
ID: "DefaultAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Default-Key",
|
||||
DefaultCredential: "default-key",
|
||||
}
|
||||
toolScheme := SecurityScheme{
|
||||
ID: "ToolAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Tool-Key",
|
||||
DefaultCredential: "tool-key",
|
||||
}
|
||||
server.AddSecurityScheme(defaultScheme)
|
||||
server.AddSecurityScheme(toolScheme)
|
||||
|
||||
// Test setting default security directly on server
|
||||
server.SetDefaultDownstreamSecurity(SecurityRequirement{
|
||||
ID: "DefaultAuth",
|
||||
Passthrough: false,
|
||||
})
|
||||
server.SetDefaultUpstreamSecurity(SecurityRequirement{
|
||||
ID: "DefaultAuth",
|
||||
})
|
||||
|
||||
// Verify default security settings
|
||||
retrievedDownstream := server.GetDefaultDownstreamSecurity()
|
||||
assert.Equal(t, "DefaultAuth", retrievedDownstream.ID)
|
||||
assert.False(t, retrievedDownstream.Passthrough)
|
||||
|
||||
retrievedUpstream := server.GetDefaultUpstreamSecurity()
|
||||
assert.Equal(t, "DefaultAuth", retrievedUpstream.ID)
|
||||
|
||||
t.Logf("REST server default security configuration test completed successfully")
|
||||
}
|
||||
|
||||
// TestRestServerSecurityFallback tests the fallback mechanism from tool-level to default security
|
||||
func TestRestServerSecurityFallback(t *testing.T) {
|
||||
server := NewRestMCPServer("test-rest-server")
|
||||
|
||||
// Add security schemes
|
||||
defaultScheme := SecurityScheme{
|
||||
ID: "DefaultAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Default-Key",
|
||||
DefaultCredential: "default-key",
|
||||
}
|
||||
toolScheme := SecurityScheme{
|
||||
ID: "ToolAuth",
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-Tool-Key",
|
||||
DefaultCredential: "tool-key",
|
||||
}
|
||||
server.AddSecurityScheme(defaultScheme)
|
||||
server.AddSecurityScheme(toolScheme)
|
||||
|
||||
// Test tool configuration with tool-level security (should use tool-level, not default)
|
||||
toolConfigWithSecurity := RestTool{
|
||||
Name: "secure_tool",
|
||||
Description: "Tool with its own security",
|
||||
Security: SecurityRequirement{
|
||||
ID: "ToolAuth",
|
||||
Passthrough: true,
|
||||
},
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "http://api.example.com/secure",
|
||||
Method: "GET",
|
||||
Security: SecurityRequirement{
|
||||
ID: "ToolAuth",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test tool configuration without tool-level security (should fallback to default)
|
||||
toolConfigWithoutSecurity := RestTool{
|
||||
Name: "fallback_tool",
|
||||
Description: "Tool that falls back to default security",
|
||||
// No Security field configured, should use default
|
||||
RequestTemplate: RestToolRequestTemplate{
|
||||
URL: "http://api.example.com/fallback",
|
||||
Method: "GET",
|
||||
// No Security field configured, should use default
|
||||
},
|
||||
}
|
||||
|
||||
// Add tools to server
|
||||
err := server.AddRestTool(toolConfigWithSecurity)
|
||||
assert.NoError(t, err)
|
||||
err = server.AddRestTool(toolConfigWithoutSecurity)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify tools were added
|
||||
tools := server.GetMCPTools()
|
||||
assert.Contains(t, tools, "secure_tool")
|
||||
assert.Contains(t, tools, "fallback_tool")
|
||||
|
||||
t.Logf("REST server security fallback test completed successfully")
|
||||
}
|
||||
874
plugins/wasm-go/pkg/mcp/server/sse_proxy.go
Normal file
874
plugins/wasm-go/pkg/mcp/server/sse_proxy.go
Normal file
@@ -0,0 +1,874 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
// Context keys for SSE proxy state management
|
||||
CtxSSEProxyState = "sse_proxy_state"
|
||||
CtxSSEProxyEndpointURL = "sse_proxy_endpoint_url"
|
||||
CtxSSEProxyBuffer = "sse_proxy_buffer"
|
||||
CtxSSEProxyAuthInfo = "sse_proxy_auth_info"
|
||||
CtxSSEProxyRequestBody = "sse_proxy_request_body"
|
||||
CtxSSEProxyRequestID = "sse_proxy_request_id"
|
||||
CtxSSEProxyFirstChunk = "sse_proxy_first_chunk"
|
||||
CtxSSEProxyJsonRpcID = "sse_proxy_jsonrpc_id"
|
||||
|
||||
// SSE proxy state values
|
||||
SSEStateWaitingEndpoint = "waiting_endpoint"
|
||||
SSEStateWaitingInitResp = "waiting_init_resp"
|
||||
SSEStateWaitingNotifyResp = "waiting_notify_resp"
|
||||
SSEStateWaitingToolResp = "waiting_tool_resp"
|
||||
|
||||
// Buffer size limit: 100MB
|
||||
MaxSSEBufferSize = 100 * 1024 * 1024
|
||||
)
|
||||
|
||||
// injectSSEResponseSuccess injects a successful JSON-RPC response in streaming response body phase
|
||||
func injectSSEResponseSuccess(ctx wrapper.HttpContext, result map[string]any) {
|
||||
// Get JSON-RPC ID from context
|
||||
jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID)
|
||||
if jsonRpcIDRaw == nil {
|
||||
log.Errorf("JSON-RPC ID not found in context for SSE response")
|
||||
return
|
||||
}
|
||||
jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID)
|
||||
|
||||
var body []byte
|
||||
var err error
|
||||
if jsonRpcID.IsString {
|
||||
body, err = json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": jsonRpcID.StringValue,
|
||||
"result": result,
|
||||
})
|
||||
} else {
|
||||
body, err = json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": jsonRpcID.IntValue,
|
||||
"result": result,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal JSON-RPC success response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
proxywasm.InjectEncodedDataToFilterChain(body, true)
|
||||
}
|
||||
|
||||
// injectSSEResponseError injects an error JSON-RPC response in streaming response body phase
|
||||
func injectSSEResponseError(ctx wrapper.HttpContext, err error, errorCode int) {
|
||||
// Get JSON-RPC ID from context
|
||||
jsonRpcIDRaw := ctx.GetContext(CtxSSEProxyJsonRpcID)
|
||||
if jsonRpcIDRaw == nil {
|
||||
log.Errorf("JSON-RPC ID not found in context for SSE error response")
|
||||
return
|
||||
}
|
||||
jsonRpcID := jsonRpcIDRaw.(utils.JsonRpcID)
|
||||
|
||||
var body []byte
|
||||
var marshalErr error
|
||||
if jsonRpcID.IsString {
|
||||
body, marshalErr = json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": jsonRpcID.StringValue,
|
||||
"error": map[string]interface{}{
|
||||
"code": errorCode,
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
body, marshalErr = json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": jsonRpcID.IntValue,
|
||||
"error": map[string]interface{}{
|
||||
"code": errorCode,
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if marshalErr != nil {
|
||||
log.Errorf("Failed to marshal JSON-RPC error response: %v", marshalErr)
|
||||
return
|
||||
}
|
||||
|
||||
proxywasm.InjectEncodedDataToFilterChain(body, true)
|
||||
}
|
||||
|
||||
// SSEMessage represents a parsed SSE message
|
||||
type SSEMessage struct {
|
||||
Event string
|
||||
Data string
|
||||
ID string
|
||||
}
|
||||
|
||||
// ParseSSEMessage parses SSE format data and returns complete messages
|
||||
// Returns the parsed message and the remaining unparsed data
|
||||
func ParseSSEMessage(data []byte) (*SSEMessage, []byte, error) {
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
// Set max token size to 32MB to handle large messages
|
||||
maxTokenSize := 32 * 1024 * 1024 // 32MB
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxTokenSize)
|
||||
msg := &SSEMessage{}
|
||||
lineCount := 0
|
||||
lastPos := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
lineCount++
|
||||
lastPos += len(line) + 1 // +1 for newline
|
||||
|
||||
// Empty line indicates end of message
|
||||
if strings.TrimSpace(line) == "" {
|
||||
if msg.Event != "" || msg.Data != "" || msg.ID != "" {
|
||||
// Found a complete message
|
||||
return msg, data[lastPos:], nil
|
||||
}
|
||||
// Empty message, continue
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip comment lines (lines starting with ':')
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse field
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
field := parts[0]
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
switch field {
|
||||
case "event":
|
||||
msg.Event = value
|
||||
case "data":
|
||||
if msg.Data != "" {
|
||||
msg.Data += "\n" + value
|
||||
} else {
|
||||
msg.Data = value
|
||||
}
|
||||
case "id":
|
||||
msg.ID = value
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
return nil, nil, fmt.Errorf("SSE message line exceeds maximum token size (32MB): %w", err)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("error scanning SSE data: %v", err)
|
||||
}
|
||||
|
||||
// No complete message found, return all data as remaining
|
||||
return nil, data, nil
|
||||
}
|
||||
|
||||
// ExtractEndpointURL extracts the endpoint URL from an SSE endpoint message
|
||||
// It handles two cases:
|
||||
// 1. endpointData is a full URL (e.g., http://example.com/sse) - return as-is
|
||||
// 2. endpointData is a path - if baseURL has scheme and host, combine them; otherwise return the path as-is
|
||||
func ExtractEndpointURL(endpointData string, baseURL string) (string, error) {
|
||||
// Case 1: endpointData is a full URL
|
||||
if strings.HasPrefix(endpointData, "http://") || strings.HasPrefix(endpointData, "https://") {
|
||||
return endpointData, nil
|
||||
}
|
||||
|
||||
// endpointData is a path
|
||||
parsedBase, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL: %v", err)
|
||||
}
|
||||
|
||||
// Case 2: baseURL has scheme and host, combine them
|
||||
if parsedBase.Scheme != "" && parsedBase.Host != "" {
|
||||
// Combine scheme, host, and the new path
|
||||
// Ensure endpointData starts with "/"
|
||||
if !strings.HasPrefix(endpointData, "/") {
|
||||
endpointData = "/" + endpointData
|
||||
}
|
||||
result := parsedBase.Scheme + "://" + parsedBase.Host + endpointData
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Case 3: baseURL is also just a path, return endpointData as-is
|
||||
return endpointData, nil
|
||||
}
|
||||
|
||||
// sendSSEInitialize sends the initialize request for SSE protocol
|
||||
func sendSSEInitialize(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error {
|
||||
initRequest := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]interface{}{
|
||||
"roots": map[string]interface{}{
|
||||
"listChanged": true,
|
||||
},
|
||||
"sampling": map[string]interface{}{},
|
||||
"elicitation": map[string]interface{}{},
|
||||
},
|
||||
"clientInfo": map[string]interface{}{
|
||||
"name": "Higress-mcp-proxy",
|
||||
"title": "Higress MCP Proxy",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal initialize request: %v", err)
|
||||
}
|
||||
|
||||
// Copy headers from current request (now supported in response phase by Envoy)
|
||||
finalHeaders := copyHeadersForSSERequest(ctx)
|
||||
|
||||
// Override required headers for SSE initialize
|
||||
ensureHeader(&finalHeaders, "Content-Type", "application/json")
|
||||
|
||||
// Apply authentication to headers and URL
|
||||
finalURL := endpointURL
|
||||
if authInfo != nil && authInfo.SecuritySchemeID != "" {
|
||||
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to apply authentication for SSE initialize: %v", err)
|
||||
} else {
|
||||
finalURL = modifiedURL
|
||||
}
|
||||
}
|
||||
|
||||
// Note: headers are already copied from the current request (which has server-level headers applied)
|
||||
// via copyHeadersForSSERequest, so no need to apply them again
|
||||
|
||||
// Store state for tracking
|
||||
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingInitResp)
|
||||
ctx.SetContext(CtxSSEProxyRequestID, 1)
|
||||
|
||||
// Use RouteCluster client to send initialize request
|
||||
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
|
||||
timeout := uint32(proxyServer.GetTimeout())
|
||||
if timeout == 0 {
|
||||
timeout = 5000 // Default 5 seconds
|
||||
}
|
||||
|
||||
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != 200 && statusCode != 202 {
|
||||
log.Errorf("SSE initialize request failed with status %d: %s", statusCode, string(responseBody))
|
||||
// At this point, we're in streaming response phase, must use injectSSEResponseError
|
||||
injectSSEResponseError(ctx, fmt.Errorf("SSE initialize failed with status %d", statusCode), utils.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("SSE initialize request sent successfully")
|
||||
// The response will be received through SSE channel and processed in streaming response handler
|
||||
// State has already been set to SSEStateWaitingInitResp before this POST request
|
||||
// No need to change state here
|
||||
}, timeout)
|
||||
}
|
||||
|
||||
// sendSSENotification sends the notifications/initialized message for SSE protocol
|
||||
func sendSSENotification(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer) error {
|
||||
notification := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal notification: %v", err)
|
||||
}
|
||||
|
||||
// Copy headers from current request (now supported in response phase by Envoy)
|
||||
finalHeaders := copyHeadersForSSERequest(ctx)
|
||||
|
||||
// Override required headers for SSE notification
|
||||
ensureHeader(&finalHeaders, "Content-Type", "application/json")
|
||||
|
||||
// Apply authentication to headers and URL
|
||||
finalURL := endpointURL
|
||||
if authInfo != nil && authInfo.SecuritySchemeID != "" {
|
||||
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to apply authentication for SSE notification: %v", err)
|
||||
} else {
|
||||
finalURL = modifiedURL
|
||||
}
|
||||
}
|
||||
|
||||
// Note: headers are already copied from the current request (which has server-level headers applied)
|
||||
// via copyHeadersForSSERequest, so no need to apply them again
|
||||
|
||||
// Store state for tracking
|
||||
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingNotifyResp)
|
||||
|
||||
// Use RouteCluster client to send notification
|
||||
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
|
||||
timeout := uint32(proxyServer.GetTimeout())
|
||||
if timeout == 0 {
|
||||
timeout = 5000 // Default 5 seconds
|
||||
}
|
||||
|
||||
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != 200 && statusCode != 202 {
|
||||
log.Warnf("SSE notification request failed with status %d: %s", statusCode, string(responseBody))
|
||||
// Even if notification fails, we should try to continue
|
||||
// Some servers may not strictly require notification success
|
||||
}
|
||||
|
||||
log.Debugf("SSE notification sent successfully")
|
||||
|
||||
// Now we can send the actual tool request
|
||||
// Get stored context
|
||||
endpointURLRaw := ctx.GetContext(CtxSSEProxyEndpointURL)
|
||||
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
|
||||
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
|
||||
requestBodyRaw := ctx.GetContext(CtxSSEProxyRequestBody)
|
||||
|
||||
if endpointURLRaw == nil || proxyServerRaw == nil || requestBodyRaw == nil {
|
||||
log.Errorf("Missing context for sending tool request")
|
||||
// At this point, we're in streaming response phase, must use injectSSEResponseError
|
||||
injectSSEResponseError(ctx, fmt.Errorf("internal error: missing context"), utils.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
endpointURL := endpointURLRaw.(string)
|
||||
proxyServer := proxyServerRaw.(*McpProxyServer)
|
||||
requestBody := requestBodyRaw.([]byte)
|
||||
|
||||
var authInfo *ProxyAuthInfo
|
||||
if authInfoRaw != nil {
|
||||
authInfo = authInfoRaw.(*ProxyAuthInfo)
|
||||
}
|
||||
|
||||
// Parse to get request ID
|
||||
reqID := gjson.GetBytes(requestBody, "id").Int()
|
||||
if err := sendSSEToolRequest(ctx, endpointURL, authInfo, proxyServer, requestBody, int(reqID)); err != nil {
|
||||
log.Errorf("Failed to send SSE tool request: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
}
|
||||
}, timeout)
|
||||
}
|
||||
|
||||
// sendSSEToolRequest sends the tools/list or tools/call request for SSE protocol
|
||||
func sendSSEToolRequest(ctx wrapper.HttpContext, endpointURL string, authInfo *ProxyAuthInfo, proxyServer *McpProxyServer, requestBody []byte, requestID int) error {
|
||||
// Copy headers from current request (now supported in response phase by Envoy)
|
||||
finalHeaders := copyHeadersForSSERequest(ctx)
|
||||
|
||||
// Override required headers for SSE tool request
|
||||
ensureHeader(&finalHeaders, "Content-Type", "application/json")
|
||||
|
||||
// Apply authentication to headers and URL
|
||||
finalURL := endpointURL
|
||||
if authInfo != nil && authInfo.SecuritySchemeID != "" {
|
||||
modifiedURL, err := applyProxyAuthenticationForSSE(proxyServer, authInfo.SecuritySchemeID, authInfo.PassthroughCredential, &finalHeaders, endpointURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to apply authentication for SSE tool request: %v", err)
|
||||
} else {
|
||||
finalURL = modifiedURL
|
||||
}
|
||||
}
|
||||
|
||||
// Note: headers are already copied from the current request (which has server-level headers applied)
|
||||
// via copyHeadersForSSERequest, so no need to apply them again
|
||||
|
||||
// Store state for tracking
|
||||
ctx.SetContext(CtxSSEProxyState, SSEStateWaitingToolResp)
|
||||
ctx.SetContext(CtxSSEProxyRequestID, requestID)
|
||||
|
||||
// Use RouteCluster client to send tool request
|
||||
client := wrapper.NewClusterClient(wrapper.RouteCluster{})
|
||||
timeout := uint32(proxyServer.GetTimeout())
|
||||
if timeout == 0 {
|
||||
timeout = 5000 // Default 5 seconds
|
||||
}
|
||||
|
||||
return client.Post(finalURL, finalHeaders, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != 200 && statusCode != 202 {
|
||||
log.Errorf("SSE tool request failed with status %d: %s", statusCode, string(responseBody))
|
||||
// At this point, we're in streaming response phase, must use injectSSEResponseError
|
||||
injectSSEResponseError(ctx, fmt.Errorf("SSE tool request failed with status %d", statusCode), utils.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("SSE tool request sent successfully")
|
||||
// The response will be received through SSE channel and processed in streaming response handler
|
||||
}, timeout)
|
||||
}
|
||||
|
||||
// copyHeadersForSSERequest copies headers from current request for SSE RouteCluster calls
|
||||
// This leverages Envoy's new capability to access request headers in response phase
|
||||
func copyHeadersForSSERequest(ctx wrapper.HttpContext) [][2]string {
|
||||
headers := make([][2]string, 0)
|
||||
|
||||
// Headers to skip
|
||||
skipHeaders := map[string]bool{
|
||||
"content-length": true, // Will be set by the client
|
||||
"transfer-encoding": true, // Will be set by the client
|
||||
"accept": true, // Will be set explicitly for SSE requests
|
||||
":path": true, // Pseudo-header, not needed
|
||||
":method": true, // Pseudo-header, not needed
|
||||
":scheme": true, // Pseudo-header, not needed
|
||||
":authority": true, // Pseudo-header, not needed
|
||||
}
|
||||
|
||||
// Get all request headers (now supported in response phase by Envoy)
|
||||
headerMap, err := proxywasm.GetHttpRequestHeaders()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get request headers in response phase: %v", err)
|
||||
// Return minimal headers
|
||||
return [][2]string{}
|
||||
}
|
||||
|
||||
// Copy headers, skipping unwanted ones
|
||||
for _, header := range headerMap {
|
||||
headerName := strings.ToLower(header[0])
|
||||
if skipHeaders[headerName] {
|
||||
continue
|
||||
}
|
||||
headers = append(headers, header)
|
||||
}
|
||||
|
||||
log.Debugf("Copied %d headers from request in response phase for SSE", len(headers))
|
||||
return headers
|
||||
}
|
||||
|
||||
// applyProxyAuthenticationForSSE applies authentication for SSE proxy requests
|
||||
func applyProxyAuthenticationForSSE(server *McpProxyServer, schemeID string, passthroughCredential string, headers *[][2]string, targetURL string) (string, error) {
|
||||
// Parse the target URL
|
||||
parsedURL, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse target URL: %v", err)
|
||||
}
|
||||
|
||||
// Create authentication context
|
||||
authCtx := AuthRequestContext{
|
||||
Method: "POST",
|
||||
Headers: *headers,
|
||||
ParsedURL: parsedURL,
|
||||
RequestBody: []byte{},
|
||||
PassthroughCredential: passthroughCredential,
|
||||
}
|
||||
|
||||
// Create security config
|
||||
securityConfig := SecurityRequirement{
|
||||
ID: schemeID,
|
||||
Credential: "",
|
||||
Passthrough: passthroughCredential != "",
|
||||
}
|
||||
|
||||
// Apply authentication
|
||||
err = ApplySecurity(securityConfig, server, &authCtx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update headers
|
||||
*headers = authCtx.Headers
|
||||
|
||||
// Reconstruct URL
|
||||
u := authCtx.ParsedURL
|
||||
encodedPath := u.EscapedPath()
|
||||
var urlStr string
|
||||
if u.Scheme != "" && u.Host != "" {
|
||||
urlStr = u.Scheme + "://" + u.Host + encodedPath
|
||||
} else {
|
||||
urlStr = "/" + strings.TrimPrefix(encodedPath, "/")
|
||||
}
|
||||
if u.RawQuery != "" {
|
||||
urlStr += "?" + u.RawQuery
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
urlStr += "#" + u.Fragment
|
||||
}
|
||||
|
||||
return urlStr, nil
|
||||
}
|
||||
|
||||
// handleSSEStreamingResponse handles the streaming SSE response
|
||||
func handleSSEStreamingResponse(ctx wrapper.HttpContext, config McpServerConfig, data []byte, endOfStream bool) []byte {
|
||||
// Get the first chunk flag
|
||||
isFirstChunk := ctx.GetBoolContext(CtxSSEProxyFirstChunk, true)
|
||||
if isFirstChunk {
|
||||
ctx.SetContext(CtxSSEProxyFirstChunk, false)
|
||||
}
|
||||
log.Debugf("Handling chunk of SSE response, data: %q", string(data))
|
||||
// On first chunk, validate content-type and modify headers
|
||||
if isFirstChunk {
|
||||
// Validate that backend returned text/event-stream
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if err != nil || !strings.Contains(strings.ToLower(contentType), "text/event-stream") {
|
||||
log.Errorf("Backend did not return text/event-stream content-type, got: %s", contentType)
|
||||
// Return JSON-RPC error
|
||||
injectSSEResponseError(ctx, fmt.Errorf("invalid content-type, expected text/event-stream but got: %s", contentType), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Remove content-length and modify content-type
|
||||
proxywasm.RemoveHttpResponseHeader("content-length")
|
||||
proxywasm.ReplaceHttpResponseHeader("content-type", "application/json; charset=utf-8")
|
||||
proxywasm.ReplaceHttpResponseHeader(":status", "200")
|
||||
}
|
||||
|
||||
// Get or initialize buffer
|
||||
var buffer []byte
|
||||
if bufferRaw := ctx.GetContext(CtxSSEProxyBuffer); bufferRaw != nil {
|
||||
buffer = bufferRaw.([]byte)
|
||||
}
|
||||
|
||||
// Append new data to buffer
|
||||
buffer = append(buffer, data...)
|
||||
|
||||
// Check buffer size limit
|
||||
if len(buffer) > MaxSSEBufferSize {
|
||||
log.Errorf("SSE buffer exceeded maximum size of %d bytes", MaxSSEBufferSize)
|
||||
injectSSEResponseError(ctx, errors.New("response too large"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Store buffer back
|
||||
ctx.SetContext(CtxSSEProxyBuffer, buffer)
|
||||
|
||||
// Get current state
|
||||
state := ctx.GetContext(CtxSSEProxyState)
|
||||
if state == nil {
|
||||
state = SSEStateWaitingEndpoint
|
||||
ctx.SetContext(CtxSSEProxyState, state)
|
||||
}
|
||||
|
||||
log.Debugf("SSE proxy state: %s, now buffering data: %q", state.(string), string(buffer))
|
||||
|
||||
// Process based on state
|
||||
switch state.(string) {
|
||||
case SSEStateWaitingEndpoint:
|
||||
return handleWaitingEndpoint(ctx, config, &buffer)
|
||||
|
||||
case SSEStateWaitingInitResp:
|
||||
return handleWaitingInitResp(ctx, config, &buffer)
|
||||
|
||||
case SSEStateWaitingNotifyResp:
|
||||
return handleWaitingNotifyResp(ctx, config, &buffer)
|
||||
|
||||
case SSEStateWaitingToolResp:
|
||||
return handleWaitingToolResp(ctx, config, &buffer)
|
||||
|
||||
default:
|
||||
log.Warnf("Unknown SSE proxy state: %v", state)
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitingEndpoint processes SSE messages waiting for endpoint message
|
||||
func handleWaitingEndpoint(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
|
||||
for {
|
||||
msg, remaining, err := ParseSSEMessage(*buffer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse SSE message: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if msg == nil {
|
||||
// No complete message yet
|
||||
*buffer = remaining
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Update buffer
|
||||
*buffer = remaining
|
||||
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
|
||||
|
||||
// Check for endpoint message
|
||||
if msg.Event == "endpoint" {
|
||||
// Extract and store endpoint URL
|
||||
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
|
||||
if proxyServerRaw == nil {
|
||||
log.Errorf("mcp_proxy_server not found in context")
|
||||
injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
proxyServer := proxyServerRaw.(*McpProxyServer)
|
||||
|
||||
endpointURL, err := ExtractEndpointURL(msg.Data, proxyServer.GetMcpServerURL())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to extract endpoint URL: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
log.Infof("Received SSE endpoint URL: %s", endpointURL)
|
||||
ctx.SetContext(CtxSSEProxyEndpointURL, endpointURL)
|
||||
|
||||
// Get stored auth info
|
||||
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
|
||||
|
||||
var authInfo *ProxyAuthInfo
|
||||
if authInfoRaw != nil {
|
||||
authInfo = authInfoRaw.(*ProxyAuthInfo)
|
||||
}
|
||||
|
||||
// Send initialize request
|
||||
if err := sendSSEInitialize(ctx, endpointURL, authInfo, proxyServer); err != nil {
|
||||
log.Errorf("Failed to send SSE initialize: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// State has been changed to SSEStateWaitingInitResp in sendSSEInitialize
|
||||
// Return immediately to allow next chunk to be processed in the new state
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Skip other message types (like ping) while waiting for endpoint
|
||||
// Continue to process next message in buffer
|
||||
log.Debugf("Skipping SSE message with event '%s' while waiting for endpoint", msg.Event)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitingInitResp processes SSE messages waiting for initialize response
|
||||
func handleWaitingInitResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
|
||||
requestID := ctx.GetContext(CtxSSEProxyRequestID)
|
||||
if requestID == nil {
|
||||
requestID = 1
|
||||
}
|
||||
|
||||
for {
|
||||
msg, remaining, err := ParseSSEMessage(*buffer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse SSE message: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if msg == nil {
|
||||
// No complete message yet
|
||||
*buffer = remaining
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Update buffer
|
||||
*buffer = remaining
|
||||
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
|
||||
|
||||
// Check for message event
|
||||
if msg.Event == "message" {
|
||||
// Parse JSON-RPC response
|
||||
var jsonRpcResp map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil {
|
||||
log.Errorf("Failed to parse JSON-RPC response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is the initialize response
|
||||
respID := jsonRpcResp["id"]
|
||||
if respID != nil {
|
||||
var idMatch bool
|
||||
switch v := respID.(type) {
|
||||
case float64:
|
||||
idMatch = int(v) == requestID.(int)
|
||||
case int:
|
||||
idMatch = v == requestID.(int)
|
||||
}
|
||||
|
||||
if idMatch {
|
||||
// Check for errors
|
||||
if errorObj, hasError := jsonRpcResp["error"]; hasError {
|
||||
log.Errorf("Backend initialize error: %v", errorObj)
|
||||
injectSSEResponseError(ctx, fmt.Errorf("backend initialize failed"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
log.Debugf("Received initialize response, sending notification")
|
||||
|
||||
// Get endpoint URL and auth info
|
||||
endpointURL := ctx.GetContext(CtxSSEProxyEndpointURL).(string)
|
||||
authInfoRaw := ctx.GetContext(CtxSSEProxyAuthInfo)
|
||||
proxyServerRaw := ctx.GetContext("mcp_proxy_server")
|
||||
|
||||
var authInfo *ProxyAuthInfo
|
||||
if authInfoRaw != nil {
|
||||
authInfo = authInfoRaw.(*ProxyAuthInfo)
|
||||
}
|
||||
|
||||
proxyServer := proxyServerRaw.(*McpProxyServer)
|
||||
|
||||
// Send notification
|
||||
// The notification callback will send the tool request after notification succeeds
|
||||
if err := sendSSENotification(ctx, endpointURL, authInfo, proxyServer); err != nil {
|
||||
log.Errorf("Failed to send SSE notification: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// State has been changed to SSEStateWaitingNotifyResp in sendSSENotification
|
||||
// The tool request will be sent in the notification callback
|
||||
// Return immediately to allow next chunk to be processed in the new state
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip other message types (like ping) while waiting for init response
|
||||
// Continue to process next message in buffer
|
||||
log.Debugf("Skipping SSE message with event '%s' while waiting for init response", msg.Event)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitingNotifyResp processes SSE messages waiting for notification response
|
||||
func handleWaitingNotifyResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
|
||||
// For notifications, we don't expect a response in SSE channel
|
||||
// Just continue to send tool request
|
||||
// This state should be very brief
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// handleWaitingToolResp processes SSE messages waiting for tool response
|
||||
func handleWaitingToolResp(ctx wrapper.HttpContext, config McpServerConfig, buffer *[]byte) []byte {
|
||||
requestID := ctx.GetContext(CtxSSEProxyRequestID)
|
||||
if requestID == nil {
|
||||
log.Errorf("Request ID not found in context")
|
||||
injectSSEResponseError(ctx, errors.New("internal error"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
for {
|
||||
msg, remaining, err := ParseSSEMessage(*buffer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse SSE message: %v", err)
|
||||
injectSSEResponseError(ctx, err, utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
if msg == nil {
|
||||
// No complete message yet
|
||||
*buffer = remaining
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Update buffer
|
||||
*buffer = remaining
|
||||
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
|
||||
|
||||
// Check for message event
|
||||
if msg.Event == "message" {
|
||||
// Parse JSON-RPC response
|
||||
var jsonRpcResp map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(msg.Data), &jsonRpcResp); err != nil {
|
||||
log.Errorf("Failed to parse JSON-RPC response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is the expected response
|
||||
respID := jsonRpcResp["id"]
|
||||
if respID != nil {
|
||||
var idMatch bool
|
||||
switch v := respID.(type) {
|
||||
case float64:
|
||||
idMatch = int(v) == requestID.(int)
|
||||
case int:
|
||||
idMatch = v == requestID.(int)
|
||||
}
|
||||
|
||||
if idMatch {
|
||||
// Check for errors
|
||||
if errorObj, hasError := jsonRpcResp["error"]; hasError {
|
||||
log.Errorf("Backend tool error: %v", errorObj)
|
||||
injectSSEResponseError(ctx, fmt.Errorf("backend tool call failed"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Extract result and return to client
|
||||
if result, hasResult := jsonRpcResp["result"]; hasResult {
|
||||
if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
// Apply allowTools filtering if this is a tools/list response
|
||||
filteredResult := resultMap
|
||||
if _, hasTools := resultMap["tools"]; hasTools {
|
||||
// Get pre-computed effective allowTools from context
|
||||
if allowToolsCtx := ctx.GetContext("mcp_proxy_effective_allow_tools"); allowToolsCtx != nil {
|
||||
if effectiveAllowTools, ok := allowToolsCtx.(*map[string]struct{}); ok && effectiveAllowTools != nil {
|
||||
// Apply filtering
|
||||
if tools, hasToolsArray := resultMap["tools"]; hasToolsArray {
|
||||
if toolsArray, ok := tools.([]interface{}); ok {
|
||||
filteredTools := make([]interface{}, 0)
|
||||
for _, tool := range toolsArray {
|
||||
if toolMap, ok := tool.(map[string]interface{}); ok {
|
||||
if name, hasName := toolMap["name"]; hasName {
|
||||
if toolName, ok := name.(string); ok {
|
||||
if _, allow := (*effectiveAllowTools)[toolName]; allow {
|
||||
filteredTools = append(filteredTools, tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create filtered result
|
||||
filteredResult = make(map[string]interface{})
|
||||
for k, v := range resultMap {
|
||||
filteredResult[k] = v
|
||||
}
|
||||
filteredResult["tools"] = filteredTools
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
injectSSEResponseSuccess(ctx, filteredResult)
|
||||
// Clear buffer as we've processed the response
|
||||
*buffer = []byte{}
|
||||
ctx.SetContext(CtxSSEProxyBuffer, *buffer)
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("Invalid tool response format")
|
||||
injectSSEResponseError(ctx, errors.New("invalid response format"), utils.ErrInternalError)
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip other message types (like ping) while waiting for tool response
|
||||
// Continue to process next message in buffer
|
||||
log.Debugf("Skipping SSE message with event '%s' while waiting for tool response", msg.Event)
|
||||
continue
|
||||
}
|
||||
}
|
||||
297
plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go
Normal file
297
plugins/wasm-go/pkg/mcp/server/sse_proxy_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestParseSSEMessage tests SSE message parsing
|
||||
func TestParseSSEMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantEvent string
|
||||
wantData string
|
||||
wantID string
|
||||
shouldParse bool
|
||||
}{
|
||||
{
|
||||
name: "endpoint message",
|
||||
input: []byte(`event: endpoint
|
||||
data: /messages/?session_id=test123
|
||||
|
||||
`),
|
||||
wantEvent: "endpoint",
|
||||
wantData: "/messages/?session_id=test123",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "message with JSON data",
|
||||
input: []byte(`event: message
|
||||
data: {"jsonrpc":"2.0","id":1,"result":{"test":"value"}}
|
||||
|
||||
`),
|
||||
wantEvent: "message",
|
||||
wantData: `{"jsonrpc":"2.0","id":1,"result":{"test":"value"}}`,
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete message",
|
||||
input: []byte(`event: message
|
||||
data: {"jsonrpc":"2.0"`),
|
||||
shouldParse: false,
|
||||
},
|
||||
{
|
||||
name: "message with id",
|
||||
input: []byte(`id: 123
|
||||
event: message
|
||||
data: test data
|
||||
|
||||
`),
|
||||
wantEvent: "message",
|
||||
wantData: "test data",
|
||||
wantID: "123",
|
||||
shouldParse: true,
|
||||
},
|
||||
{
|
||||
name: "comment line ignored",
|
||||
input: []byte(`: this is a comment
|
||||
event: message
|
||||
data: test data
|
||||
|
||||
`),
|
||||
wantEvent: "message",
|
||||
wantData: "test data",
|
||||
shouldParse: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg, remaining, err := ParseSSEMessage(tt.input)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("parseSSEMessage() error = %v", err)
|
||||
}
|
||||
|
||||
if tt.shouldParse {
|
||||
if msg == nil {
|
||||
t.Errorf("parseSSEMessage() expected message but got nil")
|
||||
return
|
||||
}
|
||||
if msg.Event != tt.wantEvent {
|
||||
t.Errorf("parseSSEMessage() Event = %v, want %v", msg.Event, tt.wantEvent)
|
||||
}
|
||||
if msg.Data != tt.wantData {
|
||||
t.Errorf("parseSSEMessage() Data = %v, want %v", msg.Data, tt.wantData)
|
||||
}
|
||||
if msg.ID != tt.wantID {
|
||||
t.Errorf("parseSSEMessage() ID = %v, want %v", msg.ID, tt.wantID)
|
||||
}
|
||||
if len(remaining) != 0 {
|
||||
t.Errorf("parseSSEMessage() expected no remaining bytes, got %d bytes", len(remaining))
|
||||
}
|
||||
} else {
|
||||
if msg != nil {
|
||||
t.Errorf("parseSSEMessage() expected no message but got %v", msg)
|
||||
}
|
||||
if len(remaining) != len(tt.input) {
|
||||
t.Errorf("parseSSEMessage() expected all data as remaining, got %d bytes instead of %d", len(remaining), len(tt.input))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractEndpointURL tests endpoint URL extraction
|
||||
func TestExtractEndpointURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpointData string
|
||||
baseURL string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "full URL",
|
||||
endpointData: "http://example.com/messages?session=123",
|
||||
baseURL: "http://backend.com/mcp",
|
||||
want: "http://example.com/messages?session=123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path only",
|
||||
endpointData: "/messages/?session_id=abc",
|
||||
baseURL: "http://backend.com/mcp",
|
||||
want: "http://backend.com/messages/?session_id=abc",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "https base URL",
|
||||
endpointData: "/sse/endpoint",
|
||||
baseURL: "https://secure.backend.com:8443/api",
|
||||
want: "https://secure.backend.com:8443/sse/endpoint",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path-only base URL",
|
||||
endpointData: "/messages",
|
||||
baseURL: "/api/v1",
|
||||
want: "/messages",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path without leading slash",
|
||||
endpointData: "api/v1/messages",
|
||||
baseURL: "http://backend.com",
|
||||
want: "http://backend.com/api/v1/messages",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path without leading slash with port",
|
||||
endpointData: "sse/endpoint",
|
||||
baseURL: "https://secure.backend.com:8443",
|
||||
want: "https://secure.backend.com:8443/sse/endpoint",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ExtractEndpointURL(tt.endpointData, tt.baseURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("extractEndpointURL() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("extractEndpointURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransportProtocolValidation tests transport protocol validation
|
||||
func TestTransportProtocolValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
transport string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid http transport",
|
||||
transport: "http",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid sse transport",
|
||||
transport: "sse",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "invalid transport",
|
||||
transport: "websocket",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "empty transport",
|
||||
transport: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
transport := TransportProtocol(tt.transport)
|
||||
isValid := transport == TransportHTTP || transport == TransportSSE
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("TransportProtocol validation = %v, want %v for %s", isValid, tt.wantValid, tt.transport)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMcpProxyServerTransport tests transport getter/setter
|
||||
func TestMcpProxyServerTransport(t *testing.T) {
|
||||
server := NewMcpProxyServer("test-server")
|
||||
|
||||
// Test default transport
|
||||
if server.GetTransport() != "" {
|
||||
t.Errorf("Expected empty default transport, got %v", server.GetTransport())
|
||||
}
|
||||
|
||||
// Test setting HTTP transport
|
||||
server.SetTransport(TransportHTTP)
|
||||
if server.GetTransport() != TransportHTTP {
|
||||
t.Errorf("Expected HTTP transport, got %v", server.GetTransport())
|
||||
}
|
||||
|
||||
// Test setting SSE transport
|
||||
server.SetTransport(TransportSSE)
|
||||
if server.GetTransport() != TransportSSE {
|
||||
t.Errorf("Expected SSE transport, got %v", server.GetTransport())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEMessageParsing_MultipleMessages tests parsing multiple SSE messages
|
||||
func TestSSEMessageParsing_MultipleMessages(t *testing.T) {
|
||||
data := []byte(`event: endpoint
|
||||
data: /messages/123
|
||||
|
||||
event: message
|
||||
data: {"id":1}
|
||||
|
||||
: comment line
|
||||
event: message
|
||||
data: {"id":2}
|
||||
|
||||
`)
|
||||
|
||||
// First message
|
||||
msg1, remaining, err := ParseSSEMessage(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse first message: %v", err)
|
||||
}
|
||||
if msg1 == nil || msg1.Event != "endpoint" || msg1.Data != "/messages/123" {
|
||||
t.Errorf("First message incorrect: %+v", msg1)
|
||||
}
|
||||
|
||||
// Second message
|
||||
msg2, remaining, err := ParseSSEMessage(remaining)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse second message: %v", err)
|
||||
}
|
||||
if msg2 == nil || msg2.Event != "message" || msg2.Data != `{"id":1}` {
|
||||
t.Errorf("Second message incorrect: %+v", msg2)
|
||||
}
|
||||
|
||||
// Third message
|
||||
msg3, remaining, err := ParseSSEMessage(remaining)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse third message: %v", err)
|
||||
}
|
||||
if msg3 == nil || msg3.Event != "message" || msg3.Data != `{"id":2}` {
|
||||
t.Errorf("Third message incorrect: %+v", msg3)
|
||||
}
|
||||
|
||||
// Should be no more complete messages
|
||||
msg4, _, err := ParseSSEMessage(remaining)
|
||||
if err != nil {
|
||||
t.Fatalf("Error parsing remaining data: %v", err)
|
||||
}
|
||||
if msg4 != nil {
|
||||
t.Errorf("Expected no more messages, got: %+v", msg4)
|
||||
}
|
||||
}
|
||||
209
plugins/wasm-go/pkg/mcp/utils/json_rpc.go
Normal file
209
plugins/wasm-go/pkg/mcp/utils/json_rpc.go
Normal file
@@ -0,0 +1,209 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/iface"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
pb "github.com/higress-group/wasm-go/pkg/protos"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
const (
|
||||
CtxJsonRpcID = "jsonRpcID"
|
||||
CtxNeedPause = "needPause" // Context key to signal if the handler needs to pause
|
||||
JError = "error"
|
||||
JCode = "code"
|
||||
JMessage = "message"
|
||||
JResult = "result"
|
||||
|
||||
ErrParseError = -32700
|
||||
ErrInvalidRequest = -32600
|
||||
ErrMethodNotFound = -32601
|
||||
ErrInvalidParams = -32602
|
||||
ErrInternalError = -32603
|
||||
)
|
||||
|
||||
// JsonRpcID represents a JSON-RPC ID which can be either a string or a number
|
||||
type JsonRpcID struct {
|
||||
StringValue string
|
||||
IntValue int64
|
||||
IsString bool
|
||||
}
|
||||
|
||||
// NewJsonRpcIDFromGjson creates a JsonRpcID from a gjson.Result
|
||||
func NewJsonRpcIDFromGjson(result gjson.Result) JsonRpcID {
|
||||
if result.Type == gjson.String {
|
||||
return JsonRpcID{
|
||||
StringValue: result.String(),
|
||||
IsString: true,
|
||||
}
|
||||
}
|
||||
return JsonRpcID{
|
||||
IntValue: result.Int(),
|
||||
IsString: false,
|
||||
}
|
||||
}
|
||||
|
||||
type JsonRpcRequestHandler func(context wrapper.HttpContext, id JsonRpcID, method string, params gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type JsonRpcResponseHandler func(context wrapper.HttpContext, id JsonRpcID, result gjson.Result, error gjson.Result, rawBody []byte) types.Action
|
||||
|
||||
type JsonRpcMethodHandler func(context wrapper.HttpContext, id JsonRpcID, params gjson.Result) error
|
||||
|
||||
type MethodHandlers map[string]JsonRpcMethodHandler
|
||||
|
||||
func makeHttpResponse(ctx wrapper.HttpContext, code uint32, debugInfo string, headers [][2]string, body []byte) {
|
||||
phase := ctx.GetExecutionPhase()
|
||||
if phase < iface.EncodeHeader {
|
||||
proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1)
|
||||
return
|
||||
}
|
||||
if debugInfo != "" {
|
||||
log.Infof("response detail info:%s", debugInfo)
|
||||
}
|
||||
proxywasm.RemoveHttpResponseHeader("content-length")
|
||||
proxywasm.ReplaceHttpResponseHeader(":status", strconv.Itoa(int(code)))
|
||||
for _, kv := range headers {
|
||||
proxywasm.ReplaceHttpResponseHeader(kv[0], kv[1])
|
||||
}
|
||||
if phase == iface.EncodeData {
|
||||
proxywasm.ReplaceHttpResponseBody(body)
|
||||
return
|
||||
}
|
||||
// EncodeHeader phase
|
||||
args := &pb.InjectEncodedDataToFilterChainArguments{
|
||||
Body: string(body),
|
||||
Endstream: true,
|
||||
}
|
||||
argsStr, _ := proto.Marshal(args)
|
||||
_, err := proxywasm.CallForeignFunction("inject_encoded_data_to_filter_chain_on_header", argsStr)
|
||||
if err != nil {
|
||||
log.Warnf("call inject_encoded_data_to_filter_chain_on_header failed, err:%v, fallback to send directly", err)
|
||||
proxywasm.SendHttpResponseWithDetail(code, debugInfo, headers, body, -1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func sendJsonRpcResponse(ctx wrapper.HttpContext, id JsonRpcID, extras map[string]any, debugInfo string) {
|
||||
body := []byte(`{"jsonrpc": "2.0"}`)
|
||||
if id.IsString {
|
||||
body, _ = sjson.SetBytes(body, "id", id.StringValue)
|
||||
} else {
|
||||
body, _ = sjson.SetBytes(body, "id", id.IntValue)
|
||||
}
|
||||
for key, value := range extras {
|
||||
body, _ = sjson.SetBytes(body, key, value)
|
||||
}
|
||||
makeHttpResponse(ctx, 200, debugInfo, [][2]string{{"Content-Type", "application/json; charset=utf-8"}}, body)
|
||||
}
|
||||
|
||||
func OnJsonRpcResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo ...string) {
|
||||
var (
|
||||
id JsonRpcID
|
||||
ok bool
|
||||
)
|
||||
idRaw := ctx.GetContext(CtxJsonRpcID)
|
||||
if id, ok = idRaw.(JsonRpcID); !ok {
|
||||
makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id"))
|
||||
return
|
||||
}
|
||||
responseDebugInfo := "json_rpc_success"
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
sendJsonRpcResponse(ctx, id, map[string]any{JResult: result}, responseDebugInfo)
|
||||
}
|
||||
|
||||
func OnJsonRpcResponseError(ctx wrapper.HttpContext, err error, errorCode int, debugInfo ...string) {
|
||||
var (
|
||||
id JsonRpcID
|
||||
ok bool
|
||||
)
|
||||
idRaw := ctx.GetContext(CtxJsonRpcID)
|
||||
if id, ok = idRaw.(JsonRpcID); !ok {
|
||||
makeHttpResponse(ctx, 500, "not_found_json_rpc_id", nil, []byte("not found json rpc id"))
|
||||
return
|
||||
}
|
||||
responseDebugInfo := fmt.Sprintf("json_rpc_error(%s)", err)
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
sendJsonRpcResponse(ctx, id, map[string]any{JError: map[string]any{
|
||||
JMessage: err.Error(),
|
||||
JCode: errorCode,
|
||||
}}, responseDebugInfo)
|
||||
}
|
||||
|
||||
func HandleJsonRpcMethod(ctx wrapper.HttpContext, body []byte, handles MethodHandlers) types.Action {
|
||||
idResult := gjson.GetBytes(body, "id")
|
||||
id := NewJsonRpcIDFromGjson(idResult)
|
||||
ctx.SetContext(CtxJsonRpcID, id)
|
||||
method := gjson.GetBytes(body, "method").String()
|
||||
params := gjson.GetBytes(body, "params")
|
||||
if method != "" {
|
||||
if handle, ok := handles[method]; ok {
|
||||
log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw)
|
||||
|
||||
// Clear pause flag before calling handler
|
||||
ctx.SetContext(CtxNeedPause, false)
|
||||
|
||||
err := handle(ctx, id, params)
|
||||
if err != nil {
|
||||
OnJsonRpcResponseError(ctx, err, ErrInvalidRequest)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Check if the handler set the pause flag
|
||||
if needPause := ctx.GetContext(CtxNeedPause); needPause != nil && needPause.(bool) {
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
OnJsonRpcResponseError(ctx, fmt.Errorf("method not found:%s", method), ErrMethodNotFound)
|
||||
} else {
|
||||
proxywasm.SendHttpResponseWithDetail(202, "json_rpc_ack", nil, nil, -1)
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func HandleJsonRpcRequest(ctx wrapper.HttpContext, body []byte, handle JsonRpcRequestHandler) types.Action {
|
||||
idResult := gjson.GetBytes(body, "id")
|
||||
id := NewJsonRpcIDFromGjson(idResult)
|
||||
ctx.SetContext(CtxJsonRpcID, id)
|
||||
method := gjson.GetBytes(body, "method").String()
|
||||
params := gjson.GetBytes(body, "params")
|
||||
log.Debugf("json rpc call method[%s] with params[%s]", method, params.Raw)
|
||||
return handle(ctx, id, method, params, body)
|
||||
}
|
||||
|
||||
func HandleJsonRpcResponse(ctx wrapper.HttpContext, body []byte, handle JsonRpcResponseHandler) types.Action {
|
||||
idResult := gjson.GetBytes(body, "id")
|
||||
id := NewJsonRpcIDFromGjson(idResult)
|
||||
error := gjson.GetBytes(body, "error")
|
||||
result := gjson.GetBytes(body, "result")
|
||||
log.Debugf("json rpc response error[%s] result[%s]", error.Raw, result.Raw)
|
||||
return handle(ctx, id, result, error, body)
|
||||
}
|
||||
160
plugins/wasm-go/pkg/mcp/utils/json_rpc_test.go
Normal file
160
plugins/wasm-go/pkg/mcp/utils/json_rpc_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestJsonRpcIDFromGjson(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonData string
|
||||
expected JsonRpcID
|
||||
}{
|
||||
{
|
||||
name: "integer id",
|
||||
jsonData: `{"id": 123}`,
|
||||
expected: JsonRpcID{
|
||||
IntValue: 123,
|
||||
IsString: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string id",
|
||||
jsonData: `{"id": "abc-123"}`,
|
||||
expected: JsonRpcID{
|
||||
StringValue: "abc-123",
|
||||
IsString: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float id treated as int",
|
||||
jsonData: `{"id": 123.45}`,
|
||||
expected: JsonRpcID{
|
||||
IntValue: 123,
|
||||
IsString: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean id treated as int",
|
||||
jsonData: `{"id": true}`,
|
||||
expected: JsonRpcID{
|
||||
IntValue: 1,
|
||||
IsString: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null id treated as int zero",
|
||||
jsonData: `{"id": null}`,
|
||||
expected: JsonRpcID{
|
||||
IntValue: 0,
|
||||
IsString: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idResult := gjson.Get(tt.jsonData, "id")
|
||||
result := NewJsonRpcIDFromGjson(idResult)
|
||||
|
||||
if result.IsString != tt.expected.IsString {
|
||||
t.Errorf("IsString = %v, want %v", result.IsString, tt.expected.IsString)
|
||||
}
|
||||
|
||||
if result.IsString {
|
||||
if result.StringValue != tt.expected.StringValue {
|
||||
t.Errorf("StringValue = %v, want %v", result.StringValue, tt.expected.StringValue)
|
||||
}
|
||||
} else {
|
||||
if result.IntValue != tt.expected.IntValue {
|
||||
t.Errorf("IntValue = %v, want %v", result.IntValue, tt.expected.IntValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Skip TestSendJsonRpcResponse because it requires proxywasm which is not available in the test environment
|
||||
// This function would normally test that sendJsonRpcResponse correctly handles different ID types
|
||||
func TestSendJsonRpcResponse(t *testing.T) {
|
||||
t.Skip("Skipping test that requires proxywasm")
|
||||
}
|
||||
|
||||
func TestJsonRpcIDMarshaling(t *testing.T) {
|
||||
// Test that JsonRpcID is correctly marshaled in a JSON response
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id JsonRpcID
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "integer id",
|
||||
id: JsonRpcID{
|
||||
IntValue: 123,
|
||||
IsString: false,
|
||||
},
|
||||
expected: `"id":123`,
|
||||
},
|
||||
{
|
||||
name: "string id",
|
||||
id: JsonRpcID{
|
||||
StringValue: "abc-123",
|
||||
IsString: true,
|
||||
},
|
||||
expected: `"id":"abc-123"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a JSON object with the ID
|
||||
var jsonObj map[string]interface{}
|
||||
if tt.id.IsString {
|
||||
jsonObj = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": tt.id.StringValue,
|
||||
}
|
||||
} else {
|
||||
jsonObj = map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": tt.id.IntValue,
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
body, err := json.Marshal(jsonObj)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check that the ID is correctly marshaled
|
||||
if !json.Valid(body) {
|
||||
t.Errorf("Invalid JSON: %s", string(body))
|
||||
}
|
||||
|
||||
// Check that the ID is correctly formatted
|
||||
if !strings.Contains(string(body), tt.expected) {
|
||||
t.Errorf("ID not correctly formatted. Expected to contain %s, got %s", tt.expected, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
130
plugins/wasm-go/pkg/mcp/utils/log.go
Normal file
130
plugins/wasm-go/pkg/mcp/utils/log.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type MCPServerLog struct {
|
||||
}
|
||||
|
||||
func setMCPInfo(msg string) string {
|
||||
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
|
||||
requestID := string(requestIDRaw)
|
||||
if requestID == "" {
|
||||
requestID = "nil"
|
||||
}
|
||||
mcpServerNameRaw, _ := proxywasm.GetProperty([]string{"mcp_server_name"})
|
||||
mcpServerName := string(mcpServerNameRaw)
|
||||
mcpToolNameRaw, _ := proxywasm.GetProperty([]string{"mcp_tool_name"})
|
||||
mcpToolName := string(mcpToolNameRaw)
|
||||
mcpInfo := mcpServerName
|
||||
if mcpToolName != "" {
|
||||
mcpInfo = fmt.Sprintf("%s/%s", mcpServerName, mcpToolName)
|
||||
}
|
||||
return fmt.Sprintf("[mcp-server] [%s] [%s] %s", mcpInfo, requestID, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) log(level wrapper.LogLevel, msg string) {
|
||||
msg = setMCPInfo(msg)
|
||||
switch level {
|
||||
case wrapper.LogLevelTrace:
|
||||
proxywasm.LogTrace(msg)
|
||||
case wrapper.LogLevelDebug:
|
||||
proxywasm.LogDebug(msg)
|
||||
case wrapper.LogLevelInfo:
|
||||
proxywasm.LogInfo(msg)
|
||||
case wrapper.LogLevelWarn:
|
||||
proxywasm.LogWarn(msg)
|
||||
case wrapper.LogLevelError:
|
||||
proxywasm.LogError(msg)
|
||||
case wrapper.LogLevelCritical:
|
||||
proxywasm.LogCritical(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (l MCPServerLog) logFormat(level wrapper.LogLevel, format string, args ...interface{}) {
|
||||
format = setMCPInfo(format)
|
||||
switch level {
|
||||
case wrapper.LogLevelTrace:
|
||||
proxywasm.LogTracef(format, args...)
|
||||
case wrapper.LogLevelDebug:
|
||||
proxywasm.LogDebugf(format, args...)
|
||||
case wrapper.LogLevelInfo:
|
||||
proxywasm.LogInfof(format, args...)
|
||||
case wrapper.LogLevelWarn:
|
||||
proxywasm.LogWarnf(format, args...)
|
||||
case wrapper.LogLevelError:
|
||||
proxywasm.LogErrorf(format, args...)
|
||||
case wrapper.LogLevelCritical:
|
||||
proxywasm.LogCriticalf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Trace(msg string) {
|
||||
l.log(wrapper.LogLevelTrace, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Tracef(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelTrace, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Debug(msg string) {
|
||||
l.log(wrapper.LogLevelDebug, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Debugf(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelDebug, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Info(msg string) {
|
||||
l.log(wrapper.LogLevelInfo, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Infof(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelInfo, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Warn(msg string) {
|
||||
l.log(wrapper.LogLevelWarn, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Warnf(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelWarn, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Error(msg string) {
|
||||
l.log(wrapper.LogLevelError, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Errorf(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelError, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Critical(msg string) {
|
||||
l.log(wrapper.LogLevelCritical, msg)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) Criticalf(format string, args ...interface{}) {
|
||||
l.logFormat(wrapper.LogLevelCritical, format, args...)
|
||||
}
|
||||
|
||||
func (l MCPServerLog) ResetID(pluginID string) {
|
||||
}
|
||||
117
plugins/wasm-go/pkg/mcp/utils/mcp_rpc.go
Normal file
117
plugins/wasm-go/pkg/mcp/utils/mcp_rpc.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func OnMCPResponseSuccess(ctx wrapper.HttpContext, result map[string]any, debugInfo string) {
|
||||
OnJsonRpcResponseSuccess(ctx, result, debugInfo)
|
||||
// TODO: support pub to redis when use POST + SSE
|
||||
}
|
||||
|
||||
func OnMCPResponseError(ctx wrapper.HttpContext, err error, code int, debugInfo string) {
|
||||
OnJsonRpcResponseError(ctx, err, code, debugInfo)
|
||||
// TODO: support pub to redis when use POST + SSE
|
||||
}
|
||||
|
||||
func OnMCPToolCallSuccess(ctx wrapper.HttpContext, content []map[string]any, debugInfo string) {
|
||||
OnMCPResponseSuccess(ctx, map[string]any{
|
||||
"content": content,
|
||||
"isError": false,
|
||||
}, debugInfo)
|
||||
}
|
||||
|
||||
// OnMCPToolCallSuccessWithStructuredContent sends a successful MCP tool response with structured content
|
||||
// According to MCP spec, structuredContent is a field in tool results, not a capability
|
||||
func OnMCPToolCallSuccessWithStructuredContent(ctx wrapper.HttpContext, content []map[string]any, structuredContent json.RawMessage, debugInfo string) {
|
||||
response := map[string]any{
|
||||
"content": content,
|
||||
"isError": false,
|
||||
}
|
||||
if structuredContent != nil && len(structuredContent) > 0 {
|
||||
response["structuredContent"] = structuredContent
|
||||
}
|
||||
OnMCPResponseSuccess(ctx, response, debugInfo)
|
||||
}
|
||||
|
||||
func OnMCPToolCallError(ctx wrapper.HttpContext, err error, debugInfo ...string) {
|
||||
responseDebugInfo := fmt.Sprintf("mcp:tools/call:error(%s)", err)
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
OnMCPResponseSuccess(ctx, map[string]any{
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": err.Error(),
|
||||
},
|
||||
},
|
||||
"isError": true,
|
||||
}, responseDebugInfo)
|
||||
}
|
||||
|
||||
func SendMCPToolTextResult(ctx wrapper.HttpContext, result string, debugInfo ...string) {
|
||||
responseDebugInfo := "mcp:tools/call::result"
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
OnMCPToolCallSuccess(ctx, []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": result,
|
||||
},
|
||||
}, responseDebugInfo)
|
||||
}
|
||||
|
||||
func SendMCPToolImageResult(ctx wrapper.HttpContext, image []byte, contentType string, debugInfo ...string) {
|
||||
responseDebugInfo := "mcp:tools/call::result"
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
|
||||
content := []map[string]any{
|
||||
{
|
||||
"type": "image",
|
||||
"data": base64.StdEncoding.EncodeToString(image),
|
||||
"mimeType": contentType,
|
||||
},
|
||||
}
|
||||
|
||||
// Use traditional response format since no structured data is provided
|
||||
OnMCPToolCallSuccess(ctx, content, responseDebugInfo)
|
||||
}
|
||||
|
||||
// SendMCPToolTextResultWithStructuredContent sends a tool result with both text content and structured content
|
||||
// According to MCP spec, for backward compatibility, tools that return structured content
|
||||
// SHOULD also return the serialized JSON in a TextContent block
|
||||
func SendMCPToolTextResultWithStructuredContent(ctx wrapper.HttpContext, textResult string, structuredContent json.RawMessage, debugInfo ...string) {
|
||||
responseDebugInfo := "mcp:tools/call::result"
|
||||
if len(debugInfo) > 0 {
|
||||
responseDebugInfo = debugInfo[0]
|
||||
}
|
||||
content := []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": textResult,
|
||||
},
|
||||
}
|
||||
OnMCPToolCallSuccessWithStructuredContent(ctx, content, structuredContent, responseDebugInfo)
|
||||
}
|
||||
51
plugins/wasm-go/pkg/mcp/utils/session.go
Normal file
51
plugins/wasm-go/pkg/mcp/utils/session.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func IsStatefulSession(ctx wrapper.HttpContext) bool {
|
||||
parse, err := url.Parse(ctx.Path())
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse request path: %v", err)
|
||||
return false
|
||||
}
|
||||
query, err := url.ParseQuery(parse.RawQuery)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse query params: %v", err)
|
||||
return false
|
||||
}
|
||||
// Protocol version: 2024-11-05
|
||||
if query.Get("sessionId") != "" {
|
||||
return true
|
||||
}
|
||||
// Protocol version: 2025-03-26
|
||||
sessionHeader, err := proxywasm.GetHttpRequestHeader("mcp-session-id")
|
||||
if err != nil {
|
||||
log.Errorf("failed to get request header: %v", err)
|
||||
return false
|
||||
}
|
||||
if sessionHeader != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user