mirror of
https://github.com/alibaba/higress.git
synced 2026-04-13 16:17:27 +08:00
Compare commits
91 Commits
v2.2.0
...
add-releas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d1f8482bb | ||
|
|
2c15f97246 | ||
|
|
69523292cb | ||
|
|
60ce07d297 | ||
|
|
228eb27e6a | ||
|
|
1c9e981bf2 | ||
|
|
89587c1c9b | ||
|
|
cd8ed99db5 | ||
|
|
889ea67013 | ||
|
|
83461887dc | ||
|
|
983c57f404 | ||
|
|
231ba1cd23 | ||
|
|
3fc01913cf | ||
|
|
213286bb9e | ||
|
|
c75f741104 | ||
|
|
cce53203ce | ||
|
|
23a0299d3b | ||
|
|
36a1680952 | ||
|
|
ca22fcb90b | ||
|
|
70ec36107a | ||
|
|
ca7ee6ef5f | ||
|
|
045238944d | ||
|
|
62df71aadf | ||
|
|
8961db2e90 | ||
|
|
94f0d7179f | ||
|
|
f1e305844e | ||
|
|
68d6090e36 | ||
|
|
65aba909d7 | ||
|
|
528e6c9908 | ||
|
|
13b808c1e4 | ||
|
|
aa502e7e62 | ||
|
|
2e3f6868df | ||
|
|
6c9747d778 | ||
|
|
c12183cae5 | ||
|
|
e2a22d1171 | ||
|
|
e9aecb6e1f | ||
|
|
b3fb6324a4 | ||
|
|
8576128e4c | ||
|
|
caa5317723 | ||
|
|
093ef9a2c0 | ||
|
|
9346f1340b | ||
|
|
87c6cc9c9f | ||
|
|
ac29ba6984 | ||
|
|
1c847dd553 | ||
|
|
a07f5024a9 | ||
|
|
814c3307ba | ||
|
|
b76a3aca5e | ||
|
|
28df33c596 | ||
|
|
8e7292c42e | ||
|
|
d03932b3ea | ||
|
|
5a2ff8c836 | ||
|
|
6f8ef2ff69 | ||
|
|
67e2913f3d | ||
|
|
e996194228 | ||
|
|
95f86d7ab5 | ||
|
|
5d5d20df1f | ||
|
|
1ddc07992c | ||
|
|
13ed2284ae | ||
|
|
f9c7527753 | ||
|
|
c2be0e8c9a | ||
|
|
927fb52309 | ||
|
|
c0761c4553 | ||
|
|
4f857597da | ||
|
|
0d45ce755f | ||
|
|
44d688a168 | ||
|
|
0d9354da16 | ||
|
|
65834bff21 | ||
|
|
668c2b3669 | ||
|
|
ff4de901e7 | ||
|
|
a1967adb94 | ||
|
|
f6cb3031fe | ||
|
|
d4a0665957 | ||
|
|
2c7771da42 | ||
|
|
75c6fbe090 | ||
|
|
b153d08610 | ||
|
|
de633d8610 | ||
|
|
f2e4942f00 | ||
|
|
1b3a8b762b | ||
|
|
c885b89d03 | ||
|
|
ce4dff9887 | ||
|
|
6935a44d53 | ||
|
|
b33e2be5e9 | ||
|
|
d2385f1b30 | ||
|
|
ef5e3ee31b | ||
|
|
d2b0885236 | ||
|
|
6cb48247fd | ||
|
|
773f639260 | ||
|
|
fe58ce3943 | ||
|
|
0dbc056ce9 | ||
|
|
3bf39b60ea | ||
|
|
e9bb5d3255 |
@@ -1,435 +0,0 @@
|
||||
---
|
||||
name: higress-clawdbot-integration
|
||||
description: "Deploy and configure Higress AI Gateway for Clawdbot/OpenClaw integration. Use when: (1) User wants to deploy Higress AI Gateway, (2) User wants to configure Clawdbot/OpenClaw to use Higress as a model provider, (3) User mentions 'higress', 'ai gateway', 'model gateway', 'AI网关', (4) User wants to set up model routing or auto-routing, (5) User needs to manage LLM provider API keys, (6) User wants to track token usage and conversation history."
|
||||
---
|
||||
|
||||
# Higress AI Gateway Integration
|
||||
|
||||
Deploy and configure Higress AI Gateway for Clawdbot/OpenClaw integration with one-click deployment, model provider configuration, auto-routing, and session monitoring.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker installed and running
|
||||
- Internet access to download setup script
|
||||
- LLM provider API keys (at least one)
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Download Setup Script
|
||||
|
||||
Download official get-ai-gateway.sh script:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/higress-group/higress-standalone/main/all-in-one/get-ai-gateway.sh -o get-ai-gateway.sh
|
||||
chmod +x get-ai-gateway.sh
|
||||
```
|
||||
|
||||
### Step 2: Gather Configuration
|
||||
|
||||
Ask user for:
|
||||
|
||||
1. **LLM Provider API Keys** (at least one required):
|
||||
|
||||
**Top Commonly Used Providers:**
|
||||
- Aliyun Dashscope (Qwen): `--dashscope-key`
|
||||
- DeepSeek: `--deepseek-key`
|
||||
- Moonshot (Kimi): `--moonshot-key`
|
||||
- Zhipu AI: `--zhipuai-key`
|
||||
- Claude Code (OAuth mode): `--claude-code-key` (run `claude setup-token` to get token)
|
||||
- Claude: `--claude-key`
|
||||
- Minimax: `--minimax-key`
|
||||
- Azure OpenAI: `--azure-key`
|
||||
- AWS Bedrock: `--bedrock-key`
|
||||
- Google Vertex AI: `--vertex-key`
|
||||
- OpenAI: `--openai-key`
|
||||
- OpenRouter: `--openrouter-key`
|
||||
- Grok: `--grok-key`
|
||||
|
||||
To configure additional providers beyond the above, run `./get-ai-gateway.sh --help` to view the complete list of supported models and providers.
|
||||
|
||||
2. **Port Configuration** (optional):
|
||||
- HTTP port: `--http-port` (default: 8080)
|
||||
- HTTPS port: `--https-port` (default: 8443)
|
||||
- Console port: `--console-port` (default: 8001)
|
||||
|
||||
3. **Auto-routing** (optional):
|
||||
- Enable: `--auto-routing`
|
||||
- Default model: `--auto-routing-default-model`
|
||||
|
||||
### Step 3: Run Setup Script
|
||||
|
||||
Run script in non-interactive mode with gathered parameters:
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx \
|
||||
--openai-key sk-xxx \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model qwen-turbo
|
||||
```
|
||||
|
||||
**Automatic Repository Selection:**
|
||||
|
||||
The script automatically detects your timezone and selects the geographically closest registry for both:
|
||||
- **Container image** (`IMAGE_REPO`)
|
||||
- **WASM plugins** (`PLUGIN_REGISTRY`)
|
||||
|
||||
| Region | Timezone Examples | Selected Registry |
|
||||
|--------|------------------|-------------------|
|
||||
| China & nearby | Asia/Shanghai, Asia/Hong_Kong, etc. | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
| Southeast Asia | Asia/Singapore, Asia/Jakarta, etc. | `higress-registry.ap-southeast-7.cr.aliyuncs.com` |
|
||||
| North America | America/*, US/*, Canada/* | `higress-registry.us-west-1.cr.aliyuncs.com` |
|
||||
| Others | Default fallback | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
|
||||
**Manual Override (optional):**
|
||||
|
||||
If you want to use a specific registry:
|
||||
|
||||
```bash
|
||||
IMAGE_REPO="higress-registry.ap-southeast-7.cr.aliyuncs.com/higress/all-in-one" \
|
||||
PLUGIN_REGISTRY="higress-registry.ap-southeast-7.cr.aliyuncs.com" \
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx \
|
||||
--openai-key sk-xxx
|
||||
```
|
||||
|
||||
### Step 4: Verify Deployment
|
||||
|
||||
After script completion:
|
||||
|
||||
1. Check container is running:
|
||||
```bash
|
||||
docker ps --filter "name=higress-ai-gateway"
|
||||
```
|
||||
|
||||
2. Test gateway endpoint:
|
||||
```bash
|
||||
curl http://localhost:8080/v1/models
|
||||
```
|
||||
|
||||
3. Access console (optional):
|
||||
```
|
||||
http://localhost:8001
|
||||
```
|
||||
|
||||
### Step 5: Configure Clawdbot/OpenClaw Plugin
|
||||
|
||||
If user wants to use Higress with Clawdbot/OpenClaw, install appropriate plugin:
|
||||
|
||||
#### Automatic Installation
|
||||
|
||||
Detect runtime and install correct plugin version:
|
||||
|
||||
```bash
|
||||
# Detect which runtime is installed
|
||||
if command -v clawdbot &> /dev/null; then
|
||||
RUNTIME="clawdbot"
|
||||
RUNTIME_DIR="$HOME/.clawdbot"
|
||||
PLUGIN_SRC="scripts/plugin-clawdbot"
|
||||
elif command -v openclaw &> /dev/null; then
|
||||
RUNTIME="openclaw"
|
||||
RUNTIME_DIR="$HOME/.openclaw"
|
||||
PLUGIN_SRC="scripts/plugin"
|
||||
else
|
||||
echo "Error: Neither clawdbot nor openclaw is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install plugin
|
||||
PLUGIN_DEST="$RUNTIME_DIR/extensions/higress-ai-gateway"
|
||||
echo "Installing Higress AI Gateway plugin for $RUNTIME..."
|
||||
mkdir -p "$(dirname "$PLUGIN_DEST")"
|
||||
[ -d "$PLUGIN_DEST" ] && rm -rf "$PLUGIN_DEST"
|
||||
cp -r "$PLUGIN_SRC" "$PLUGIN_DEST"
|
||||
echo "✓ Plugin installed at: $PLUGIN_DEST"
|
||||
|
||||
# Configure provider
|
||||
echo ""
|
||||
echo "Configuring provider..."
|
||||
$RUNTIME models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will guide you through an interactive setup for:
|
||||
1. Gateway URL (default: `http://localhost:8080`)
|
||||
2. Console URL (default: `http://localhost:8001`)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using `higress/auto`)
|
||||
|
||||
### Step 6: Manage API Keys (optional)
|
||||
|
||||
After deployment, manage API keys without redeploying:
|
||||
|
||||
```bash
|
||||
# View configured API keys
|
||||
./get-ai-gateway.sh config list
|
||||
# Add or update an API key (hot-reload, no restart needed)
|
||||
./get-ai-gateway.sh config add --provider <provider> --key <api-key>
|
||||
# Remove an API key (hot-reload, no restart needed)
|
||||
./get-ai-gateway.sh config remove --provider <provider>
|
||||
```
|
||||
|
||||
**Note:** Changes take effect immediately via hot-reload. No container restart required.
|
||||
|
||||
## CLI Parameters Reference
|
||||
|
||||
### Basic Options
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--non-interactive` | Run without prompts | - |
|
||||
| `--http-port` | Gateway HTTP port | 8080 |
|
||||
| `--https-port` | Gateway HTTPS port | 8443 |
|
||||
| `--console-port` | Console port | 8001 |
|
||||
| `--container-name` | Container name | higress-ai-gateway |
|
||||
| `--data-folder` | Data folder path | ./higress |
|
||||
| `--auto-routing` | Enable auto-routing feature | - |
|
||||
| `--auto-routing-default-model` | Default model when no rule matches | - |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `PLUGIN_REGISTRY` | Registry URL for container images and WASM plugins (auto-selected based on timezone) | `higress-registry.cn-hangzhou.cr.aliyuncs.com` |
|
||||
|
||||
**Auto-Selection Logic:**
|
||||
|
||||
The registry is automatically selected based on your timezone:
|
||||
|
||||
- **China & nearby** (Asia/Shanghai, etc.) → `higress-registry.cn-hangzhou.cr.aliyuncs.com`
|
||||
- **Southeast Asia** (Asia/Singapore, etc.) → `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
- **North America** (America/*, etc.) → `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
- **Others** → `higress-registry.cn-hangzhou.cr.aliyuncs.com` (default)
|
||||
|
||||
Both container images and WASM plugins use the same registry for consistency.
|
||||
|
||||
**Manual Override:**
|
||||
|
||||
```bash
|
||||
PLUGIN_REGISTRY="higress-registry.ap-southeast-7.cr.aliyuncs.com" \
|
||||
./get-ai-gateway.sh start --non-interactive ...
|
||||
```
|
||||
|
||||
### LLM Provider API Keys
|
||||
|
||||
**Top Providers:**
|
||||
|
||||
| Parameter | Provider |
|
||||
|-----------|----------|
|
||||
| `--dashscope-key` | Aliyun Dashscope (Qwen) |
|
||||
| `--deepseek-key` | DeepSeek |
|
||||
| `--moonshot-key` | Moonshot (Kimi) |
|
||||
| `--zhipuai-key` | Zhipu AI |
|
||||
| `--claude-code-key` | Claude Code (OAuth mode - run `claude setup-token` to get token) |
|
||||
| `--claude-key` | Claude |
|
||||
| `--openai-key` | OpenAI |
|
||||
| `--openrouter-key` | OpenRouter |
|
||||
| `--gemini-key` | Google Gemini |
|
||||
| `--groq-key` | Groq |
|
||||
|
||||
**Additional Providers:**
|
||||
|
||||
`--doubao-key`, `--baichuan-key`, `--yi-key`, `--stepfun-key`, `--minimax-key`, `--cohere-key`, `--mistral-key`, `--github-key`, `--fireworks-key`, `--togetherai-key`, `--grok-key`, `--azure-key`, `--bedrock-key`, `--vertex-key`
|
||||
|
||||
## Managing Configuration
|
||||
|
||||
### API Keys
|
||||
|
||||
```bash
|
||||
# List all configured API keys
|
||||
./get-ai-gateway.sh config list
|
||||
# Add or update an API key (hot-reload)
|
||||
./get-ai-gateway.sh config add --provider deepseek --key sk-xxx
|
||||
# Remove an API key (hot-reload)
|
||||
./get-ai-gateway.sh config remove --provider deepseek
|
||||
```
|
||||
|
||||
**Supported provider aliases:**
|
||||
|
||||
`dashscope`/`qwen`, `moonshot`/`kimi`, `zhipuai`/`zhipu`, `togetherai`/`together`
|
||||
|
||||
### Routing Rules
|
||||
|
||||
```bash
|
||||
# Add a routing rule
|
||||
./get-ai-gateway.sh route add --model claude-opus-4.5 --trigger "深入思考|deep thinking"
|
||||
# List all rules
|
||||
./get-ai-gateway.sh route list
|
||||
# Remove a rule
|
||||
./get-ai-gateway.sh route remove --rule-id 0
|
||||
```
|
||||
|
||||
See [higress-auto-router](../higress-auto-router/SKILL.md) for detailed documentation.
|
||||
|
||||
## Access Logs
|
||||
|
||||
Gateway access logs are available at:
|
||||
|
||||
```
|
||||
$DATA_FOLDER/logs/access.log
|
||||
```
|
||||
|
||||
These logs can be used with **agent-session-monitor** skill for token tracking and conversation analysis.
|
||||
|
||||
## Related Skills
|
||||
|
||||
- **higress-auto-router**: Configure automatic model routing using CLI commands
|
||||
See: [higress-auto-router](../higress-auto-router/SKILL.md)
|
||||
|
||||
- **agent-session-monitor**: Monitor and track token usage across sessions
|
||||
See: [agent-session-monitor](../agent-session-monitor/SKILL.md)
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Basic Deployment with Dashscope
|
||||
|
||||
**User:** 帮我部署一个Higress AI网关,使用阿里云的通义千问
|
||||
|
||||
**Steps:**
|
||||
1. Download script
|
||||
2. Get Dashscope API key from user
|
||||
3. Run (script auto-detects timezone and selects optimal registry):
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--dashscope-key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: Asia/Shanghai
|
||||
Selected plugin registry: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 部署完成!
|
||||
|
||||
网关地址: http://localhost:8080/v1/chat/completions
|
||||
控制台: http://localhost:8001
|
||||
日志目录: ./higress/logs
|
||||
|
||||
已配置的模型提供商:
|
||||
- Aliyun Dashscope (Qwen)
|
||||
|
||||
测试命令:
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "qwen-turbo", "messages": [{"role": "user", "content": "Hello!"}]}'
|
||||
```
|
||||
|
||||
### Example 2: Full Integration with Clawdbot
|
||||
|
||||
**User:** 完整配置Higress和Clawdbot的集成
|
||||
|
||||
**Steps:**
|
||||
1. Deploy Higress AI Gateway (auto-detects timezone)
|
||||
2. Install and configure Clawdbot plugin
|
||||
3. Enable auto-routing
|
||||
4. Set up session monitoring
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: Asia/Shanghai
|
||||
Selected plugin registry: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 集成完成!
|
||||
|
||||
1. 网关已部署:
|
||||
- HTTP: http://localhost:8080
|
||||
- Console: http://localhost:8001
|
||||
- 容器镜像: Hangzhou (自动选择)
|
||||
- 插件镜像: Hangzhou (自动选择)
|
||||
|
||||
2. Clawdbot 插件配置:
|
||||
Plugin installed at: /root/.clawdbot/extensions/higress-ai-gateway
|
||||
Run: clawdbot models auth login --provider higress
|
||||
|
||||
3. 自动路由:
|
||||
已启用,使用 model="higress/auto"
|
||||
|
||||
4. 会话监控:
|
||||
日志路径: ./higress/logs/access.log
|
||||
|
||||
需要我帮你配置自动路由规则吗?
|
||||
```
|
||||
|
||||
### Example 3: Manage API Keys
|
||||
|
||||
**User:** 帮我查看当前配置的API keys,并添加一个DeepSeek的key
|
||||
|
||||
**Steps:**
|
||||
1. List current API keys:
|
||||
```bash
|
||||
./get-ai-gateway.sh config list
|
||||
```
|
||||
|
||||
2. Add DeepSeek API key:
|
||||
```bash
|
||||
./get-ai-gateway.sh config add --provider deepseek --key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
当前配置的API keys:
|
||||
|
||||
Aliyun Dashscope (Qwen): sk-ab***ef12
|
||||
OpenAI: sk-cd***gh34
|
||||
|
||||
Adding API key for DeepSeek...
|
||||
|
||||
✅ API key updated successfully!
|
||||
|
||||
Provider: DeepSeek
|
||||
Key: sk-xx***yy56
|
||||
|
||||
Configuration has been hot-reloaded (no restart needed).
|
||||
```
|
||||
|
||||
### Example 4: North America Deployment
|
||||
|
||||
**User:** 帮我部署Higress AI网关
|
||||
|
||||
**Context:** User's timezone is America/Los_Angeles
|
||||
|
||||
**Steps:**
|
||||
1. Download script
|
||||
2. Get API keys from user
|
||||
3. Run (script auto-detects timezone and selects North America mirror):
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--openai-key sk-xxx \
|
||||
--openrouter-key sk-xxx
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```
|
||||
Auto-detected timezone: America/Los_Angeles
|
||||
Selected plugin registry: higress-registry.us-west-1.cr.aliyuncs.com
|
||||
|
||||
✅ Higress AI Gateway 部署完成!
|
||||
|
||||
网关地址: http://localhost:8080/v1/chat/completions
|
||||
控制台: http://localhost:8001
|
||||
日志目录: ./higress/logs
|
||||
|
||||
镜像优化:
|
||||
- 容器镜像: North America (基于时区自动选择)
|
||||
- 插件镜像: North America (基于时区自动选择)
|
||||
|
||||
已配置的模型提供商:
|
||||
- OpenAI
|
||||
- OpenRouter
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
For detailed troubleshooting guides, see [TROUBLESHOOTING.md](references/TROUBLESHOOTING.md).
|
||||
|
||||
Common issues:
|
||||
- **Container fails to start**: Check Docker status, port availability, and container logs
|
||||
- **"too many open files" error**: Increase `fs.inotify.max_user_instances` to 8192
|
||||
- **Gateway not responding**: Verify container status and port mapping
|
||||
- **Plugin not recognized**: Check installation path and restart runtime
|
||||
- **Auto-routing not working**: Verify model list and routing rules
|
||||
- **Timezone detection fails**: Manually set `IMAGE_REPO` environment variable
|
||||
@@ -1,79 +0,0 @@
|
||||
# Higress AI Gateway Plugin (Clawdbot)
|
||||
|
||||
Clawdbot model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables Clawdbot to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and Clawdbot extension declaration
|
||||
- **clawdbot.plugin.json**: Plugin manifest for Clawdbot
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-clawdbot-integration` skill. See the parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Copy plugin files
|
||||
mkdir -p "$HOME/.clawdbot/extensions/higress-ai-gateway"
|
||||
cp -r ./* "$HOME/.clawdbot/extensions/higress-ai-gateway/"
|
||||
|
||||
# Configure provider
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
## Auto-routing
|
||||
|
||||
To use auto-routing, include `higress/auto` in your model list during configuration. Then use it in your conversations:
|
||||
|
||||
```bash
|
||||
# Use auto-routing
|
||||
clawdbot chat --model higress/auto "深入思考 这个问题应该怎么解决?"
|
||||
|
||||
# The gateway will automatically route to the appropriate model based on:
|
||||
# - Message content triggers (configured via higress-auto-router skill)
|
||||
# - Fallback to default model if no rule matches
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-clawdbot-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
- **Session Monitoring**: [agent-session-monitor](../../agent-session-monitor/SKILL.md)
|
||||
- **Higress AI Gateway**: https://github.com/higress-group/higress-standalone
|
||||
|
||||
## Compatibility
|
||||
|
||||
- **Clawdbot**: v2.0.0+
|
||||
- **Higress AI Gateway**: All versions
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"id": "higress-ai-gateway",
|
||||
"name": "Higress AI Gateway",
|
||||
"description": "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
"providers": ["higress"],
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
import { emptyPluginConfigSchema } from "clawdbot/plugin-sdk";
|
||||
|
||||
const DEFAULT_GATEWAY_URL = "http://localhost:8080";
|
||||
const DEFAULT_CONSOLE_URL = "http://localhost:8001";
|
||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||
const DEFAULT_MAX_TOKENS = 8192;
|
||||
|
||||
// Common models that Higress AI Gateway typically supports
|
||||
const DEFAULT_MODEL_IDS = [
|
||||
// Auto-routing special model
|
||||
"higress/auto",
|
||||
// OpenAI models
|
||||
"gpt-5.2",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
// Anthropic models
|
||||
"claude-opus-4.5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-haiku-4.5",
|
||||
// Qwen models
|
||||
"qwen3-turbo",
|
||||
"qwen3-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-480b-a35b-instruct",
|
||||
// DeepSeek models
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
// Other common models
|
||||
"kimi-k2.5",
|
||||
"glm-4.7",
|
||||
"MiniMax-M2.1",
|
||||
] as const;
|
||||
|
||||
function normalizeBaseUrl(value: string): string {
|
||||
const trimmed = value.trim();
|
||||
if (!trimmed) return DEFAULT_GATEWAY_URL;
|
||||
let normalized = trimmed;
|
||||
while (normalized.endsWith("/")) normalized = normalized.slice(0, -1);
|
||||
if (!normalized.endsWith("/v1")) normalized = `${normalized}/v1`;
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function validateUrl(value: string): string | undefined {
|
||||
const normalized = normalizeBaseUrl(value);
|
||||
try {
|
||||
new URL(normalized);
|
||||
} catch {
|
||||
return "Enter a valid URL";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function parseModelIds(input: string): string[] {
|
||||
const parsed = input
|
||||
.split(/[\n,]/)
|
||||
.map((model) => model.trim())
|
||||
.filter(Boolean);
|
||||
return Array.from(new Set(parsed));
|
||||
}
|
||||
|
||||
function buildModelDefinition(modelId: string) {
|
||||
const isAutoModel = modelId === "higress/auto";
|
||||
return {
|
||||
id: modelId,
|
||||
name: isAutoModel ? "Higress Auto Router" : modelId,
|
||||
api: "openai-completions",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: DEFAULT_MAX_TOKENS,
|
||||
};
|
||||
}
|
||||
|
||||
async function testGatewayConnection(gatewayUrl: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${gatewayUrl}/v1/models`, {
|
||||
method: "GET",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
return response.ok || response.status === 401; // 401 means gateway is up but needs auth
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchAvailableModels(consoleUrl: string): Promise<string[]> {
|
||||
try {
|
||||
// Try to get models from Higress Console API
|
||||
const response = await fetch(`${consoleUrl}/v1/ai/routes`, {
|
||||
method: "GET",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { data?: { model?: string }[] };
|
||||
if (data.data && Array.isArray(data.data)) {
|
||||
return data.data
|
||||
.map((route: { model?: string }) => route.model)
|
||||
.filter((m): m is string => typeof m === "string");
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, use defaults
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const higressPlugin = {
|
||||
id: "higress-ai-gateway",
|
||||
name: "Higress AI Gateway",
|
||||
description: "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
configSchema: emptyPluginConfigSchema(),
|
||||
register(api) {
|
||||
api.registerProvider({
|
||||
id: "higress",
|
||||
label: "Higress AI Gateway",
|
||||
docsPath: "/providers/models",
|
||||
aliases: ["higress-gateway", "higress-ai"],
|
||||
auth: [
|
||||
{
|
||||
id: "api-key",
|
||||
label: "API Key",
|
||||
hint: "Configure Higress AI Gateway endpoint with optional API key",
|
||||
kind: "custom",
|
||||
run: async (ctx) => {
|
||||
// Step 1: Get Gateway URL
|
||||
const gatewayUrlInput = await ctx.prompter.text({
|
||||
message: "Higress AI Gateway URL",
|
||||
initialValue: DEFAULT_GATEWAY_URL,
|
||||
validate: validateUrl,
|
||||
});
|
||||
const gatewayUrl = normalizeBaseUrl(gatewayUrlInput);
|
||||
|
||||
// Step 2: Get Console URL (for auto-router configuration)
|
||||
const consoleUrlInput = await ctx.prompter.text({
|
||||
message: "Higress Console URL (for auto-router config)",
|
||||
initialValue: DEFAULT_CONSOLE_URL,
|
||||
validate: validateUrl,
|
||||
});
|
||||
const consoleUrl = normalizeBaseUrl(consoleUrlInput);
|
||||
|
||||
// Step 3: Test connection (create a new spinner)
|
||||
const spin = ctx.prompter.progress("Testing gateway connection…");
|
||||
const isConnected = await testGatewayConnection(gatewayUrl);
|
||||
if (!isConnected) {
|
||||
spin.stop("Gateway connection failed");
|
||||
await ctx.prompter.note(
|
||||
[
|
||||
"Could not connect to Higress AI Gateway.",
|
||||
"Make sure the gateway is running and the URL is correct.",
|
||||
"",
|
||||
`Tried: ${gatewayUrl}/v1/models`,
|
||||
].join("\n"),
|
||||
"Connection Warning",
|
||||
);
|
||||
} else {
|
||||
spin.stop("Gateway connected");
|
||||
}
|
||||
|
||||
// Step 4: Get API Key (optional for local gateway)
|
||||
const apiKeyInput = await ctx.prompter.text({
|
||||
message: "API Key (leave empty if not required)",
|
||||
initialValue: "",
|
||||
}) || '';
|
||||
const apiKey = apiKeyInput.trim() || "higress-local";
|
||||
|
||||
// Step 5: Fetch available models (create a new spinner)
|
||||
const spin2 = ctx.prompter.progress("Fetching available models…");
|
||||
const fetchedModels = await fetchAvailableModels(consoleUrl);
|
||||
const defaultModels = fetchedModels.length > 0
|
||||
? ["higress/auto", ...fetchedModels]
|
||||
: DEFAULT_MODEL_IDS;
|
||||
spin2.stop();
|
||||
|
||||
// Step 6: Let user customize model list
|
||||
const modelInput = await ctx.prompter.text({
|
||||
message: "Model IDs (comma-separated, higress/auto enables auto-routing)",
|
||||
initialValue: defaultModels.slice(0, 10).join(", "),
|
||||
validate: (value) =>
|
||||
parseModelIds(value).length > 0 ? undefined : "Enter at least one model id",
|
||||
});
|
||||
|
||||
const modelIds = parseModelIds(modelInput);
|
||||
const hasAutoModel = modelIds.includes("higress/auto");
|
||||
|
||||
// FIX: Avoid double prefix - if modelId already starts with provider, don't add prefix again
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "qwen-turbo");
|
||||
const defaultModelRef = defaultModelId.startsWith("higress/")
|
||||
? defaultModelId
|
||||
: `higress/${defaultModelId}`;
|
||||
|
||||
// Step 7: Configure default model for auto-routing
|
||||
let autoRoutingDefaultModel = "qwen-turbo";
|
||||
if (hasAutoModel) {
|
||||
const autoRoutingModelInput = await ctx.prompter.text({
|
||||
message: "Default model for auto-routing (when no rule matches)",
|
||||
initialValue: "qwen-turbo",
|
||||
});
|
||||
autoRoutingDefaultModel = autoRoutingModelInput.trim(); // FIX: Add trim() here
|
||||
}
|
||||
|
||||
return {
|
||||
profiles: [
|
||||
{
|
||||
profileId: `higress:${apiKey === "higress-local" ? "local" : "default"}`,
|
||||
credential: {
|
||||
type: "token",
|
||||
provider: "higress",
|
||||
token: apiKey,
|
||||
},
|
||||
},
|
||||
],
|
||||
configPatch: {
|
||||
models: {
|
||||
providers: {
|
||||
higress: {
|
||||
baseUrl: `${gatewayUrl}/v1`,
|
||||
apiKey: apiKey,
|
||||
api: "openai-completions",
|
||||
authHeader: apiKey !== "higress-local",
|
||||
models: modelIds.map((modelId) => buildModelDefinition(modelId)),
|
||||
},
|
||||
},
|
||||
},
|
||||
agents: {
|
||||
defaults: {
|
||||
models: Object.fromEntries(
|
||||
modelIds.map((modelId) => {
|
||||
// FIX: Avoid double prefix - only add provider prefix if not already present
|
||||
const modelRef = modelId.startsWith("higress/")
|
||||
? modelId
|
||||
: `higress/${modelId}`;
|
||||
return [modelRef, {}];
|
||||
}),
|
||||
),
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
entries: {
|
||||
"higress-ai-gateway": {
|
||||
enabled: true,
|
||||
config: {
|
||||
gatewayUrl,
|
||||
consoleUrl,
|
||||
autoRoutingDefaultModel,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
defaultModel: defaultModelRef,
|
||||
notes: [
|
||||
"Higress AI Gateway is now configured as a model provider.",
|
||||
hasAutoModel
|
||||
? `Auto-routing enabled: use model "higress/auto" to route based on message content.`
|
||||
: "Add 'higress/auto' to models to enable auto-routing.",
|
||||
`Gateway endpoint: ${gatewayUrl}/v1/chat/completions`,
|
||||
`Console: ${consoleUrl}`,
|
||||
"",
|
||||
"🎯 Recommended Skills (install via Clawdbot conversation):",
|
||||
"",
|
||||
"1. Auto-Routing Skill:",
|
||||
" Configure automatic model routing based on message content",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/higress-auto-router",
|
||||
' Say: "Install higress-auto-router skill"',
|
||||
"",
|
||||
"2. Agent Session Monitor Skill:",
|
||||
" Track token usage and monitor conversation history",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/agent-session-monitor",
|
||||
' Say: "Install agent-session-monitor skill"',
|
||||
],
|
||||
};
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export default higressPlugin;
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"name": "@higress/higress-ai-gateway",
|
||||
"version": "1.0.0",
|
||||
"description": "Higress AI Gateway model provider plugin for Clawdbot with auto-routing support",
|
||||
"main": "index.ts",
|
||||
"clawdbot": {
|
||||
"extensions": ["./index.ts"]
|
||||
},
|
||||
"keywords": [
|
||||
"clawdbot",
|
||||
"higress",
|
||||
"ai-gateway",
|
||||
"model-router",
|
||||
"auto-routing"
|
||||
],
|
||||
"author": "Higress Team",
|
||||
"license": "Apache-2.0",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/alibaba/higress"
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
# Higress AI Gateway Plugin
|
||||
|
||||
OpenClaw/Clawdbot model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables Clawdbot and OpenClaw to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and OpenClaw extension declaration
|
||||
- **openclaw.plugin.json**: Plugin manifest for OpenClaw
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-clawdbot-integration` skill. See the parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Detect runtime
|
||||
if command -v clawdbot &> /dev/null; then
|
||||
RUNTIME_DIR="$HOME/.clawdbot"
|
||||
elif command -v openclaw &> /dev/null; then
|
||||
RUNTIME_DIR="$HOME/.openclaw"
|
||||
else
|
||||
echo "Error: Neither clawdbot nor openclaw is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Copy plugin files
|
||||
mkdir -p "$RUNTIME_DIR/extensions/higress-ai-gateway"
|
||||
cp -r ./* "$RUNTIME_DIR/extensions/higress-ai-gateway/"
|
||||
|
||||
# Configure provider
|
||||
clawdbot models auth login --provider higress
|
||||
# or
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
clawdbot models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
## Auto-routing
|
||||
|
||||
To use auto-routing, include `higress/auto` in your model list during configuration. Then use it in your conversations:
|
||||
|
||||
```bash
|
||||
# Use auto-routing
|
||||
clawdbot chat --model higress/auto "深入思考 这个问题应该怎么解决?"
|
||||
|
||||
# The gateway will automatically route to the appropriate model based on:
|
||||
# - Message content triggers (configured via higress-auto-router skill)
|
||||
# - Fallback to default model if no rule matches
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-clawdbot-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
- **Session Monitoring**: [agent-session-monitor](../../agent-session-monitor/SKILL.md)
|
||||
- **Higress AI Gateway**: https://github.com/higress-group/higress-standalone
|
||||
|
||||
## Compatibility
|
||||
|
||||
- **OpenClaw**: v2.0.0+
|
||||
- **Clawdbot**: v2.0.0+
|
||||
- **Higress AI Gateway**: All versions
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
259
.claude/skills/higress-openclaw-integration/SKILL.md
Normal file
259
.claude/skills/higress-openclaw-integration/SKILL.md
Normal file
@@ -0,0 +1,259 @@
|
||||
---
|
||||
name: higress-openclaw-integration
|
||||
description: "Deploy and configure Higress AI Gateway for OpenClaw integration. Use when: (1) User wants to deploy Higress AI Gateway, (2) User wants to configure OpenClaw to use more model providers, (3) User mentions 'higress', 'ai gateway', 'model gateway', 'AI网关', (4) User wants to set up model routing or auto-routing, (5) User needs to manage LLM provider API keys."
|
||||
---
|
||||
|
||||
# Higress AI Gateway Integration
|
||||
|
||||
Deploy Higress AI Gateway and configure OpenClaw to use it as a unified model provider.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Collect Information from User
|
||||
|
||||
**Ask the user for the following information upfront:**
|
||||
|
||||
1. **Which LLM provider(s) to use?** (at least one required)
|
||||
|
||||
**Commonly Used Providers:**
|
||||
|
||||
| Provider | Parameter | Notes |
|
||||
|----------|-----------|-------|
|
||||
| 智谱 / z.ai | `--zhipuai-key` | Models: glm-*, Code Plan mode enabled by default |
|
||||
| Claude Code | `--claude-code-key` | **Requires OAuth token from `claude setup-token`** |
|
||||
| Moonshot (Kimi) | `--moonshot-key` | Models: moonshot-*, kimi-* |
|
||||
| Minimax | `--minimax-key` | Models: abab-* |
|
||||
| 阿里云通义千问 (Dashscope) | `--dashscope-key` | Models: qwen* |
|
||||
| OpenAI | `--openai-key` | Models: gpt-*, o1-*, o3-* |
|
||||
| DeepSeek | `--deepseek-key` | Models: deepseep-* |
|
||||
| Grok | `--grok-key` | Models: grok-* |
|
||||
|
||||
**Other Providers:**
|
||||
|
||||
| Provider | Parameter | Notes |
|
||||
|----------|-----------|-------|
|
||||
| Claude | `--claude-key` | Models: claude-* |
|
||||
| Google Gemini | `--gemini-key` | Models: gemini-* |
|
||||
| OpenRouter | `--openrouter-key` | Supports all models (catch-all) |
|
||||
| Groq | `--groq-key` | Fast inference |
|
||||
| Doubao (豆包) | `--doubao-key` | Models: doubao-* |
|
||||
| Mistral | `--mistral-key` | Models: mistral-* |
|
||||
| Baichuan (百川) | `--baichuan-key` | Models: Baichuan* |
|
||||
| 01.AI (Yi) | `--yi-key` | Models: yi-* |
|
||||
| Stepfun (阶跃星辰) | `--stepfun-key` | Models: step-* |
|
||||
| Cohere | `--cohere-key` | Models: command* |
|
||||
| Fireworks AI | `--fireworks-key` | - |
|
||||
| Together AI | `--togetherai-key` | - |
|
||||
| GitHub Models | `--github-key` | - |
|
||||
|
||||
**Cloud Providers (require additional config):**
|
||||
- Azure OpenAI: `--azure-key` (requires service URL)
|
||||
- AWS Bedrock: `--bedrock-key` (requires region and access key)
|
||||
- Google Vertex AI: `--vertex-key` (requires project ID and region)
|
||||
|
||||
**Brand Name Display (z.ai / 智谱):**
|
||||
- If user communicates in Chinese: display as "智谱"
|
||||
- If user communicates in English: display as "z.ai"
|
||||
|
||||
2. **Enable auto-routing?** (recommended)
|
||||
- If yes: `--auto-routing --auto-routing-default-model <model-name>`
|
||||
- Auto-routing allows using `model="higress/auto"` to automatically route requests based on message content
|
||||
|
||||
3. **Custom ports?** (optional, defaults: HTTP=8080, HTTPS=8443, Console=8001)
|
||||
|
||||
### Step 2: Deploy Gateway
|
||||
|
||||
**Auto-detect region for z.ai / 智谱 domain configuration:**
|
||||
|
||||
When user selects z.ai / 智谱 provider, detect their region:
|
||||
|
||||
```bash
|
||||
# Run region detection script (scripts/detect-region.sh relative to skill directory)
|
||||
REGION=$(bash scripts/detect-region.sh)
|
||||
# Output: "china" or "international"
|
||||
```
|
||||
|
||||
**Based on detection result:**
|
||||
|
||||
- If `REGION="china"`: use default domain `open.bigmodel.cn`, no extra parameter needed
|
||||
- If `REGION="international"`: automatically add `--zhipuai-domain api.z.ai` to deployment command
|
||||
|
||||
**After deployment (for international users):**
|
||||
Notify user in English: "The z.ai endpoint domain has been set to api.z.ai. If you want to change it, let me know and I can update the configuration."
|
||||
|
||||
```bash
|
||||
# Create installation directory
|
||||
mkdir -p higress-install
|
||||
cd higress-install
|
||||
|
||||
# Download script (if not exists)
|
||||
curl -fsSL https://higress.ai/ai-gateway/install.sh -o get-ai-gateway.sh
|
||||
chmod +x get-ai-gateway.sh
|
||||
|
||||
# Deploy with user's configuration
|
||||
# For z.ai / 智谱: always include --zhipuai-code-plan-mode
|
||||
# For non-China users: include --zhipuai-domain api.z.ai
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--<provider>-key <api-key> \
|
||||
[--auto-routing --auto-routing-default-model <model>]
|
||||
```
|
||||
|
||||
**z.ai / 智谱 Options:**
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `--zhipuai-code-plan-mode` | Enable Code Plan mode (enabled by default) |
|
||||
| `--zhipuai-domain <domain>` | Custom domain, default: `open.bigmodel.cn` (China), `api.z.ai` (international) |
|
||||
|
||||
**Example (China user):**
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--zhipuai-key sk-xxx \
|
||||
--zhipuai-code-plan-mode \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model glm-5
|
||||
```
|
||||
|
||||
**Example (International user):**
|
||||
```bash
|
||||
./get-ai-gateway.sh start --non-interactive \
|
||||
--zhipuai-key sk-xxx \
|
||||
--zhipuai-domain api.z.ai \
|
||||
--zhipuai-code-plan-mode \
|
||||
--auto-routing \
|
||||
--auto-routing-default-model glm-5
|
||||
```
|
||||
|
||||
### Step 3: Install OpenClaw Plugin
|
||||
|
||||
Install the Higress provider plugin for OpenClaw:
|
||||
|
||||
```bash
|
||||
# Copy plugin files (PLUGIN_SRC is relative to skill directory: scripts/plugin)
|
||||
PLUGIN_SRC="scripts/plugin"
|
||||
PLUGIN_DEST="$HOME/.openclaw/extensions/higress"
|
||||
|
||||
mkdir -p "$PLUGIN_DEST"
|
||||
cp -r "$PLUGIN_SRC"/* "$PLUGIN_DEST/"
|
||||
```
|
||||
|
||||
**Tell user to run the following commands manually in their terminal (interactive commands, cannot be executed by AI agent):**
|
||||
|
||||
```bash
|
||||
# Step 1: Enable the plugin
|
||||
openclaw plugins enable higress
|
||||
|
||||
# Step 2: Configure provider (interactive - will prompt for Gateway URL, API Key, models, etc.)
|
||||
openclaw models auth login --provider higress --set-default
|
||||
|
||||
# Step 3: Restart OpenClaw gateway to apply changes
|
||||
openclaw gateway restart
|
||||
```
|
||||
|
||||
The `openclaw models auth login` command will interactively prompt for:
|
||||
1. Gateway URL (default: `http://localhost:8080`)
|
||||
2. Console URL (default: `http://localhost:8001`)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using `higress/auto`)
|
||||
|
||||
After configuration and restart, Higress models are available in OpenClaw with `higress/` prefix (e.g., `higress/glm-5`, `higress/auto`).
|
||||
|
||||
**Future Configuration Updates (No Restart Needed)**
|
||||
|
||||
After the initial setup, you can manage your configuration through conversation with OpenClaw:
|
||||
|
||||
- **Add New Providers**: Add new LLM providers (e.g., DeepSeek, OpenAI, Claude) and their models dynamically.
|
||||
- **Update API Keys**: Update existing provider API keys without service restart.
|
||||
- **Configure Auto-routing**: If you've set up multiple models, ask OpenClaw to configure auto-routing rules. Requests will be intelligently routed based on your message content, using the most suitable model automatically.
|
||||
|
||||
All configuration changes are hot-loaded through Higress — no `openclaw gateway restart` required. Iterate on your model provider setup dynamically without service interruption!
|
||||
|
||||
## Post-Deployment Management
|
||||
|
||||
### Add/Update API Keys (Hot-reload)
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh config add --provider <provider> --key <api-key>
|
||||
./get-ai-gateway.sh config list
|
||||
./get-ai-gateway.sh config remove --provider <provider>
|
||||
```
|
||||
|
||||
Provider aliases: `dashscope`/`qwen`, `moonshot`/`kimi`, `zhipuai`/`zhipu`
|
||||
|
||||
### Update z.ai Domain (Hot-reload)
|
||||
|
||||
If user wants to change the z.ai domain after deployment:
|
||||
|
||||
```bash
|
||||
# Update domain configuration
|
||||
./get-ai-gateway.sh config add --provider zhipuai --extra-config "zhipuDomain=api.z.ai"
|
||||
# Or revert to China endpoint
|
||||
./get-ai-gateway.sh config add --provider zhipuai --extra-config "zhipuDomain=open.bigmodel.cn"
|
||||
```
|
||||
|
||||
### Add Routing Rules (for auto-routing)
|
||||
|
||||
```bash
|
||||
# Add rule: route to specific model when message starts with trigger
|
||||
./get-ai-gateway.sh route add --model <model> --trigger "keyword1|keyword2"
|
||||
|
||||
# Examples
|
||||
./get-ai-gateway.sh route add --model glm-4-flash --trigger "quick|fast"
|
||||
./get-ai-gateway.sh route add --model claude-opus-4 --trigger "think|complex"
|
||||
./get-ai-gateway.sh route add --model deepseek-coder --trigger "code|debug"
|
||||
|
||||
# List/remove rules
|
||||
./get-ai-gateway.sh route list
|
||||
./get-ai-gateway.sh route remove --rule-id 0
|
||||
```
|
||||
|
||||
### Stop/Delete Gateway
|
||||
|
||||
```bash
|
||||
./get-ai-gateway.sh stop
|
||||
./get-ai-gateway.sh delete
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Endpoint | URL |
|
||||
|----------|-----|
|
||||
| Chat Completions | http://localhost:8080/v1/chat/completions |
|
||||
| Console | http://localhost:8001 |
|
||||
| Logs | `./higress-install/logs/access.log` |
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Test with specific model
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "<model-name>", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
|
||||
# Test auto-routing (if enabled)
|
||||
curl 'http://localhost:8080/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"model": "higress/auto", "messages": [{"role": "user", "content": "What is AI?"}]}'
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Container fails to start | Check `docker logs higress-ai-gateway` |
|
||||
| Port already in use | Use `--http-port`, `--console-port` to change ports |
|
||||
| API key error | Run `./get-ai-gateway.sh config list` to verify keys |
|
||||
| Auto-routing not working | Ensure `--auto-routing` was set during deployment |
|
||||
| Slow image download | Script auto-selects nearest registry based on timezone |
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **Claude Code Mode**: Requires OAuth token from `claude setup-token` command, not a regular API key
|
||||
2. **z.ai Code Plan Mode**: Enabled by default, uses `/api/coding/paas/v4/chat/completions` endpoint, optimized for coding tasks
|
||||
3. **z.ai Domain Selection**:
|
||||
- China users: `open.bigmodel.cn` (default)
|
||||
- International users: `api.z.ai` (auto-detected based on timezone)
|
||||
- Users can update domain anytime after deployment
|
||||
4. **Auto-routing**: Must be enabled during initial deployment (`--auto-routing`); routing rules can be added later
|
||||
5. **OpenClaw Integration**: The `openclaw models auth login` and `openclaw gateway restart` commands are **interactive** and must be run by the user manually in their terminal
|
||||
6. **Hot-reload**: API key changes take effect immediately; no container restart needed
|
||||
15
.claude/skills/higress-openclaw-integration/scripts/detect-region.sh
Executable file
15
.claude/skills/higress-openclaw-integration/scripts/detect-region.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
# Detect if user is in China region based on timezone
|
||||
# Returns: "china" or "international"
|
||||
|
||||
TIMEZONE=$(cat /etc/timezone 2>/dev/null || timedatectl show --property=Timezone --value 2>/dev/null || echo "Unknown")
|
||||
|
||||
# Check if timezone indicates China region (including Hong Kong)
|
||||
if [[ "$TIMEZONE" == "Asia/Shanghai" ]] || \
|
||||
[[ "$TIMEZONE" == "Asia/Hong_Kong" ]] || \
|
||||
[[ "$TIMEZONE" == *"China"* ]] || \
|
||||
[[ "$TIMEZONE" == *"Beijing"* ]]; then
|
||||
echo "china"
|
||||
else
|
||||
echo "international"
|
||||
fi
|
||||
@@ -0,0 +1,61 @@
|
||||
# Higress AI Gateway Plugin
|
||||
|
||||
OpenClaw model provider plugin for Higress AI Gateway with auto-routing support.
|
||||
|
||||
## What is this?
|
||||
|
||||
This is a TypeScript-based provider plugin that enables OpenClaw to use Higress AI Gateway as a model provider. It provides:
|
||||
|
||||
- **Auto-routing support**: Use `higress/auto` to intelligently route requests based on message content
|
||||
- **Dynamic model discovery**: Auto-detect available models from Higress Console
|
||||
- **Smart URL handling**: Automatic URL normalization and validation
|
||||
- **Flexible authentication**: Support for both local and remote gateway deployments
|
||||
|
||||
## Files
|
||||
|
||||
- **index.ts**: Main plugin implementation
|
||||
- **package.json**: NPM package metadata and OpenClaw extension declaration
|
||||
- **openclaw.plugin.json**: Plugin manifest for OpenClaw
|
||||
|
||||
## Installation
|
||||
|
||||
This plugin is automatically installed when you use the `higress-openclaw-integration` skill. See parent SKILL.md for complete installation instructions.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you need to install manually:
|
||||
|
||||
```bash
|
||||
# Copy plugin files
|
||||
mkdir -p "$HOME/.openclaw/extensions/higress"
|
||||
cp -r ./* "$HOME/.openclaw/extensions/higress/"
|
||||
|
||||
# Configure provider
|
||||
openclaw plugins enable higress
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, configure Higress as a model provider:
|
||||
|
||||
```bash
|
||||
openclaw models auth login --provider higress
|
||||
```
|
||||
|
||||
The plugin will prompt for:
|
||||
1. Gateway URL (default: http://localhost:8080)
|
||||
2. Console URL (default: http://localhost:8001)
|
||||
3. API Key (optional for local deployments)
|
||||
4. Model list (auto-detected or manually specified)
|
||||
5. Auto-routing default model (if using higress/auto)
|
||||
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **Parent Skill**: [higress-openclaw-integration](../SKILL.md)
|
||||
- **Auto-routing Configuration**: [higress-auto-router](../../higress-auto-router/SKILL.md)
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
@@ -2,33 +2,47 @@ import { emptyPluginConfigSchema } from "openclaw/plugin-sdk";
|
||||
|
||||
const DEFAULT_GATEWAY_URL = "http://localhost:8080";
|
||||
const DEFAULT_CONSOLE_URL = "http://localhost:8001";
|
||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||
const DEFAULT_MAX_TOKENS = 8192;
|
||||
|
||||
// Model-specific context window and max tokens configurations
|
||||
const MODEL_CONFIG: Record<string, { contextWindow: number; maxTokens: number }> = {
|
||||
"gpt-5.4": { contextWindow: 1_000_000, maxTokens: 128_000 },
|
||||
"gpt-5.4-mini": { contextWindow: 400_000, maxTokens: 128_000 },
|
||||
"gpt-5.4-nano": { contextWindow: 400_000, maxTokens: 128_000 },
|
||||
"claude-opus-4-6": { contextWindow: 1_000_000, maxTokens: 128_000 },
|
||||
"claude-sonnet-4-6": { contextWindow: 1_000_000, maxTokens: 64_000 },
|
||||
"claude-haiku-4-5": { contextWindow: 200_000, maxTokens: 64_000 },
|
||||
"qwen3.5-plus": { contextWindow: 960_000, maxTokens: 64_000 },
|
||||
"deepseek-chat": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"deepseek-reasoner": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"kimi-k2.5": { contextWindow: 256_000, maxTokens: 128_000 },
|
||||
"glm-5": { contextWindow: 200_000, maxTokens: 128_000 },
|
||||
"MiniMax-M2.5": { contextWindow: 200_000, maxTokens: 128_000 },
|
||||
};
|
||||
|
||||
// Default values for unknown models
|
||||
const DEFAULT_CONTEXT_WINDOW = 200_000;
|
||||
const DEFAULT_MAX_TOKENS = 128_000;
|
||||
|
||||
// Common models that Higress AI Gateway typically supports
|
||||
const DEFAULT_MODEL_IDS = [
|
||||
// Auto-routing special model
|
||||
"higress/auto",
|
||||
// OpenAI models
|
||||
"gpt-5.2",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
// Commonly models
|
||||
"kimi-k2.5",
|
||||
"glm-5",
|
||||
"MiniMax-M2.5",
|
||||
"qwen3.5-plus",
|
||||
// Anthropic models
|
||||
"claude-opus-4.5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-haiku-4.5",
|
||||
// Qwen models
|
||||
"qwen3-turbo",
|
||||
"qwen3-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-480b-a35b-instruct",
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5",
|
||||
// OpenAI models
|
||||
"gpt-5.4",
|
||||
"gpt-5.4-mini",
|
||||
"gpt-5.4-nano",
|
||||
// DeepSeek models
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
// Other common models
|
||||
"kimi-k2.5",
|
||||
"glm-4.7",
|
||||
"MiniMax-M2.1",
|
||||
"deepseek-reasoner",
|
||||
] as const;
|
||||
|
||||
function normalizeBaseUrl(value: string): string {
|
||||
@@ -60,26 +74,33 @@ function parseModelIds(input: string): string[] {
|
||||
|
||||
function buildModelDefinition(modelId: string) {
|
||||
const isAutoModel = modelId === "higress/auto";
|
||||
const config = MODEL_CONFIG[modelId] || { contextWindow: DEFAULT_CONTEXT_WINDOW, maxTokens: DEFAULT_MAX_TOKENS };
|
||||
|
||||
return {
|
||||
id: modelId,
|
||||
name: isAutoModel ? "Higress Auto Router" : modelId,
|
||||
api: "openai-completions",
|
||||
reasoning: false,
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: DEFAULT_MAX_TOKENS,
|
||||
contextWindow: config.contextWindow,
|
||||
maxTokens: config.maxTokens,
|
||||
};
|
||||
}
|
||||
|
||||
async function testGatewayConnection(gatewayUrl: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${gatewayUrl}/v1/models`, {
|
||||
method: "GET",
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
// Use chat/completions endpoint with empty body to test connection
|
||||
// Higress doesn't support /models endpoint
|
||||
const response = await fetch(`${gatewayUrl}/chat/completions`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({}),
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
return response.ok || response.status === 401; // 401 means gateway is up but needs auth
|
||||
// Any response (including 400/401/422) means gateway is reachable
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
@@ -108,7 +129,7 @@ async function fetchAvailableModels(consoleUrl: string): Promise<string[]> {
|
||||
}
|
||||
|
||||
const higressPlugin = {
|
||||
id: "higress-ai-gateway",
|
||||
id: "higress",
|
||||
name: "Higress AI Gateway",
|
||||
description: "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
configSchema: emptyPluginConfigSchema(),
|
||||
@@ -150,8 +171,6 @@ const higressPlugin = {
|
||||
[
|
||||
"Could not connect to Higress AI Gateway.",
|
||||
"Make sure the gateway is running and the URL is correct.",
|
||||
"",
|
||||
`Tried: ${gatewayUrl}/v1/models`,
|
||||
].join("\n"),
|
||||
"Connection Warning",
|
||||
);
|
||||
@@ -184,21 +203,19 @@ const higressPlugin = {
|
||||
|
||||
const modelIds = parseModelIds(modelInput);
|
||||
const hasAutoModel = modelIds.includes("higress/auto");
|
||||
|
||||
// FIX: Avoid double prefix - if modelId already starts with provider, don't add prefix again
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "qwen-turbo");
|
||||
const defaultModelRef = defaultModelId.startsWith("higress/")
|
||||
? defaultModelId
|
||||
: `higress/${defaultModelId}`;
|
||||
|
||||
// Always add higress/ provider prefix to create model reference
|
||||
const defaultModelId = hasAutoModel
|
||||
? "higress/auto"
|
||||
: (modelIds[0] ?? "glm-5");
|
||||
const defaultModelRef = `higress/${defaultModelId}`;
|
||||
|
||||
// Step 7: Configure default model for auto-routing
|
||||
let autoRoutingDefaultModel = "qwen-turbo";
|
||||
let autoRoutingDefaultModel = "glm-5";
|
||||
if (hasAutoModel) {
|
||||
const autoRoutingModelInput = await ctx.prompter.text({
|
||||
message: "Default model for auto-routing (when no rule matches)",
|
||||
initialValue: "qwen-turbo",
|
||||
initialValue: "glm-5",
|
||||
});
|
||||
autoRoutingDefaultModel = autoRoutingModelInput.trim(); // FIX: Add trim() here
|
||||
}
|
||||
@@ -218,7 +235,8 @@ const higressPlugin = {
|
||||
models: {
|
||||
providers: {
|
||||
higress: {
|
||||
baseUrl: `${gatewayUrl}/v1`,
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
baseUrl: gatewayUrl,
|
||||
apiKey: apiKey,
|
||||
api: "openai-completions",
|
||||
authHeader: apiKey !== "higress-local",
|
||||
@@ -230,10 +248,8 @@ const higressPlugin = {
|
||||
defaults: {
|
||||
models: Object.fromEntries(
|
||||
modelIds.map((modelId) => {
|
||||
// FIX: Avoid double prefix - only add provider prefix if not already present
|
||||
const modelRef = modelId.startsWith("higress/")
|
||||
? modelId
|
||||
: `higress/${modelId}`;
|
||||
// Always add higress/ provider prefix to create model reference
|
||||
const modelRef = `higress/${modelId}`;
|
||||
return [modelRef, {}];
|
||||
}),
|
||||
),
|
||||
@@ -241,7 +257,7 @@ const higressPlugin = {
|
||||
},
|
||||
plugins: {
|
||||
entries: {
|
||||
"higress-ai-gateway": {
|
||||
"higress": {
|
||||
enabled: true,
|
||||
config: {
|
||||
gatewayUrl,
|
||||
@@ -258,20 +274,22 @@ const higressPlugin = {
|
||||
hasAutoModel
|
||||
? `Auto-routing enabled: use model "higress/auto" to route based on message content.`
|
||||
: "Add 'higress/auto' to models to enable auto-routing.",
|
||||
`Gateway endpoint: ${gatewayUrl}/v1/chat/completions`,
|
||||
// gatewayUrl already ends with /v1 from normalizeBaseUrl()
|
||||
`Gateway endpoint: ${gatewayUrl}/chat/completions`,
|
||||
`Console: ${consoleUrl}`,
|
||||
"",
|
||||
"🎯 Recommended Skills (install via Clawdbot conversation):",
|
||||
"💡 Future Configuration Updates (No Restart Needed):",
|
||||
" • Add New Providers: Add LLM providers (DeepSeek, OpenAI, Claude, etc.) dynamically.",
|
||||
" • Update API Keys: Update existing provider keys without restart.",
|
||||
" • Configure Auto-Routing: Ask OpenClaw to set up intelligent routing rules.",
|
||||
" All changes hot-load via Higress — no gateway restart required!",
|
||||
"",
|
||||
"🎯 Recommended Skills (install via OpenClaw conversation):",
|
||||
"",
|
||||
"1. Auto-Routing Skill:",
|
||||
" Configure automatic model routing based on message content",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/higress-auto-router",
|
||||
' Say: "Install higress-auto-router skill"',
|
||||
"",
|
||||
"2. Agent Session Monitor Skill:",
|
||||
" Track token usage and monitor conversation history",
|
||||
" https://github.com/alibaba/higress/tree/main/.claude/skills/agent-session-monitor",
|
||||
' Say: "Install agent-session-monitor skill"',
|
||||
],
|
||||
};
|
||||
},
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"id": "higress-ai-gateway",
|
||||
"id": "higress",
|
||||
"name": "Higress AI Gateway",
|
||||
"description": "Model provider plugin for Higress AI Gateway with auto-routing support",
|
||||
"providers": ["higress"],
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "@higress/higress-ai-gateway",
|
||||
"name": "@higress/higress",
|
||||
"version": "1.0.0",
|
||||
"description": "Higress AI Gateway model provider plugin for OpenClaw with auto-routing support",
|
||||
"main": "index.ts",
|
||||
106
.github/workflows/build-and-push-plugin-server-image.yml
vendored
Normal file
106
.github/workflows/build-and-push-plugin-server-image.yml
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
name: Build Plugin Server Image and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
plugin_server_ref:
|
||||
description: "plugin-server repo ref (branch/tag/commit, default: main)"
|
||||
required: false
|
||||
default: "main"
|
||||
type: string
|
||||
version:
|
||||
description: "Version tag (optional, without leading v)"
|
||||
required: false
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
build-plugin-server-image:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: image-registry-plugin-server
|
||||
env:
|
||||
IMAGE_REGISTRY: ${{ vars.IMAGE_REGISTRY || 'higress-registry.cn-hangzhou.cr.aliyuncs.com' }}
|
||||
IMAGE_NAME: ${{ vars.PLUGIN_SERVER_IMAGE_NAME || 'higress/plugin-server' }}
|
||||
steps:
|
||||
- name: "Clone plugin-server repository"
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: higress-group/plugin-server
|
||||
ref: ${{ github.event.inputs.plugin_server_ref || 'main' }}
|
||||
path: plugin-server
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: false
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-plugin-server-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-plugin-server-
|
||||
|
||||
- name: Determine version
|
||||
id: version
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.version }}" ]]; then
|
||||
echo "manual_version=${{ github.event.inputs.version }}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Calculate Docker metadata
|
||||
id: docker-meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=sha
|
||||
type=ref,event=tag
|
||||
type=semver,pattern={{version}}
|
||||
type=raw,value=${{ steps.version.outputs.manual_version }},enable=${{ steps.version.outputs.manual_version != '' }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Login to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.IMAGE_REGISTRY }}
|
||||
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build Docker Image and Push
|
||||
run: |
|
||||
BUILT_IMAGE=""
|
||||
readarray -t IMAGES <<< "${{ steps.docker-meta.outputs.tags }}"
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Image: $image"
|
||||
if [ "$BUILT_IMAGE" == "" ]; then
|
||||
docker buildx build \
|
||||
--platform linux/amd64,linux/arm64 \
|
||||
-t "$image" \
|
||||
-f plugin-server/Dockerfile \
|
||||
--push \
|
||||
plugin-server
|
||||
BUILT_IMAGE="$image"
|
||||
else
|
||||
docker buildx imagetools create "$BUILT_IMAGE" --tag "$image"
|
||||
fi
|
||||
done
|
||||
18
.github/workflows/build-and-test-plugin.yaml
vendored
18
.github/workflows/build-and-test-plugin.yaml
vendored
@@ -19,7 +19,7 @@ on:
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
# - run: make lint
|
||||
|
||||
higress-wasmplugin-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
# TODO(Xunzhuo): Enable C WASM Filters in CI
|
||||
@@ -38,6 +38,18 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Disable containerd image store
|
||||
run: |
|
||||
sudo bash -c 'cat > /etc/docker/daemon.json << EOF
|
||||
{
|
||||
"features": {
|
||||
"containerd-snapshotter": false
|
||||
}
|
||||
}
|
||||
EOF'
|
||||
sudo systemctl restart docker
|
||||
docker info -f '{{ .DriverStatus }}'
|
||||
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
@@ -79,7 +91,7 @@ jobs:
|
||||
command: GOPROXY="https://proxy.golang.org,direct" PLUGIN_TYPE=${{ matrix.wasmPluginType }} make higress-wasmplugin-test
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [higress-wasmplugin-test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
24
.github/workflows/build-and-test.yaml
vendored
24
.github/workflows/build-and-test.yaml
vendored
@@ -10,7 +10,7 @@ env:
|
||||
GO_VERSION: 1.24
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
# - run: make lint
|
||||
|
||||
coverage-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
build:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [lint, coverage-test]
|
||||
steps:
|
||||
- name: "Checkout ${{ github.ref }}"
|
||||
@@ -91,17 +91,29 @@ jobs:
|
||||
path: out/
|
||||
|
||||
gateway-conformance-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [build]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
higress-conformance-test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Disable containerd image store
|
||||
run: |
|
||||
sudo bash -c 'cat > /etc/docker/daemon.json << EOF
|
||||
{
|
||||
"features": {
|
||||
"containerd-snapshotter": false
|
||||
}
|
||||
}
|
||||
EOF'
|
||||
sudo systemctl restart docker
|
||||
docker info -f '{{ .DriverStatus }}'
|
||||
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
@@ -139,7 +151,7 @@ jobs:
|
||||
run: GOPROXY="https://proxy.golang.org,direct" make higress-conformance-test
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [higress-conformance-test, gateway-conformance-test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
50
.github/workflows/sync-skills-to-oss.yaml
vendored
Normal file
50
.github/workflows/sync-skills-to-oss.yaml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Sync Skills to OSS
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '.claude/skills/**'
|
||||
workflow_dispatch: ~
|
||||
|
||||
jobs:
|
||||
sync-skills-to-oss:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: oss
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Download AI Gateway Install Script
|
||||
run: |
|
||||
wget -O install.sh https://raw.githubusercontent.com/higress-group/higress-standalone/main/all-in-one/get-ai-gateway.sh
|
||||
chmod +x install.sh
|
||||
|
||||
- name: Package Skills
|
||||
run: |
|
||||
mkdir -p packaged-skills
|
||||
for skill_dir in .claude/skills/*/; do
|
||||
if [ -d "$skill_dir" ]; then
|
||||
skill_name=$(basename "$skill_dir")
|
||||
echo "Packaging $skill_name..."
|
||||
(cd "$skill_dir" && zip -r "$GITHUB_WORKSPACE/packaged-skills/${skill_name}.zip" .)
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Sync Skills to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -r -u packaged-skills/ oss://higress-ai/skills/'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
|
||||
- name: Sync Install Script to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -u install.sh oss://higress-ai/ai-gateway/install.sh'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
7
.github/workflows/wasm-plugin-unit-test.yml
vendored
7
.github/workflows/wasm-plugin-unit-test.yml
vendored
@@ -199,15 +199,14 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go 1.24
|
||||
- name: Set up Go 1.25
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.24
|
||||
go-version: 1.25
|
||||
cache: true
|
||||
|
||||
|
||||
- name: Install required tools
|
||||
run: |
|
||||
go install github.com/wadey/gocovmerge@latest
|
||||
sudo apt-get update && sudo apt-get install -y bc
|
||||
|
||||
- name: Download all test results
|
||||
|
||||
@@ -200,8 +200,8 @@ install: pre-install
|
||||
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
|
||||
|
||||
HIGRESS_LATEST_IMAGE_TAG ?= latest
|
||||
ENVOY_LATEST_IMAGE_TAG ?= ca6ff3a92e3fa592bff706894b22e0509a69757b
|
||||
ISTIO_LATEST_IMAGE_TAG ?= c482b42b9a14885bd6692c6abd01345d50a372f7
|
||||
ENVOY_LATEST_IMAGE_TAG ?= 36c1d07376bf11295edc40357d74a5ecb50122b1
|
||||
ISTIO_LATEST_IMAGE_TAG ?= 36c1d07376bf11295edc40357d74a5ecb50122b1
|
||||
|
||||
install-dev: pre-install
|
||||
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'
|
||||
|
||||
14
README.md
14
README.md
@@ -86,7 +86,19 @@ Port descriptions:
|
||||
>
|
||||
> **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.io/en-us/docs/user/quickstart).
|
||||
> **For Kubernetes deployments**, you can configure the `global.hub` parameter in Helm values to use a mirror registry closer to your region. This applies to both Higress component images and built-in Wasm plugin images:
|
||||
>
|
||||
> ```bash
|
||||
> # Example: Using North America mirror
|
||||
> helm install higress -n higress-system higress.io/higress --set global.hub=higress-registry.us-west-1.cr.aliyuncs.com --create-namespace
|
||||
> ```
|
||||
>
|
||||
> Available mirror registries:
|
||||
> - **China (Hangzhou)**: `higress-registry.cn-hangzhou.cr.aliyuncs.com` (default)
|
||||
> - **North America**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
> - **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.ai/en/docs/latest/user/quickstart/).
|
||||
|
||||
If you are deploying on the cloud, it is recommended to use the [Enterprise Edition](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
|
||||
18
README_ZH.md
18
README_ZH.md
@@ -80,6 +80,24 @@ docker run -d --rm --name higress-ai -v ${PWD}:/data \
|
||||
|
||||
**Higress 的所有 Docker 镜像都一直使用自己独享的仓库,不受 Docker Hub 境内访问受限的影响**
|
||||
|
||||
> 如果从 `higress-registry.cn-hangzhou.cr.aliyuncs.com` 拉取镜像超时,可以尝试使用以下镜像加速源:
|
||||
>
|
||||
> **北美**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
>
|
||||
> **东南亚**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
> **K8s 部署时**,可以通过 Helm values 配置 `global.hub` 参数来使用距离部署区域更近的镜像仓库,该参数会同时应用于 Higress 组件镜像和内置 Wasm 插件镜像:
|
||||
>
|
||||
> ```bash
|
||||
> # 示例:使用北美镜像源
|
||||
> helm install higress -n higress-system higress.io/higress --set global.hub=higress-registry.us-west-1.cr.aliyuncs.com --create-namespace
|
||||
> ```
|
||||
>
|
||||
> 可用镜像仓库:
|
||||
> - **中国(杭州)**: `higress-registry.cn-hangzhou.cr.aliyuncs.com`(默认)
|
||||
> - **北美**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
> - **东南亚**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start 文档](https://higress.cn/docs/latest/user/quickstart/)。
|
||||
|
||||
如果您是在云上部署,推荐使用[企业版](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
@@ -14,7 +14,7 @@ Higress Console 是 Higress 网关的管理控制台,主要功能是管理 Hig
|
||||
### 1.1 Higress Admin SDK
|
||||
|
||||
Higress Admin SDK 脱胎于 Higress Console。起初,它作为 Higress Console 的一部分,为前端界面提供实际的功能支持。后来考虑到对接外部系统等需求,将配置管理的部分剥离出来,形成一个独立的逻辑组件,便于和各个系统进行对接。目前支持服务来源管理、服务管理、路由管理、域名管理、证书管理、插件管理等功能。
|
||||
Higress Admin SDK 现在只提供 Java 版本,且要求 JDK 版本不低于 17。具体如何集成请参考 Higress 官方 BLOG [如何使用 Higress Admin SDK 进行配置管理](https://higress.io/zh-cn/blog/admin-sdk-intro)。
|
||||
Higress Admin SDK 现在只提供 Java 版本,且要求 JDK 版本不低于 17。具体如何集成请参考 Higress 官方 BLOG [如何使用 Higress Admin SDK 进行配置管理](https://higress.ai/blog/admin-sdk-intro)。
|
||||
|
||||
## 2 Higress Controller
|
||||
|
||||
|
||||
Submodule envoy/envoy updated: b46236685e...43287ff203
2
go.mod
2
go.mod
@@ -31,7 +31,7 @@ require (
|
||||
github.com/hudl/fargo v1.4.0
|
||||
github.com/mholt/acmez v1.2.0
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.5
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/pflag v1.0.7
|
||||
github.com/stretchr/testify v1.11.1
|
||||
|
||||
4
go.sum
4
go.sum
@@ -3688,8 +3688,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8 h1:8pEm05Cdav9sQgJSv5kyvlgfz0SzFUUGI3pWX6SiSnM=
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2 h1:9QB2nCJzT5wkTVlxNYl3XL/7+G6p2USMi2gQh/ouQQo=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.5 h1:Hux7C4N4rWhwBF5Zm4yyYskrs9VTgrRTA8DZjoEhQTs=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.5/go.mod h1:ygUBdt7eGeYBt6Lz2HO3wx7crKXk25Mp80568emGMWU=
|
||||
github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg=
|
||||
github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU=
|
||||
github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.2.0
|
||||
appVersion: 2.2.1
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -15,4 +15,4 @@ dependencies:
|
||||
repository: "file://../redis"
|
||||
version: 0.0.1
|
||||
type: application
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
|
||||
@@ -23,7 +23,7 @@ spec:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: "{{ .Values.global.hub }}/{{ .Values.redis.image | default "redis-stack-server" }}:{{ .Values.redis.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.global.hub }}/higress/{{ .Values.redis.image | default "redis-stack-server" }}:{{ .Values.redis.tag | default .Chart.AppVersion }}"
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
|
||||
@@ -20,6 +20,11 @@ template:
|
||||
{{- end }}
|
||||
{{- include "gateway.selectorLabels" . | nindent 6 }}
|
||||
spec:
|
||||
{{- if .Values.gateway.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.gateway.imagePullPolicy }}
|
||||
{{- else if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 6 }}
|
||||
@@ -39,7 +44,7 @@ template:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: higress-gateway
|
||||
image: "{{ .Values.gateway.hub | default .Values.global.hub }}/{{ .Values.gateway.image | default "gateway" }}:{{ .Values.gateway.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.gateway.hub | default .Values.global.hub }}/higress/{{ .Values.gateway.image | default "gateway" }}:{{ .Values.gateway.tag | default .Chart.AppVersion }}"
|
||||
args:
|
||||
- proxy
|
||||
- router
|
||||
@@ -205,7 +210,7 @@ template:
|
||||
{{- if $o11y.enabled }}
|
||||
{{- $config := $o11y.promtail }}
|
||||
- name: promtail
|
||||
image: {{ $config.image.repository | default (printf "%s/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
|
||||
image: {{ $config.image.repository | default (printf "%s/higress/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
|
||||
imagePullPolicy: IfNotPresent
|
||||
args:
|
||||
- -config.file=/etc/promtail/promtail.yaml
|
||||
|
||||
@@ -38,7 +38,7 @@ spec:
|
||||
- name: {{ .Chart.Name }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.controller.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.controller.hub | default .Values.global.hub }}/{{ .Values.controller.image | default "higress" }}:{{ .Values.controller.tag | default .Chart.AppVersion }}"
|
||||
image: "{{ .Values.controller.hub | default .Values.global.hub }}/higress/{{ .Values.controller.image | default "higress" }}:{{ .Values.controller.tag | default .Chart.AppVersion }}"
|
||||
args:
|
||||
- "serve"
|
||||
- --gatewaySelectorKey=higress
|
||||
@@ -104,8 +104,10 @@ spec:
|
||||
- name: log
|
||||
mountPath: /var/log
|
||||
- name: discovery
|
||||
image: "{{ .Values.pilot.hub | default .Values.global.hub }}/{{ .Values.pilot.image | default "pilot" }}:{{ .Values.pilot.tag | default .Chart.AppVersion }}"
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
image: "{{ .Values.pilot.hub | default .Values.global.hub }}/higress/{{ .Values.pilot.image | default "pilot" }}:{{ .Values.pilot.tag | default .Chart.AppVersion }}"
|
||||
{{- if .Values.controller.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.controller.imagePullPolicy }}
|
||||
{{- else if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
args:
|
||||
|
||||
@@ -23,8 +23,10 @@ spec:
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: {{ .Values.pluginServer.hub | default .Values.global.hub }}/{{ .Values.pluginServer.image | default "plugin-server" }}:{{ .Values.pluginServer.tag | default "1.0.0" }}
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
image: {{ .Values.pluginServer.hub | default .Values.global.hub }}/higress/{{ .Values.pluginServer.image | default "plugin-server" }}:{{ .Values.pluginServer.tag | default "1.0.0" }}
|
||||
{{- if .Values.pluginServer.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.pluginServer.imagePullPolicy }}
|
||||
{{- else if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
ports:
|
||||
|
||||
@@ -70,10 +70,14 @@ global:
|
||||
# cpu: 100m
|
||||
# memory: 128Mi
|
||||
|
||||
# -- Default hub for Istio images.
|
||||
# Releases are published to docker hub under 'istio' project.
|
||||
# Dev builds from prow are on gcr.io
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
# -- Default hub (registry) for Higress images.
|
||||
# For Higress deployments, images are pulled from: {hub}/higress/{image}
|
||||
# For built-in plugins, images are pulled from: {hub}/{pluginNamespace}/{plugin-name}
|
||||
# Change this to use a mirror registry closer to your deployment region for faster image pulls.
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com
|
||||
# -- Namespace for built-in plugin images. Default is "plugins".
|
||||
# Used by higress-console to configure plugin image path.
|
||||
pluginNamespace: "plugins"
|
||||
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
@@ -419,6 +423,10 @@ gateway:
|
||||
replicas: 2
|
||||
image: gateway
|
||||
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
imagePullPolicy: ""
|
||||
|
||||
# -- Use a `DaemonSet` or `Deployment`
|
||||
kind: Deployment
|
||||
|
||||
@@ -573,6 +581,10 @@ controller:
|
||||
periodSeconds: 3
|
||||
timeoutSeconds: 5
|
||||
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
imagePullPolicy: ""
|
||||
|
||||
imagePullSecrets: []
|
||||
|
||||
rbac:
|
||||
@@ -648,13 +660,6 @@ controller:
|
||||
|
||||
## -- Discovery Settings
|
||||
pilot:
|
||||
autoscaleEnabled: false
|
||||
autoscaleMin: 1
|
||||
autoscaleMax: 5
|
||||
replicaCount: 1
|
||||
rollingMaxSurge: 100%
|
||||
rollingMaxUnavailable: 25%
|
||||
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
|
||||
@@ -682,21 +687,11 @@ pilot:
|
||||
# -- if protocol sniffing is enabled for inbound
|
||||
enableProtocolSniffingForInbound: true
|
||||
|
||||
nodeSelector: {}
|
||||
podAnnotations: {}
|
||||
serviceAnnotations: {}
|
||||
|
||||
# -- You can use jwksResolverExtraRootCA to provide a root certificate
|
||||
# in PEM format. This will then be trusted by pilot when resolving
|
||||
# JWKS URIs.
|
||||
jwksResolverExtraRootCA: ""
|
||||
|
||||
# -- This is used to set the source of configuration for
|
||||
# the associated address in configSource, if nothing is specified
|
||||
# the default MCP is assumed.
|
||||
configSource:
|
||||
subscribedResources: []
|
||||
|
||||
plugins: []
|
||||
|
||||
# -- The following is used to limit how long a sidecar can be connected
|
||||
@@ -704,18 +699,6 @@ pilot:
|
||||
# increasing system churn.
|
||||
keepaliveMaxServerConnectionAge: 30m
|
||||
|
||||
# -- Additional labels to apply to the deployment.
|
||||
deploymentLabels: {}
|
||||
|
||||
## Mesh config settings
|
||||
|
||||
# -- Install the mesh config map, generated from values.yaml.
|
||||
# If false, pilot wil use default values (by default) or user-supplied values.
|
||||
configMap: true
|
||||
|
||||
# -- Additional labels to apply on the pod level for monitoring and logging configuration.
|
||||
podLabels: {}
|
||||
|
||||
# Tracing config settings
|
||||
tracing:
|
||||
enable: false
|
||||
@@ -811,6 +794,10 @@ pluginServer:
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
imagePullPolicy: ""
|
||||
|
||||
imagePullSecrets: []
|
||||
|
||||
labels: {}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 2.2.0
|
||||
digest: sha256:2cb148fa6d52856344e1905d3fea018466c2feb52013e08997c2d5c7d50f2e5d
|
||||
generated: "2026-02-11T17:45:59.187965929+08:00"
|
||||
version: 2.2.1
|
||||
digest: sha256:b74e3b6f0b00364a155532fd825398e0ff856f13ec90a256e05bbd9c6bead653
|
||||
generated: "2026-04-09T17:30:46.726657+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.2.0
|
||||
appVersion: 2.2.1
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
type: application
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
|
||||
@@ -46,6 +46,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| controller.env | object | `{}` | |
|
||||
| controller.hub | string | `""` | |
|
||||
| controller.image | string | `"higress"` | |
|
||||
| controller.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
|
||||
| controller.imagePullSecrets | list | `[]` | |
|
||||
| controller.labels | object | `{}` | |
|
||||
| controller.name | string | `"higress-controller"` | |
|
||||
@@ -98,6 +99,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.httpsPort | int | `443` | |
|
||||
| gateway.hub | string | `""` | |
|
||||
| gateway.image | string | `"gateway"` | |
|
||||
| gateway.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
|
||||
| gateway.kind | string | `"Deployment"` | Use a `DaemonSet` or `Deployment` |
|
||||
| gateway.labels | object | `{}` | Labels to apply to all resources |
|
||||
| gateway.metrics.enabled | bool | `false` | If true, create PodMonitor or VMPodScrape for gateway |
|
||||
@@ -178,7 +180,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. |
|
||||
| global.externalIstiod | bool | `false` | Configure a remote cluster data plane controlled by an external istiod. When set to true, istiod is not deployed locally and only a subset of the other discovery charts are enabled. |
|
||||
| global.hostRDSMergeSubset | bool | `false` | |
|
||||
| global.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | Default hub for Istio images. Releases are published to docker hub under 'istio' project. Dev builds from prow are on gcr.io |
|
||||
| global.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com"` | Default hub (registry) for Higress images. For Higress deployments, images are pulled from: {hub}/higress/{image} For built-in plugins, images are pulled from: {hub}/{pluginNamespace}/{plugin-name} Change this to use a mirror registry closer to your deployment region for faster image pulls. |
|
||||
| global.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
|
||||
| global.imagePullSecrets | list | `[]` | ImagePullSecrets for all ServiceAccount, list of secrets in the same namespace to use for pulling any images in pods that reference this ServiceAccount. For components that don't use ServiceAccounts (i.e. grafana, servicegraph, tracing) ImagePullSecrets will be added to the corresponding Deployment(StatefulSet) objects. Must be set for any cluster configured with private docker registry. |
|
||||
| global.ingressClass | string | `"higress"` | IngressClass filters which ingress resources the higress controller watches. The default ingress class is higress. There are some special cases for special ingress class. 1. When the ingress class is set as nginx, the higress controller will watch ingress resources with the nginx ingress class or without any ingress class. 2. When the ingress class is set empty, the higress controller will watch all ingress resources in the k8s cluster. |
|
||||
@@ -203,6 +205,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.onlyPushRouteCluster | bool | `true` | |
|
||||
| global.operatorManageWebhooks | bool | `false` | Configure whether Operator manages webhook configurations. The current behavior of Istiod is to manage its own webhook configurations. When this option is set as true, Istio Operator, instead of webhooks, manages the webhook configurations. When this option is set as false, webhooks manage their own webhook configurations. |
|
||||
| global.pilotCertProvider | string | `"istiod"` | Configure the certificate provider for control plane communication. Currently, two providers are supported: "kubernetes" and "istiod". As some platforms may not have kubernetes signing APIs, Istiod is the default |
|
||||
| global.pluginNamespace | string | `"plugins"` | Namespace for built-in plugin images. Default is "plugins". Used by higress-console to configure plugin image path. |
|
||||
| global.priorityClassName | string | `""` | Kubernetes >=v1.11.0 will create two PriorityClass, including system-cluster-critical and system-node-critical, it is better to configure this in order to make sure your Istio pods will not be killed because of low priority class. Refer to https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption/#priorityclass for more detail. |
|
||||
| global.proxy.autoInject | string | `"enabled"` | This controls the 'policy' in the sidecar injector. |
|
||||
| global.proxy.clusterDomain | string | `"cluster.local"` | CAUTION: It is important to ensure that all Istio helm charts specify the same clusterDomain value cluster domain. Default value is "cluster.local". |
|
||||
@@ -252,13 +255,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| meshConfig | object | `{"enablePrometheusMerge":true,"rootNamespace":null,"trustDomain":"cluster.local"}` | meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior See https://istio.io/docs/reference/config/istio.mesh.v1alpha1/ for all available options |
|
||||
| meshConfig.rootNamespace | string | `nil` | The namespace to treat as the administrative root namespace for Istio configuration. When processing a leaf namespace Istio will search for declarations in that namespace first and if none are found it will search in the root namespace. Any matching declaration found in the root namespace is processed as if it were declared in the leaf namespace. |
|
||||
| meshConfig.trustDomain | string | `"cluster.local"` | The trust domain corresponds to the trust root of a system Refer to https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE-ID.md#21-trust-domain |
|
||||
| pilot.autoscaleEnabled | bool | `false` | |
|
||||
| pilot.autoscaleMax | int | `5` | |
|
||||
| pilot.autoscaleMin | int | `1` | |
|
||||
| pilot.configMap | bool | `true` | Install the mesh config map, generated from values.yaml. If false, pilot wil use default values (by default) or user-supplied values. |
|
||||
| pilot.configSource | object | `{"subscribedResources":[]}` | This is used to set the source of configuration for the associated address in configSource, if nothing is specified the default MCP is assumed. |
|
||||
| pilot.cpu.targetAverageUtilization | int | `80` | |
|
||||
| pilot.deploymentLabels | object | `{}` | Additional labels to apply to the deployment. |
|
||||
| pilot.enableProtocolSniffingForInbound | bool | `true` | if protocol sniffing is enabled for inbound |
|
||||
| pilot.enableProtocolSniffingForOutbound | bool | `true` | if protocol sniffing is enabled for outbound |
|
||||
| pilot.env.PILOT_ENABLE_CROSS_CLUSTER_WORKLOAD_ENTRY | string | `"false"` | |
|
||||
@@ -269,19 +266,13 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| pilot.image | string | `"pilot"` | Can be a full hub/image:tag |
|
||||
| pilot.jwksResolverExtraRootCA | string | `""` | You can use jwksResolverExtraRootCA to provide a root certificate in PEM format. This will then be trusted by pilot when resolving JWKS URIs. |
|
||||
| pilot.keepaliveMaxServerConnectionAge | string | `"30m"` | The following is used to limit how long a sidecar can be connected to a pilot. It balances out load across pilot instances at the cost of increasing system churn. |
|
||||
| pilot.nodeSelector | object | `{}` | |
|
||||
| pilot.plugins | list | `[]` | |
|
||||
| pilot.podAnnotations | object | `{}` | |
|
||||
| pilot.podLabels | object | `{}` | Additional labels to apply on the pod level for monitoring and logging configuration. |
|
||||
| pilot.replicaCount | int | `1` | |
|
||||
| pilot.resources | object | `{"requests":{"cpu":"500m","memory":"2048Mi"}}` | Resources for a small pilot install |
|
||||
| pilot.rollingMaxSurge | string | `"100%"` | |
|
||||
| pilot.rollingMaxUnavailable | string | `"25%"` | |
|
||||
| pilot.serviceAnnotations | object | `{}` | |
|
||||
| pilot.tag | string | `""` | |
|
||||
| pilot.traceSampling | float | `1` | |
|
||||
| pluginServer.hub | string | `""` | |
|
||||
| pluginServer.image | string | `"plugin-server"` | |
|
||||
| pluginServer.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
|
||||
| pluginServer.imagePullSecrets | list | `[]` | |
|
||||
| pluginServer.labels | object | `{}` | |
|
||||
| pluginServer.name | string | `"higress-plugin-server"` | |
|
||||
|
||||
Submodule istio/api updated: 5b9a222e72...efc0fe428c
Submodule istio/istio updated: 77149ea560...1778761e3d
@@ -53,7 +53,7 @@ func (p *TemplateProcessor) ProcessConfig(cfg *config.Config) error {
|
||||
configStr := string(jsonBytes)
|
||||
// Find all value references in format:
|
||||
// ${type.name.key} or ${type.namespace/name.key}
|
||||
valueRegex := regexp.MustCompile(`\$\{([^.}]+)\.(?:([^/]+)/)?([^.}]+)\.([^}]+)\}`)
|
||||
valueRegex := regexp.MustCompile(`\$\{([^.}/]+)\.(?:([^/}]+)/)?([^.}/]+)\.([^}]+)\}`)
|
||||
matches := valueRegex.FindAllStringSubmatch(configStr, -1)
|
||||
// If there are no value references, return immediately
|
||||
if len(matches) == 0 {
|
||||
|
||||
@@ -114,6 +114,66 @@ func TestTemplateProcessor_ProcessConfig(t *testing.T) {
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "config with default and non-default namespaces (default first)",
|
||||
wasmPlugin: &extensions.WasmPlugin{
|
||||
PluginName: "test-plugin",
|
||||
PluginConfig: makeStructValue(t, map[string]interface{}{
|
||||
"a1": map[string]interface{}{
|
||||
"type": "${secret.auth-secret.auth_config.type}",
|
||||
"credentials": "${secret.auth-secret.auth_config.credentials}",
|
||||
},
|
||||
"a2": map[string]interface{}{
|
||||
"timeout": "${secret.default/test-secret.plugin_conf.timeout}",
|
||||
"max_retries": "${secret.default/test-secret.plugin_conf.max_retries}",
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: &extensions.WasmPlugin{
|
||||
PluginName: "test-plugin",
|
||||
PluginConfig: makeStructValue(t, map[string]interface{}{
|
||||
"a1": map[string]interface{}{
|
||||
"type": "basic",
|
||||
"credentials": "base64-encoded",
|
||||
},
|
||||
"a2": map[string]interface{}{
|
||||
"timeout": "5000",
|
||||
"max_retries": "3",
|
||||
},
|
||||
}),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "config with default and non-default namespaces (non-default first)",
|
||||
wasmPlugin: &extensions.WasmPlugin{
|
||||
PluginName: "test-plugin",
|
||||
PluginConfig: makeStructValue(t, map[string]interface{}{
|
||||
"a1": map[string]interface{}{
|
||||
"timeout": "${secret.default/test-secret.plugin_conf.timeout}",
|
||||
"max_retries": "${secret.default/test-secret.plugin_conf.max_retries}",
|
||||
},
|
||||
"a2": map[string]interface{}{
|
||||
"type": "${secret.auth-secret.auth_config.type}",
|
||||
"credentials": "${secret.auth-secret.auth_config.credentials}",
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: &extensions.WasmPlugin{
|
||||
PluginName: "test-plugin",
|
||||
PluginConfig: makeStructValue(t, map[string]interface{}{
|
||||
"a1": map[string]interface{}{
|
||||
"timeout": "5000",
|
||||
"max_retries": "3",
|
||||
},
|
||||
"a2": map[string]interface{}{
|
||||
"type": "basic",
|
||||
"credentials": "base64-encoded",
|
||||
},
|
||||
}),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent secret",
|
||||
wasmPlugin: &extensions.WasmPlugin{
|
||||
|
||||
@@ -503,7 +503,7 @@ func (c *controller) ConvertHTTPRoute(convertOptions *common.ConvertOptions, wra
|
||||
|
||||
// Two duplicated rules in the same ingress.
|
||||
if ingressRouteBuilder.Event == common.Normal {
|
||||
pathFormat := wrapperHttpRoute.PathFormat()
|
||||
pathFormat := wrapperHttpRoute.PathFormat() + kingressPathHeadersKey(httpPath.Headers)
|
||||
if definedRules.Contains(pathFormat) {
|
||||
ingressRouteBuilder.PreIngress = cfg
|
||||
ingressRouteBuilder.Event = common.DuplicatedRoute
|
||||
@@ -726,3 +726,25 @@ func isIngressPublic(ingSpec *ingress.IngressSpec) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// kingressPathHeadersKey builds a stable string from path-level headers for use
|
||||
// in duplicate-route detection. KIngress paths are distinguished by headers
|
||||
// (not by URL path), so the dedup key must include header information.
|
||||
func kingressPathHeadersKey(headers map[string]ingress.HeaderMatch) string {
|
||||
if len(headers) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(headers))
|
||||
for k := range headers {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var sb strings.Builder
|
||||
for _, k := range keys {
|
||||
sb.WriteByte('\x00')
|
||||
sb.WriteString(k)
|
||||
sb.WriteByte('=')
|
||||
sb.WriteString(headers[k].Exact)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -619,3 +619,182 @@ func TestCreateRuleKey(t *testing.T) {
|
||||
func buildHigressAnnotationKey(key string) string {
|
||||
return annotations.HigressAnnotationsPrefix + "/" + key
|
||||
}
|
||||
|
||||
// TestKingressPathHeadersKey verifies that kingressPathHeadersKey produces
|
||||
// stable, unique keys for different header combinations.
|
||||
func TestKingressPathHeadersKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]ingress.HeaderMatch
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil headers",
|
||||
headers: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: map[string]ingress.HeaderMatch{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "single header",
|
||||
headers: map[string]ingress.HeaderMatch{
|
||||
"x-version": {Exact: "v1"},
|
||||
},
|
||||
want: "\x00x-version=v1",
|
||||
},
|
||||
{
|
||||
name: "multiple headers are sorted deterministically",
|
||||
headers: map[string]ingress.HeaderMatch{
|
||||
"x-version": {Exact: "v2"},
|
||||
"x-env": {Exact: "prod"},
|
||||
},
|
||||
// sorted: x-env, x-version
|
||||
want: "\x00x-env=prod\x00x-version=v2",
|
||||
},
|
||||
{
|
||||
name: "same headers different values produce different keys",
|
||||
headers: map[string]ingress.HeaderMatch{
|
||||
"x-version": {Exact: "v2"},
|
||||
},
|
||||
want: "\x00x-version=v2",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := kingressPathHeadersKey(tt.headers)
|
||||
if got != tt.want {
|
||||
t.Errorf("kingressPathHeadersKey() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify that v1 and v2 keys are distinct.
|
||||
keyV1 := kingressPathHeadersKey(map[string]ingress.HeaderMatch{"x-version": {Exact: "v1"}})
|
||||
keyV2 := kingressPathHeadersKey(map[string]ingress.HeaderMatch{"x-version": {Exact: "v2"}})
|
||||
if keyV1 == keyV2 {
|
||||
t.Errorf("expected distinct keys for different header values, got %q == %q", keyV1, keyV2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertHTTPRoute_HeaderDistinctPaths verifies that two KIngress paths
|
||||
// sharing the same URL path but differing only in header-match rules are NOT
|
||||
// treated as duplicates and both produce VirtualService routes.
|
||||
//
|
||||
// KIngress example that triggers the bug (before fix):
|
||||
//
|
||||
// apiVersion: networking.internal.knative.dev/v1alpha1
|
||||
// kind: Ingress
|
||||
// metadata:
|
||||
// name: hello-header-routing
|
||||
// namespace: default
|
||||
// spec:
|
||||
// rules:
|
||||
// - hosts: ["hello.default.example.com"]
|
||||
// http:
|
||||
// paths:
|
||||
// - path: "/"
|
||||
// headers:
|
||||
// x-version:
|
||||
// exact: "v1"
|
||||
// splits:
|
||||
// - serviceName: hello-v1
|
||||
// servicePort: 80
|
||||
// percent: 100
|
||||
// - path: "/"
|
||||
// headers:
|
||||
// x-version:
|
||||
// exact: "v2"
|
||||
// splits:
|
||||
// - serviceName: hello-v2
|
||||
// servicePort: 80
|
||||
// percent: 100
|
||||
//
|
||||
// Before the fix, the second path (x-version: v2) was incorrectly marked as
|
||||
// DuplicatedRoute and dropped, leaving only the v1 route in the VirtualService.
|
||||
// After the fix, both routes are preserved.
|
||||
func TestConvertHTTPRoute_HeaderDistinctPaths(t *testing.T) {
|
||||
fakeClient := kube.NewFakeClient()
|
||||
options := common.Options{IngressClass: "mse", ClusterId: "", EnableStatus: true}
|
||||
secretController := secret.NewController(fakeClient, options)
|
||||
c := NewController(fakeClient, fakeClient, options, secretController)
|
||||
|
||||
convertOptions := &common.ConvertOptions{
|
||||
IngressDomainCache: &common.IngressDomainCache{
|
||||
Valid: make(map[string]*common.IngressDomainBuilder),
|
||||
Invalid: make([]model.IngressDomain, 0),
|
||||
},
|
||||
Route2Ingress: map[string]*common.WrapperConfigWithRuleKey{},
|
||||
VirtualServices: make(map[string]*common.WrapperVirtualService),
|
||||
Gateways: make(map[string]*common.WrapperGateway),
|
||||
IngressRouteCache: common.NewIngressRouteCache(),
|
||||
HTTPRoutes: make(map[string][]*common.WrapperHTTPRoute),
|
||||
}
|
||||
|
||||
wrapperConfig := &common.WrapperConfig{
|
||||
Config: &config.Config{
|
||||
Meta: config.Meta{
|
||||
Name: "hello-header-routing",
|
||||
Namespace: "default",
|
||||
},
|
||||
// Two paths share the same URL "/" but differ by x-version header.
|
||||
// Before fix: second path was dropped as DuplicatedRoute.
|
||||
// After fix: both paths are kept.
|
||||
Spec: ingress.IngressSpec{
|
||||
Rules: []ingress.IngressRule{
|
||||
{
|
||||
Hosts: []string{"hello.default.example.com"},
|
||||
HTTP: &ingress.HTTPIngressRuleValue{
|
||||
Paths: []ingress.HTTPIngressPath{
|
||||
{
|
||||
Path: "/",
|
||||
Headers: map[string]ingress.HeaderMatch{
|
||||
"x-version": {Exact: "v1"},
|
||||
},
|
||||
Splits: []ingress.IngressBackendSplit{{
|
||||
IngressBackend: ingress.IngressBackend{
|
||||
ServiceNamespace: "default",
|
||||
ServiceName: "hello-v1",
|
||||
ServicePort: intstr.FromInt(80),
|
||||
},
|
||||
Percent: 100,
|
||||
}},
|
||||
},
|
||||
{
|
||||
Path: "/",
|
||||
Headers: map[string]ingress.HeaderMatch{
|
||||
"x-version": {Exact: "v2"},
|
||||
},
|
||||
Splits: []ingress.IngressBackendSplit{{
|
||||
IngressBackend: ingress.IngressBackend{
|
||||
ServiceNamespace: "default",
|
||||
ServiceName: "hello-v2",
|
||||
ServicePort: intstr.FromInt(80),
|
||||
},
|
||||
Percent: 100,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
Visibility: ingress.IngressVisibilityExternalIP,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AnnotationsConfig: &annotations.Ingress{},
|
||||
}
|
||||
|
||||
err := c.ConvertHTTPRoute(convertOptions, wrapperConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
routes, ok := convertOptions.HTTPRoutes["hello.default.example.com"]
|
||||
require.True(t, ok, "expected HTTPRoutes entry for hello.default.example.com")
|
||||
|
||||
// Both header-differentiated paths must survive dedup and appear as
|
||||
// separate WrapperHTTPRoute entries destined for distinct backends.
|
||||
require.Equal(t, 2, len(routes),
|
||||
"expected 2 routes (one per header value), got %d; "+
|
||||
"the second path was likely dropped as a false duplicate", len(routes))
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (s *statusSyncer) updateStatus(status []coreV1.LoadBalancerIngress) error {
|
||||
}
|
||||
ingress.Status.MarkNetworkConfigured()
|
||||
KIngressStatus := transportLoadBalancerIngress(status)
|
||||
if ingress.Status.PublicLoadBalancer == nil || len(ingress.Status.PublicLoadBalancer.Ingress) != len(KIngressStatus) || reflect.DeepEqual(ingress.Status.PublicLoadBalancer.Ingress, KIngressStatus) {
|
||||
if ingress.Status.PublicLoadBalancer == nil || len(ingress.Status.PublicLoadBalancer.Ingress) != len(KIngressStatus) || !reflect.DeepEqual(ingress.Status.PublicLoadBalancer.Ingress, KIngressStatus) {
|
||||
ingress.Status.ObservedGeneration = ingress.Generation
|
||||
ingress.Status.MarkLoadBalancerReady(KIngressStatus, KIngressStatus)
|
||||
IngressLog.Infof("Update Ingress %v/%v within cluster %s status", ingress.Namespace, ingress.Name, s.controller.options.ClusterId)
|
||||
|
||||
186
pkg/ingress/kube/kingress/status_test.go
Normal file
186
pkg/ingress/kube/kingress/status_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kingress
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
coreV1 "k8s.io/api/core/v1"
|
||||
"knative.dev/networking/pkg/apis/networking/v1alpha1"
|
||||
)
|
||||
|
||||
// TestTransportLoadBalancerIngress verifies that transportLoadBalancerIngress
|
||||
// correctly maps k8s LoadBalancerIngress entries to knative LoadBalancerIngressStatus.
|
||||
func TestTransportLoadBalancerIngress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []coreV1.LoadBalancerIngress
|
||||
expect []v1alpha1.LoadBalancerIngressStatus
|
||||
}{
|
||||
{
|
||||
name: "nil input returns nil",
|
||||
input: nil,
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "empty input returns nil",
|
||||
input: []coreV1.LoadBalancerIngress{},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "ip only entry",
|
||||
input: []coreV1.LoadBalancerIngress{
|
||||
{IP: "1.2.3.4"},
|
||||
},
|
||||
expect: []v1alpha1.LoadBalancerIngressStatus{
|
||||
{IP: "1.2.3.4", Domain: ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hostname only entry",
|
||||
input: []coreV1.LoadBalancerIngress{
|
||||
{Hostname: "lb.example.com"},
|
||||
},
|
||||
expect: []v1alpha1.LoadBalancerIngressStatus{
|
||||
{IP: "", Domain: "lb.example.com"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple entries preserve order",
|
||||
input: []coreV1.LoadBalancerIngress{
|
||||
{IP: "10.0.0.1"},
|
||||
{IP: "10.0.0.2", Hostname: "lb2.example.com"},
|
||||
},
|
||||
expect: []v1alpha1.LoadBalancerIngressStatus{
|
||||
{IP: "10.0.0.1", Domain: ""},
|
||||
{IP: "10.0.0.2", Domain: "lb2.example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := transportLoadBalancerIngress(tt.input)
|
||||
if len(got) != len(tt.expect) {
|
||||
t.Fatalf("len mismatch: got %d, want %d", len(got), len(tt.expect))
|
||||
}
|
||||
for i := range got {
|
||||
if got[i].IP != tt.expect[i].IP || got[i].Domain != tt.expect[i].Domain {
|
||||
t.Errorf("entry[%d]: got {IP:%q Domain:%q}, want {IP:%q Domain:%q}",
|
||||
i, got[i].IP, got[i].Domain, tt.expect[i].IP, tt.expect[i].Domain)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateStatusCondition tests the update-trigger condition in updateStatus:
|
||||
//
|
||||
// PublicLoadBalancer == nil || len differs || !DeepEqual
|
||||
//
|
||||
// Before the fix (commit f04791b4), the condition was:
|
||||
//
|
||||
// PublicLoadBalancer == nil || len differs || DeepEqual ← missing !
|
||||
//
|
||||
// This meant the status was updated when the LB list was EQUAL (no-op churn)
|
||||
// and skipped when it was DIFFERENT (the actual update never happened).
|
||||
//
|
||||
// The table below documents each branch so a regression immediately shows
|
||||
// which invariant was broken.
|
||||
func TestUpdateStatusCondition(t *testing.T) {
|
||||
newStatus := func(ips ...string) *v1alpha1.LoadBalancerIngressStatus {
|
||||
return nil // helper not used directly; see inline construction below
|
||||
}
|
||||
_ = newStatus
|
||||
|
||||
makeKnative := func(ips ...string) []v1alpha1.LoadBalancerIngressStatus {
|
||||
out := make([]v1alpha1.LoadBalancerIngressStatus, len(ips))
|
||||
for i, ip := range ips {
|
||||
out[i] = v1alpha1.LoadBalancerIngressStatus{IP: ip}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing *v1alpha1.LoadBalancerStatus // PublicLoadBalancer field
|
||||
incoming []v1alpha1.LoadBalancerIngressStatus
|
||||
wantShouldUpd bool // true == condition evaluates to true (update needed)
|
||||
}{
|
||||
{
|
||||
name: "PublicLoadBalancer is nil → always update",
|
||||
existing: nil,
|
||||
incoming: makeKnative("1.2.3.4"),
|
||||
wantShouldUpd: true,
|
||||
},
|
||||
{
|
||||
name: "lengths differ → update",
|
||||
existing: &v1alpha1.LoadBalancerStatus{
|
||||
Ingress: makeKnative("1.2.3.4"),
|
||||
},
|
||||
incoming: makeKnative("1.2.3.4", "5.6.7.8"),
|
||||
wantShouldUpd: true,
|
||||
},
|
||||
{
|
||||
// Bug scenario: status is DIFFERENT → must update.
|
||||
// Before fix: !DeepEqual was missing, so this branch was skipped.
|
||||
name: "same length but different IPs → update (was broken before fix)",
|
||||
existing: &v1alpha1.LoadBalancerStatus{
|
||||
Ingress: makeKnative("1.2.3.4"),
|
||||
},
|
||||
incoming: makeKnative("9.9.9.9"),
|
||||
wantShouldUpd: true,
|
||||
},
|
||||
{
|
||||
// Idempotency: status is already up-to-date → skip update.
|
||||
// Before fix: DeepEqual (without !) was true here, so it wrongly
|
||||
// triggered an unnecessary update on every reconcile loop.
|
||||
name: "status already up-to-date → no update needed (was broken before fix)",
|
||||
existing: &v1alpha1.LoadBalancerStatus{
|
||||
Ingress: makeKnative("1.2.3.4"),
|
||||
},
|
||||
incoming: makeKnative("1.2.3.4"),
|
||||
wantShouldUpd: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mirror the exact condition from status.go updateStatus:
|
||||
// PublicLoadBalancer == nil || len differs || !DeepEqual
|
||||
shouldUpdate := tt.existing == nil ||
|
||||
len(tt.existing.Ingress) != len(tt.incoming) ||
|
||||
!equalLoadBalancerStatus(tt.existing.Ingress, tt.incoming)
|
||||
|
||||
if shouldUpdate != tt.wantShouldUpd {
|
||||
t.Errorf("condition = %v, want %v", shouldUpdate, tt.wantShouldUpd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// equalLoadBalancerStatus compares two LoadBalancerIngressStatus slices
|
||||
// element-by-element (mirrors reflect.DeepEqual for this type).
|
||||
func equalLoadBalancerStatus(a, b []v1alpha1.LoadBalancerIngressStatus) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].IP != b[i].IP || a[i].Domain != b[i].Domain {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -17,7 +17,6 @@ package translation
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"istio.io/istio/pilot/pkg/model"
|
||||
istiomodel "istio.io/istio/pilot/pkg/model"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/collection"
|
||||
@@ -40,8 +39,8 @@ type IngressTranslation struct {
|
||||
ingressConfig *ingressconfig.IngressConfig
|
||||
kingressConfig *ingressconfig.KIngressConfig
|
||||
mutex sync.RWMutex
|
||||
higressRouteCache model.IngressRouteCollection
|
||||
higressDomainCache model.IngressDomainCollection
|
||||
higressRouteCache istiomodel.IngressRouteCollection
|
||||
higressDomainCache istiomodel.IngressDomainCollection
|
||||
}
|
||||
|
||||
func NewIngressTranslation(localKubeClient kube.Client, xdsUpdater istiomodel.XDSUpdater, namespace string, options common.Options) *IngressTranslation {
|
||||
@@ -109,11 +108,11 @@ func (m *IngressTranslation) SetWatchErrorHandler(f func(r *cache.Reflector, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *IngressTranslation) GetIngressRoutes() model.IngressRouteCollection {
|
||||
func (m *IngressTranslation) GetIngressRoutes() istiomodel.IngressRouteCollection {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
ingressRouteCache := m.ingressConfig.GetIngressRoutes()
|
||||
m.higressRouteCache = model.IngressRouteCollection{}
|
||||
m.higressRouteCache = istiomodel.IngressRouteCollection{}
|
||||
m.higressRouteCache.Invalid = append(m.higressRouteCache.Invalid, ingressRouteCache.Invalid...)
|
||||
m.higressRouteCache.Valid = append(m.higressRouteCache.Valid, ingressRouteCache.Valid...)
|
||||
if m.kingressConfig != nil {
|
||||
@@ -125,12 +124,12 @@ func (m *IngressTranslation) GetIngressRoutes() model.IngressRouteCollection {
|
||||
return m.higressRouteCache
|
||||
}
|
||||
|
||||
func (m *IngressTranslation) GetIngressDomains() model.IngressDomainCollection {
|
||||
func (m *IngressTranslation) GetIngressDomains() istiomodel.IngressDomainCollection {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
ingressDomainCache := m.ingressConfig.GetIngressDomains()
|
||||
|
||||
m.higressDomainCache = model.IngressDomainCollection{}
|
||||
m.higressDomainCache = istiomodel.IngressDomainCollection{}
|
||||
m.higressDomainCache.Invalid = append(m.higressDomainCache.Invalid, ingressDomainCache.Invalid...)
|
||||
m.higressDomainCache.Valid = append(m.higressDomainCache.Valid, ingressDomainCache.Valid...)
|
||||
if m.kingressConfig != nil {
|
||||
|
||||
@@ -140,10 +140,16 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
|
||||
// Send the initial endpoint event
|
||||
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||
err = s.redisClient.Publish(channel, initialEvent)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to send initial event: %v", err)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Failed to send initial event: %v", r)
|
||||
}
|
||||
}()
|
||||
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
||||
api.LogDebugf("SSE Send message: %s", initialEvent)
|
||||
cb.EncoderFilterCallbacks().InjectData([]byte(initialEvent))
|
||||
}()
|
||||
|
||||
// Start health check handler
|
||||
go func() {
|
||||
|
||||
@@ -52,6 +52,9 @@ var (
|
||||
{provider.PathOpenAICompletions, provider.ApiNameCompletion},
|
||||
{provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings},
|
||||
{provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech},
|
||||
{provider.PathOpenAIAudioTranscriptions, provider.ApiNameAudioTranscription},
|
||||
{provider.PathOpenAIAudioTranslations, provider.ApiNameAudioTranslation},
|
||||
{provider.PathOpenAIRealtime, provider.ApiNameRealtime},
|
||||
{provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration},
|
||||
{provider.PathOpenAIImageVariation, provider.ApiNameImageVariation},
|
||||
{provider.PathOpenAIImageEdit, provider.ApiNameImageEdit},
|
||||
@@ -225,9 +228,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
}
|
||||
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
|
||||
ctx.DontReadRequestBody()
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
|
||||
}
|
||||
|
||||
if apiName == "" {
|
||||
@@ -297,7 +300,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
// 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数
|
||||
if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
||||
// generic provider 不做能力映射,不添加 stream_options
|
||||
if providerConfig.IsOpenAIProtocol() && !providerConfig.IsGeneric() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
||||
newBody = normalizeOpenAiRequestBody(newBody)
|
||||
}
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
@@ -306,6 +310,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
|
||||
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue
|
||||
@@ -381,6 +386,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return chunk
|
||||
}
|
||||
|
||||
promoteThinking := pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty()
|
||||
|
||||
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
||||
|
||||
@@ -388,6 +395,9 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
if promoteThinking {
|
||||
modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk, isLastChunk)
|
||||
}
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
|
||||
if convertErr != nil {
|
||||
@@ -431,6 +441,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
@@ -439,11 +453,12 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return claudeChunk
|
||||
}
|
||||
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
if !needsClaudeResponseConversion(ctx) && !promoteThinking {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// If provider doesn't implement any streaming handlers but we need Claude conversion
|
||||
// or thinking promotion
|
||||
// First extract complete events from the chunk
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events))
|
||||
@@ -460,6 +475,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
@@ -492,6 +511,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
finalBody = body
|
||||
}
|
||||
|
||||
// Promote thinking/reasoning to content when content is empty
|
||||
if pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty() {
|
||||
promoted, err := provider.PromoteThinkingOnEmptyResponse(finalBody)
|
||||
if err != nil {
|
||||
log.Warnf("[promoteThinkingOnEmpty] failed: %v", err)
|
||||
} else {
|
||||
finalBody = promoted
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed (applies to both branches)
|
||||
convertedBody, err := convertResponseBodyToClaude(ctx, finalBody)
|
||||
if err != nil {
|
||||
@@ -540,6 +569,49 @@ func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]b
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
// promoteThinkingInStreamingChunk processes SSE-formatted streaming data, buffering
|
||||
// reasoning deltas and stripping them from chunks. On the last chunk, if no content
|
||||
// was ever seen, it appends a flush chunk that emits buffered reasoning as content.
|
||||
func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte, isLastChunk bool) []byte {
|
||||
// SSE data contains lines like "data: {...}\n\n"
|
||||
// We need to find and process each data line
|
||||
lines := strings.Split(string(data), "\n")
|
||||
modified := false
|
||||
for i, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
if payload == "[DONE]" || payload == "" {
|
||||
continue
|
||||
}
|
||||
stripped, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
newLine := "data: " + string(stripped)
|
||||
if newLine != line {
|
||||
lines[i] = newLine
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
result := data
|
||||
if modified {
|
||||
result = []byte(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
// On last chunk, flush buffered reasoning as content if no content was seen
|
||||
if isLastChunk {
|
||||
flushChunk := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
if flushChunk != nil {
|
||||
result = append(flushChunk, result...)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper function to convert OpenAI response body to Claude format
|
||||
func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
@@ -594,3 +666,14 @@ func getApiName(path string) provider.ApiName {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
|
||||
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
return true
|
||||
}
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -18,6 +18,12 @@ func Test_getApiName(t *testing.T) {
|
||||
{"openai completions", "/v1/completions", provider.ApiNameCompletion},
|
||||
{"openai embeddings", "/v1/embeddings", provider.ApiNameEmbeddings},
|
||||
{"openai audio speech", "/v1/audio/speech", provider.ApiNameAudioSpeech},
|
||||
{"openai audio transcriptions", "/v1/audio/transcriptions", provider.ApiNameAudioTranscription},
|
||||
{"openai audio transcriptions with prefix", "/proxy/v1/audio/transcriptions", provider.ApiNameAudioTranscription},
|
||||
{"openai audio translations", "/v1/audio/translations", provider.ApiNameAudioTranslation},
|
||||
{"openai realtime", "/v1/realtime", provider.ApiNameRealtime},
|
||||
{"openai realtime with prefix", "/proxy/v1/realtime", provider.ApiNameRealtime},
|
||||
{"openai realtime with trailing slash", "/v1/realtime/", ""},
|
||||
{"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration},
|
||||
{"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation},
|
||||
{"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit},
|
||||
@@ -63,6 +69,54 @@ func Test_getApiName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSupportedRequestContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName provider.ApiName
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "json chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "application/json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image variation",
|
||||
apiName: provider.ApiNameImageVariation,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "text plain image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "text/plain",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAi360(t *testing.T) {
|
||||
test.RunAi360ParseConfigTests(t)
|
||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||
@@ -79,6 +133,8 @@ func TestOpenAI(t *testing.T) {
|
||||
test.RunOpenAIOnHttpResponseHeadersTests(t)
|
||||
test.RunOpenAIOnHttpResponseBodyTests(t)
|
||||
test.RunOpenAIOnStreamingResponseBodyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyStreamingTests(t)
|
||||
}
|
||||
|
||||
func TestQwen(t *testing.T) {
|
||||
@@ -102,6 +158,7 @@ func TestGemini(t *testing.T) {
|
||||
|
||||
func TestAzure(t *testing.T) {
|
||||
test.RunAzureParseConfigTests(t)
|
||||
test.RunAzureMultipartHelperTests(t)
|
||||
test.RunAzureOnHttpRequestHeadersTests(t)
|
||||
test.RunAzureOnHttpRequestBodyTests(t)
|
||||
test.RunAzureOnHttpResponseHeadersTests(t)
|
||||
@@ -123,6 +180,10 @@ func TestUtil(t *testing.T) {
|
||||
test.RunMapRequestPathByCapabilityTests(t)
|
||||
}
|
||||
|
||||
func TestApiPathRegression(t *testing.T) {
|
||||
test.RunApiPathRegressionTests(t)
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
test.RunGenericParseConfigTests(t)
|
||||
test.RunGenericOnHttpRequestHeadersTests(t)
|
||||
@@ -135,8 +196,12 @@ func TestVertex(t *testing.T) {
|
||||
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
test.RunVertexOpenAICompatibleModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
|
||||
// Vertex Raw 模式测试
|
||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||
@@ -149,6 +214,7 @@ func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
test.RunBedrockOnStreamingResponseBodyTests(t)
|
||||
test.RunBedrockToolCallTests(t)
|
||||
}
|
||||
|
||||
@@ -157,3 +223,16 @@ func TestClaude(t *testing.T) {
|
||||
test.RunClaudeOnHttpRequestHeadersTests(t)
|
||||
test.RunClaudeOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestConsumerAffinity(t *testing.T) {
|
||||
test.RunConsumerAffinityParseConfigTests(t)
|
||||
test.RunConsumerAffinityOnHttpRequestHeadersTests(t)
|
||||
}
|
||||
|
||||
func TestOpenRouter(t *testing.T) {
|
||||
test.RunOpenRouterClaudeAutoConversionTests(t)
|
||||
}
|
||||
|
||||
func TestZhipuAI(t *testing.T) {
|
||||
test.RunZhipuAIClaudeAutoConversionTests(t)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
@@ -151,17 +152,44 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func isAzureMultipartImageRequest(apiName ApiName, contentType string) bool {
|
||||
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
|
||||
return false
|
||||
}
|
||||
return isMultipartFormData(contentType)
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
|
||||
transformedBody = body
|
||||
err = nil
|
||||
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType)
|
||||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||||
|
||||
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
if isMultipartImageRequest {
|
||||
if err != nil {
|
||||
log.Debugf("[azure multipart] body transform failed: api=%s, err=%v", apiName, err)
|
||||
} else {
|
||||
log.Debugf("[azure multipart] body transformed: api=%s, originalModel=%s, mappedModel=%s, bodyBytes=%d->%d",
|
||||
apiName,
|
||||
ctx.GetStringContext(ctxKeyOriginalRequestModel, ""),
|
||||
ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
len(body),
|
||||
len(transformedBody),
|
||||
)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
|
||||
if path := m.transformRequestPath(ctx, apiName); path != "" {
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] body path overwrite: api=%s, path=%s, modelInContext=%s",
|
||||
apiName, path, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
}
|
||||
err = util.OverwriteRequestPath(path)
|
||||
if err == nil {
|
||||
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
|
||||
@@ -222,16 +250,30 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
contentType := headers.Get(util.HeaderContentType)
|
||||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||||
|
||||
// We need to overwrite the request path in the request headers stage,
|
||||
// because for some APIs, we don't read the request body and the path is model irrelevant.
|
||||
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
|
||||
util.OverwriteRequestPathHeader(headers, overwrittenPath)
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] header path overwrite: api=%s, path=%s, modelInContext=%s",
|
||||
apiName, overwrittenPath, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
}
|
||||
}
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) {
|
||||
supportedAPI := m.config.isSupportedAPI(apiName)
|
||||
needProcessBody := m.config.needToProcessRequestBody(apiName)
|
||||
if isMultipartImageRequest {
|
||||
log.Debugf("[azure multipart] body processing decision: api=%s, supported=%t, needProcessBody=%t",
|
||||
apiName, supportedAPI, needProcessBody)
|
||||
}
|
||||
|
||||
if !supportedAPI || !needProcessBody {
|
||||
// If the API is not supported or there is no need to process the body,
|
||||
// we should not read the request body and keep it as it is.
|
||||
ctx.DontReadRequestBody()
|
||||
|
||||
@@ -35,9 +35,23 @@ const (
|
||||
// converseStream路径 /model/{modelId}/converse-stream
|
||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||
// invoke_model 路径 /model/{modelId}/invoke
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockCacheTypeDefault = "default"
|
||||
bedrockCacheTTL5m = "5m"
|
||||
bedrockCacheTTL1h = "1h"
|
||||
bedrockPromptCacheNova = "amazon.nova"
|
||||
bedrockPromptCacheClaude = "anthropic.claude"
|
||||
|
||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||
bedrockCachePointPositionLastMessage = "lastMessage"
|
||||
)
|
||||
|
||||
var (
|
||||
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
|
||||
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
|
||||
)
|
||||
|
||||
type bedrockProviderInitializer struct{}
|
||||
@@ -164,9 +178,10 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Usage != nil {
|
||||
openAIFormattedChunk.Choices = choices[:0]
|
||||
openAIFormattedChunk.Usage = &usage{
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens),
|
||||
}
|
||||
}
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
@@ -630,13 +645,24 @@ func (b *bedrockProvider) GetProviderType() string {
|
||||
return providerTypeBedrock
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case bedrockConversePathPattern.MatchString(path):
|
||||
return ApiNameChatCompletion
|
||||
case bedrockInvokePathPattern.MatchString(path):
|
||||
return ApiNameImageGeneration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
b.config.handleRequestHeaders(b, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
|
||||
|
||||
// If apiTokens is configured, set Bearer token authentication here
|
||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||
@@ -647,6 +673,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
|
||||
// and only apply auth headers.
|
||||
if b.config.IsOriginal() {
|
||||
headers := util.GetRequestHeaders()
|
||||
b.setAuthHeaders(body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
if !b.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
@@ -654,14 +689,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
var transformedBody []byte
|
||||
var err error
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Always apply auth after request body/path are finalized.
|
||||
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
|
||||
b.setAuthHeaders(transformedBody, headers)
|
||||
return transformedBody, nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
@@ -715,9 +761,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
Quality: origRequest.Quality,
|
||||
},
|
||||
}
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
@@ -797,6 +841,19 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention)
|
||||
|
||||
if origRequest.PromptCacheKey != "" {
|
||||
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
||||
}
|
||||
if isPromptCacheSupportedModel(origRequest.Model) {
|
||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok {
|
||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||
}
|
||||
} else if effectivePromptCacheRetention != "" {
|
||||
log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model)
|
||||
}
|
||||
|
||||
if origRequest.ReasoningEffort != "" {
|
||||
thinkingBudget := 1024 // default
|
||||
switch origRequest.ReasoningEffort {
|
||||
@@ -847,9 +904,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
|
||||
@@ -900,9 +955,10 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: &usage{
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -933,6 +989,145 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||
normalizedRetention := normalizePromptCacheRetention(retention)
|
||||
switch normalizedRetention {
|
||||
case "":
|
||||
return "", false
|
||||
case "in_memory":
|
||||
// For the default 5-minute cache, omit ttl and let Bedrock apply its default.
|
||||
// This is more robust for models that are strict about explicit ttl fields.
|
||||
return "", true
|
||||
case "24h":
|
||||
return bedrockCacheTTL1h, true
|
||||
default:
|
||||
log.Warnf("unsupported prompt_cache_retention for bedrock mapping: %s", retention)
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePromptCacheRetention(retention string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(retention))
|
||||
normalized = strings.ReplaceAll(normalized, "-", "_")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "_")
|
||||
if normalized == "inmemory" {
|
||||
return "in_memory"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isPromptCacheSupportedModel(model string) bool {
|
||||
normalizedModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(normalizedModel, bedrockPromptCacheNova) ||
|
||||
strings.Contains(normalizedModel, bedrockPromptCacheClaude)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string {
|
||||
if requestPromptCacheRetention != "" {
|
||||
return requestPromptCacheRetention
|
||||
}
|
||||
if b.config.promptCacheRetention != "" {
|
||||
return b.config.promptCacheRetention
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
||||
if b.config.bedrockPromptCachePointPositions == nil {
|
||||
return map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: true,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
}
|
||||
positions := map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: false,
|
||||
bedrockCachePointPositionLastUserMessage: false,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
for rawKey, enabled := range b.config.bedrockPromptCachePointPositions {
|
||||
key := normalizeBedrockCachePointPosition(rawKey)
|
||||
switch key {
|
||||
case bedrockCachePointPositionSystemPrompt, bedrockCachePointPositionLastUserMessage, bedrockCachePointPositionLastMessage:
|
||||
positions[key] = enabled
|
||||
default:
|
||||
log.Warnf("unsupported bedrockPromptCachePointPositions key: %s", rawKey)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func normalizeBedrockCachePointPosition(raw string) string {
|
||||
key := strings.ToLower(raw)
|
||||
key = strings.ReplaceAll(key, "_", "")
|
||||
key = strings.ReplaceAll(key, "-", "")
|
||||
switch key {
|
||||
case "systemprompt":
|
||||
return bedrockCachePointPositionSystemPrompt
|
||||
case "lastusermessage":
|
||||
return bedrockCachePointPositionLastUserMessage
|
||||
case "lastmessage":
|
||||
return bedrockCachePointPositionLastMessage
|
||||
default:
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
func addPromptCachePointsToBedrockRequest(request *bedrockTextGenRequest, cacheTTL string, positions map[string]bool) {
|
||||
if positions[bedrockCachePointPositionSystemPrompt] && len(request.System) > 0 {
|
||||
request.System = append(request.System, systemContentBlock{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
lastUserMessageIndex := -1
|
||||
if positions[bedrockCachePointPositionLastUserMessage] {
|
||||
lastUserMessageIndex = findLastMessageIndexByRole(request.Messages, roleUser)
|
||||
if lastUserMessageIndex >= 0 {
|
||||
appendCachePointToBedrockMessage(request, lastUserMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
if positions[bedrockCachePointPositionLastMessage] && len(request.Messages) > 0 {
|
||||
lastMessageIndex := len(request.Messages) - 1
|
||||
if lastMessageIndex != lastUserMessageIndex {
|
||||
appendCachePointToBedrockMessage(request, lastMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == role {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
||||
if messageIndex < 0 || messageIndex >= len(request.Messages) {
|
||||
return
|
||||
}
|
||||
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails {
|
||||
totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens
|
||||
if totalCachedTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &promptTokensDetails{
|
||||
CachedTokens: totalCachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
type bedrockTextGenRequest struct {
|
||||
Messages []bedrockMessage `json:"messages"`
|
||||
System []systemContentBlock `json:"system,omitempty"`
|
||||
@@ -977,14 +1172,21 @@ type bedrockMessage struct {
|
||||
}
|
||||
|
||||
type bedrockMessageContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type systemContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockCachePoint struct {
|
||||
Type string `json:"type"`
|
||||
TTL string `json:"ttl,omitempty"`
|
||||
}
|
||||
|
||||
type imageBlock struct {
|
||||
@@ -1066,6 +1268,10 @@ type tokenUsage struct {
|
||||
OutputTokens int `json:"outputTokens,omitempty"`
|
||||
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
|
||||
CacheReadInputTokens int `json:"cacheReadInputTokens,omitempty"`
|
||||
|
||||
CacheWriteInputTokens int `json:"cacheWriteInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
|
||||
@@ -1163,35 +1369,45 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
}
|
||||
|
||||
// Use AWS Signature V4 authentication
|
||||
accessKey := strings.TrimSpace(b.config.awsAccessKey)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
path := headers.Get(":path")
|
||||
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
||||
headers.Set("X-Amz-Date", amzDate)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||
path = encodeSigV4Path(path)
|
||||
canonicalURI := encodeSigV4Path(path)
|
||||
hashedPayload := sha256Hex(body)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
secretKey := strings.TrimSpace(b.config.awsSecretKey)
|
||||
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
|
||||
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
|
||||
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
|
||||
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
|
||||
amzDate, credentialScope, hashedCanonReq)
|
||||
|
||||
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
|
||||
signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
|
||||
signature := hmacHex(signingKey, stringToSign)
|
||||
return signature
|
||||
}
|
||||
|
||||
func encodeSigV4Path(path string) string {
|
||||
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
|
||||
// and this implementation uses an empty canonical query string.
|
||||
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
|
||||
path = path[:queryIndex]
|
||||
}
|
||||
|
||||
segments := strings.Split(path, "/")
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeSigV4Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "raw model id keeps colon",
|
||||
path: "/model/global.amazon.nova-2-lite-v1:0/converse-stream",
|
||||
want: "/model/global.amazon.nova-2-lite-v1:0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "pre-encoded model id escapes percent to avoid mismatch",
|
||||
path: "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream",
|
||||
want: "/model/global.amazon.nova-2-lite-v1%253A0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "raw inference profile arn keeps colon and slash delimiters",
|
||||
path: "/model/arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.anthropic.claude-sonnet-4-20250514-v1:0/converse",
|
||||
want: "/model/arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.anthropic.claude-sonnet-4-20250514-v1:0/converse",
|
||||
},
|
||||
{
|
||||
name: "encoded inference profile arn preserves escaped slash as double-escaped percent",
|
||||
path: "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.anthropic.claude-sonnet-4-20250514-v1%3A0/converse",
|
||||
want: "/model/arn%253Aaws%253Abedrock%253Aus-east-1%253A123456789012%253Ainference-profile%252Fglobal.anthropic.claude-sonnet-4-20250514-v1%253A0/converse",
|
||||
},
|
||||
{
|
||||
name: "query string is stripped before canonical encoding",
|
||||
path: "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream?trace=1&foo=bar",
|
||||
want: "/model/global.amazon.nova-2-lite-v1%253A0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "invalid percent sequence falls back to escaped percent",
|
||||
path: "/model/abc%ZZxyz/converse",
|
||||
want: "/model/abc%25ZZxyz/converse",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, encodeSigV4Path(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverwriteRequestPathHeaderPreservesSingleEncodedRequestPath(t *testing.T) {
|
||||
p := &bedrockProvider{}
|
||||
plainModel := "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0"
|
||||
preEncodedModel := url.QueryEscape(plainModel)
|
||||
|
||||
t.Run("plain model is encoded once", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
p.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, plainModel)
|
||||
assert.Equal(t, "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.amazon.nova-2-lite-v1%3A0/converse", headers.Get(":path"))
|
||||
})
|
||||
|
||||
t.Run("pre-encoded model is not double encoded", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
p.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, preEncodedModel)
|
||||
assert.Equal(t, "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.amazon.nova-2-lite-v1%3A0/converse", headers.Get(":path"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateSignatureIgnoresQueryStringInCanonicalURI(t *testing.T) {
|
||||
p := &bedrockProvider{
|
||||
config: ProviderConfig{
|
||||
awsRegion: "ap-northeast-3",
|
||||
awsSecretKey: "test-secret",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`)
|
||||
pathWithoutQuery := "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream"
|
||||
pathWithQuery := pathWithoutQuery + "?trace=1&foo=bar"
|
||||
|
||||
sigWithoutQuery := p.generateSignature(pathWithoutQuery, "20260312T142942Z", "20260312", body)
|
||||
sigWithQuery := p.generateSignature(pathWithQuery, "20260312T142942Z", "20260312", body)
|
||||
assert.Equal(t, sigWithoutQuery, sigWithQuery)
|
||||
}
|
||||
|
||||
func TestGenerateSignatureDiffersForRawAndPreEncodedModelPath(t *testing.T) {
|
||||
p := &bedrockProvider{
|
||||
config: ProviderConfig{
|
||||
awsRegion: "ap-northeast-3",
|
||||
awsSecretKey: "test-secret",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`)
|
||||
rawPath := "/model/global.amazon.nova-2-lite-v1:0/converse-stream"
|
||||
preEncodedPath := "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream"
|
||||
|
||||
rawSignature := p.generateSignature(rawPath, "20260312T142942Z", "20260312", body)
|
||||
preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body)
|
||||
assert.NotEqual(t, rawSignature, preEncodedSignature)
|
||||
}
|
||||
|
||||
func TestNormalizePromptCacheRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retention string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inmemory alias maps to in_memory",
|
||||
retention: "inmemory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "dash style maps to in_memory",
|
||||
retention: "in-memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "space style with trim maps to in_memory",
|
||||
retention: " in memory ",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "already normalized remains unchanged",
|
||||
retention: "in_memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, normalizePromptCacheRetention(tt.retention))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendCachePointToBedrockMessageInvalidIndexNoop(t *testing.T) {
|
||||
request := &bedrockTextGenRequest{
|
||||
Messages: []bedrockMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{Text: "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendCachePointToBedrockMessage(request, -1, bedrockCacheTTL5m)
|
||||
appendCachePointToBedrockMessage(request, len(request.Messages), bedrockCacheTTL5m)
|
||||
|
||||
assert.Len(t, request.Messages[0].Content, 1)
|
||||
|
||||
appendCachePointToBedrockMessage(request, 0, bedrockCacheTTL5m)
|
||||
assert.Len(t, request.Messages[0].Content, 2)
|
||||
assert.NotNil(t, request.Messages[0].Content[1].CachePoint)
|
||||
}
|
||||
|
||||
func TestIsPromptCacheSupportedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "anthropic claude model is supported",
|
||||
model: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "amazon nova inference profile is supported",
|
||||
model: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other model is not supported",
|
||||
model: "meta.llama3-70b-instruct-v1:0",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isPromptCacheSupportedModel(tt.model))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -73,8 +73,8 @@ type claudeChatMessageContent struct {
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
// Tool result fields
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content claudeChatMessageContentWr `json:"content,omitempty"` // For tool_result - can be string or array
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content *claudeChatMessageContentWr `json:"content,omitempty"` // For tool_result - can be string or array
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
|
||||
@@ -237,13 +237,13 @@ type claudeTextGenResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
Signature string `json:"signature,omitempty"` // For thinking
|
||||
Thinking string `json:"thinking,omitempty"` // For thinking
|
||||
Type string `json:"type,omitempty"`
|
||||
Text *string `json:"text,omitempty"` // Use pointer: empty string outputs "text":"", nil omits field
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input *map[string]interface{} `json:"input,omitempty"` // Use pointer: empty map outputs "input":{}, nil omits field
|
||||
Signature *string `json:"signature,omitempty"` // For thinking - use pointer for empty string output
|
||||
Thinking *string `json:"thinking,omitempty"` // For thinking - use pointer for empty string output
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
@@ -269,11 +269,12 @@ type claudeTextGenStreamResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence json.RawMessage `json:"stop_sequence,omitempty"` // Use RawMessage to output explicit null
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
@@ -441,6 +442,34 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
claudeRequest.MaxTokens = claudeDefaultMaxTokens
|
||||
}
|
||||
|
||||
// Convert OpenAI reasoning parameters to Claude thinking configuration
|
||||
if origRequest.ReasoningEffort != "" || origRequest.ReasoningMaxTokens > 0 {
|
||||
var budgetTokens int
|
||||
if origRequest.ReasoningMaxTokens > 0 {
|
||||
budgetTokens = origRequest.ReasoningMaxTokens
|
||||
} else {
|
||||
// Convert reasoning_effort to budget_tokens
|
||||
switch origRequest.ReasoningEffort {
|
||||
case "low":
|
||||
budgetTokens = 1024 // Minimum required by Claude
|
||||
case "medium":
|
||||
budgetTokens = 8192
|
||||
case "high":
|
||||
budgetTokens = 16384
|
||||
default:
|
||||
budgetTokens = 8192 // Default to medium
|
||||
}
|
||||
}
|
||||
// Ensure minimum budget_tokens requirement
|
||||
if budgetTokens < 1024 {
|
||||
budgetTokens = 1024
|
||||
}
|
||||
claudeRequest.Thinking = &claudeThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: budgetTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Track if system message exists in original request
|
||||
hasSystemMessage := false
|
||||
for _, message := range origRequest.Messages {
|
||||
@@ -469,9 +498,102 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle OpenAI "tool" role messages - convert to Claude "user" role with tool_result content
|
||||
if message.Role == roleTool {
|
||||
toolResultContent := claudeChatMessageContent{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
}
|
||||
// Tool result content can be string or array
|
||||
if message.IsStringContent() {
|
||||
toolResultContent.Content = &claudeChatMessageContentWr{
|
||||
StringValue: message.StringContent(),
|
||||
IsString: true,
|
||||
}
|
||||
} else {
|
||||
// For array content, extract text parts
|
||||
var textParts []string
|
||||
for _, part := range message.ParseContent() {
|
||||
if part.Type == contentTypeText {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
toolResultContent.Content = &claudeChatMessageContentWr{
|
||||
StringValue: strings.Join(textParts, "\n"),
|
||||
IsString: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the last message is a user message with tool_result, merge if so
|
||||
if len(claudeRequest.Messages) > 0 {
|
||||
lastMsg := &claudeRequest.Messages[len(claudeRequest.Messages)-1]
|
||||
if lastMsg.Role == roleUser && !lastMsg.Content.IsString {
|
||||
// Check if last message contains tool_result
|
||||
hasToolResult := false
|
||||
for _, content := range lastMsg.Content.ArrayValue {
|
||||
if content.Type == "tool_result" {
|
||||
hasToolResult = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasToolResult {
|
||||
// Merge with existing tool_result message
|
||||
lastMsg.Content.ArrayValue = append(lastMsg.Content.ArrayValue, toolResultContent)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new user message with tool_result
|
||||
claudeMessage := claudeChatMessage{
|
||||
Role: roleUser,
|
||||
Content: NewArrayContent([]claudeChatMessageContent{toolResultContent}),
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
claudeMessage := claudeChatMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool_calls - convert to Claude tool_use content blocks
|
||||
if message.Role == roleAssistant && len(message.ToolCalls) > 0 {
|
||||
chatMessageContents := make([]claudeChatMessageContent, 0)
|
||||
|
||||
// Add text content if present
|
||||
if message.IsStringContent() && message.StringContent() != "" {
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: message.StringContent(),
|
||||
})
|
||||
}
|
||||
|
||||
// Convert tool_calls to tool_use content blocks
|
||||
for _, tc := range message.ToolCalls {
|
||||
var inputMap map[string]interface{}
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &inputMap); err != nil {
|
||||
log.Errorf("failed to parse tool call arguments: %v", err)
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
} else {
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: "tool_use",
|
||||
Id: tc.Id,
|
||||
Name: tc.Function.Name,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
|
||||
claudeMessage.Content = NewArrayContent(chatMessageContents)
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = NewStringContent(message.StringContent())
|
||||
} else {
|
||||
@@ -562,9 +684,41 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
}
|
||||
|
||||
func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenResponse) *chatCompletionResponse {
|
||||
// Extract text content, thinking content, and tool calls from Claude response
|
||||
var textContent string
|
||||
var reasoningContent string
|
||||
var toolCalls []toolCall
|
||||
for _, content := range origResponse.Content {
|
||||
switch content.Type {
|
||||
case contentTypeText:
|
||||
if content.Text != nil {
|
||||
textContent = *content.Text
|
||||
}
|
||||
case "thinking":
|
||||
if content.Thinking != nil {
|
||||
reasoningContent = *content.Thinking
|
||||
}
|
||||
case "tool_use":
|
||||
var args []byte
|
||||
if content.Input != nil {
|
||||
args, _ = json.Marshal(*content.Input)
|
||||
} else {
|
||||
args = []byte("{}")
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall{
|
||||
Id: content.Id,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: content.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: origResponse.Content[0].Text},
|
||||
Message: &chatMessage{Role: roleAssistant, Content: textContent, ReasoningContent: reasoningContent, ToolCalls: toolCalls},
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
|
||||
}
|
||||
|
||||
@@ -600,6 +754,8 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
return finishReasonStop
|
||||
case "max_tokens":
|
||||
return finishReasonLength
|
||||
case "tool_use":
|
||||
return finishReasonToolCall
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
@@ -626,11 +782,64 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_start":
|
||||
// Handle tool_use content block start
|
||||
if origResponse.ContentBlock != nil && origResponse.ContentBlock.Type == "tool_use" {
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{
|
||||
ToolCalls: []toolCall{
|
||||
{
|
||||
Index: index,
|
||||
Id: origResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: origResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
return nil
|
||||
|
||||
case "content_block_delta":
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
// Handle tool_use input_json_delta
|
||||
if origResponse.Delta != nil && origResponse.Delta.Type == "input_json_delta" {
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{
|
||||
ToolCalls: []toolCall{
|
||||
{
|
||||
Index: index,
|
||||
Function: functionCall{
|
||||
Arguments: origResponse.Delta.PartialJson,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
// Handle thinking_delta
|
||||
if origResponse.Delta != nil && origResponse.Delta.Type == "thinking_delta" {
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{Reasoning: origResponse.Delta.Thinking},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
}
|
||||
// Handle text_delta
|
||||
choice := chatCompletionChoice{
|
||||
Index: index,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
@@ -667,7 +876,7 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or
|
||||
TotalTokens: c.usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
case "content_block_stop", "ping", "content_block_start":
|
||||
case "content_block_stop", "ping":
|
||||
log.Debugf("skip processing response type: %s", origResponse.Type)
|
||||
return nil
|
||||
default:
|
||||
|
||||
@@ -315,3 +315,107 @@ func TestClaudeProvider_GetApiName(t *testing.T) {
|
||||
assert.Equal(t, ApiName(""), provider.GetApiName("/unknown"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeProvider_BuildClaudeTextGenRequest_ToolRoleConversion(t *testing.T) {
|
||||
provider := &claudeProvider{
|
||||
config: ProviderConfig{
|
||||
claudeCodeMode: false,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("converts_single_tool_role_to_user_with_tool_result", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather?"},
|
||||
{Role: roleAssistant, Content: nil, ToolCalls: []toolCall{
|
||||
{Id: "call_123", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
}},
|
||||
{Role: roleTool, ToolCallId: "call_123", Content: "Sunny, 25°C"},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
// Should have 3 messages: user, assistant with tool_use, user with tool_result
|
||||
require.Len(t, claudeReq.Messages, 3)
|
||||
|
||||
// First message should be user
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[0].Role)
|
||||
|
||||
// Second message should be assistant with tool_use
|
||||
assert.Equal(t, roleAssistant, claudeReq.Messages[1].Role)
|
||||
require.False(t, claudeReq.Messages[1].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 1)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_123", claudeReq.Messages[1].Content.ArrayValue[0].Id)
|
||||
assert.Equal(t, "get_weather", claudeReq.Messages[1].Content.ArrayValue[0].Name)
|
||||
|
||||
// Third message should be user with tool_result
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[2].Role)
|
||||
require.False(t, claudeReq.Messages[2].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[2].Content.ArrayValue, 1)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_123", claudeReq.Messages[2].Content.ArrayValue[0].ToolUseId)
|
||||
})
|
||||
|
||||
t.Run("merges_multiple_tool_results_into_single_user_message", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather and time?"},
|
||||
{Role: roleAssistant, Content: nil, ToolCalls: []toolCall{
|
||||
{Id: "call_1", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
{Id: "call_2", Type: "function", Function: functionCall{Name: "get_time", Arguments: `{"timezone": "Asia/Shanghai"}`}},
|
||||
}},
|
||||
{Role: roleTool, ToolCallId: "call_1", Content: "Sunny, 25°C"},
|
||||
{Role: roleTool, ToolCallId: "call_2", Content: "3:00 PM"},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
// Should have 3 messages: user, assistant with 2 tool_use, user with 2 tool_results
|
||||
require.Len(t, claudeReq.Messages, 3)
|
||||
|
||||
// Assistant message should have 2 tool_use blocks
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 2)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[1].Type)
|
||||
|
||||
// User message should have 2 tool_result blocks merged
|
||||
assert.Equal(t, roleUser, claudeReq.Messages[2].Role)
|
||||
require.Len(t, claudeReq.Messages[2].Content.ArrayValue, 2)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "call_1", claudeReq.Messages[2].Content.ArrayValue[0].ToolUseId)
|
||||
assert.Equal(t, "tool_result", claudeReq.Messages[2].Content.ArrayValue[1].Type)
|
||||
assert.Equal(t, "call_2", claudeReq.Messages[2].Content.ArrayValue[1].ToolUseId)
|
||||
})
|
||||
|
||||
t.Run("handles_assistant_tool_calls_with_text_content", func(t *testing.T) {
|
||||
request := &chatCompletionRequest{
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
MaxTokens: 1024,
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "What's the weather?"},
|
||||
{Role: roleAssistant, Content: "Let me check the weather for you.", ToolCalls: []toolCall{
|
||||
{Id: "call_123", Type: "function", Function: functionCall{Name: "get_weather", Arguments: `{"city": "Beijing"}`}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
claudeReq := provider.buildClaudeTextGenRequest(request)
|
||||
|
||||
require.Len(t, claudeReq.Messages, 2)
|
||||
|
||||
// Assistant message should have both text and tool_use
|
||||
assert.Equal(t, roleAssistant, claudeReq.Messages[1].Role)
|
||||
require.False(t, claudeReq.Messages[1].Content.IsString)
|
||||
require.Len(t, claudeReq.Messages[1].Content.ArrayValue, 2)
|
||||
assert.Equal(t, contentTypeText, claudeReq.Messages[1].Content.ArrayValue[0].Type)
|
||||
assert.Equal(t, "Let me check the weather for you.", claudeReq.Messages[1].Content.ArrayValue[0].Text)
|
||||
assert.Equal(t, "tool_use", claudeReq.Messages[1].Content.ArrayValue[1].Type)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -119,6 +119,15 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, toolMsg)
|
||||
}
|
||||
// Also add text content if present alongside tool results
|
||||
// This handles cases like: [tool_result, tool_result, text]
|
||||
if len(conversionResult.textParts) > 0 {
|
||||
textMsg := chatMessage{
|
||||
Role: claudeMsg.Role,
|
||||
Content: strings.Join(conversionResult.textParts, "\n\n"),
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, textMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular content if no tool calls or tool results
|
||||
@@ -136,7 +145,8 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
if claudeRequest.System != nil {
|
||||
systemMsg := chatMessage{Role: roleSystem}
|
||||
if !claudeRequest.System.IsArray {
|
||||
systemMsg.Content = claudeRequest.System.StringValue
|
||||
// Strip dynamic cch field from billing header to enable caching
|
||||
systemMsg.Content = stripCchFromBillingHeader(claudeRequest.System.StringValue)
|
||||
} else {
|
||||
conversionResult := c.convertContentArray(claudeRequest.System.ArrayValue)
|
||||
systemMsg.Content = conversionResult.openaiContents
|
||||
@@ -177,13 +187,16 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
}
|
||||
|
||||
// Convert thinking configuration if present
|
||||
// Only set standard OpenAI fields (reasoning_effort). Non-standard fields like
|
||||
// "thinking" and "reasoning_max_tokens" are NOT set here because they are not
|
||||
// recognized by OpenAI/Azure and will cause 400 errors. Providers that need
|
||||
// these non-standard fields (e.g., ZhipuAI) should handle them in their own
|
||||
// OnRequestBody implementation.
|
||||
if claudeRequest.Thinking != nil {
|
||||
log.Debugf("[Claude->OpenAI] Found thinking config: type=%s, budget_tokens=%d",
|
||||
claudeRequest.Thinking.Type, claudeRequest.Thinking.BudgetTokens)
|
||||
|
||||
if claudeRequest.Thinking.Type == "enabled" {
|
||||
openaiRequest.ReasoningMaxTokens = claudeRequest.Thinking.BudgetTokens
|
||||
|
||||
// Set ReasoningEffort based on budget_tokens
|
||||
// low: <4096, medium: >=4096 and <16384, high: >=16384
|
||||
if claudeRequest.Thinking.BudgetTokens < 4096 {
|
||||
@@ -194,11 +207,9 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
|
||||
openaiRequest.ReasoningEffort = "high"
|
||||
}
|
||||
|
||||
log.Debugf("[Claude->OpenAI] Converted thinking config: budget_tokens=%d, reasoning_effort=%s, reasoning_max_tokens=%d",
|
||||
claudeRequest.Thinking.BudgetTokens, openaiRequest.ReasoningEffort, openaiRequest.ReasoningMaxTokens)
|
||||
log.Debugf("[Claude->OpenAI] Converted thinking config: budget_tokens=%d, reasoning_effort=%s",
|
||||
claudeRequest.Thinking.BudgetTokens, openaiRequest.ReasoningEffort)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("[Claude->OpenAI] No thinking config found")
|
||||
}
|
||||
|
||||
result, err := json.Marshal(openaiRequest)
|
||||
@@ -253,19 +264,21 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
emptySignature := ""
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature, use empty string
|
||||
Thinking: reasoningText,
|
||||
Signature: &emptySignature, // Use pointer for empty string
|
||||
Thinking: &reasoningText,
|
||||
})
|
||||
log.Debugf("[OpenAI->Claude] Added thinking content: %s", reasoningText)
|
||||
}
|
||||
|
||||
// Add text content if present
|
||||
if choice.Message.StringContent() != "" {
|
||||
textContent := choice.Message.StringContent()
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: choice.Message.StringContent(),
|
||||
Text: &textContent,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -288,7 +301,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
|
||||
Type: "tool_use",
|
||||
Id: toolCall.Id,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: input,
|
||||
Input: &input,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -338,7 +351,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &c.thinkingBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
if c.textBlockStarted && !c.textBlockStopped {
|
||||
c.textBlockStopped = true
|
||||
@@ -348,7 +361,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &c.textBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
// Send final content_block_stop events for any remaining unclosed tool calls
|
||||
for index, toolCall := range c.toolCallStates {
|
||||
@@ -360,7 +373,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Index: &toolCall.claudeContentIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", stopEvent.Type, stopData))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,12 +383,12 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
StopReason: c.pendingStopReason,
|
||||
StopReason: c.pendingStopReason,
|
||||
StopSequence: json.RawMessage("null"),
|
||||
},
|
||||
}
|
||||
stopData, _ := json.Marshal(messageDelta)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", messageDelta.Type, stopData))
|
||||
c.pendingStopReason = nil
|
||||
}
|
||||
|
||||
@@ -386,7 +399,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
Type: "message_stop",
|
||||
}
|
||||
stopData, _ := json.Marshal(messageStopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", messageStopEvent.Type, stopData))
|
||||
}
|
||||
|
||||
// Reset all state for next request
|
||||
@@ -515,13 +528,14 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
c.nextContentIndex++
|
||||
c.thinkingBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for thinking at index %d", c.thinkingBlockIndex)
|
||||
emptyStr := ""
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature
|
||||
Thinking: "",
|
||||
Signature: &emptyStr, // Use pointer for empty string output
|
||||
Thinking: &emptyStr, // Use pointer for empty string output
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -532,8 +546,8 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
Type: "content_block_delta",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "thinking_delta", // Use thinking_delta for reasoning content
|
||||
Text: reasoningText,
|
||||
Type: "thinking_delta",
|
||||
Thinking: reasoningText, // Use Thinking field, not Text
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -564,12 +578,13 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
c.nextContentIndex++
|
||||
c.textBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for text at index %d", c.textBlockIndex)
|
||||
emptyText := ""
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.textBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
Text: &emptyText,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -588,6 +603,30 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
|
||||
// Handle tool calls in streaming response
|
||||
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
|
||||
// Ensure message_start is sent before any content blocks
|
||||
if !c.messageStartSent {
|
||||
c.messageId = openaiResponse.Id
|
||||
c.messageStartSent = true
|
||||
message := &claudeTextGenResponse{
|
||||
Id: openaiResponse.Id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResponse.Model,
|
||||
Content: []claudeTextGenContent{},
|
||||
}
|
||||
if openaiResponse.Usage != nil {
|
||||
message.Usage = claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
OutputTokens: 0,
|
||||
}
|
||||
}
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "message_start",
|
||||
Message: message,
|
||||
})
|
||||
log.Debugf("[OpenAI->Claude] Generated message_start event before tool calls for id: %s", openaiResponse.Id)
|
||||
}
|
||||
|
||||
// Initialize toolCallStates if needed
|
||||
if c.toolCallStates == nil {
|
||||
c.toolCallStates = make(map[int]*toolCallInfo)
|
||||
@@ -722,7 +761,9 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
}
|
||||
|
||||
// Handle usage information
|
||||
if openaiResponse.Usage != nil && choice.FinishReason == nil {
|
||||
// Note: Some providers may send usage in the same chunk as finish_reason,
|
||||
// so we check for usage regardless of whether finish_reason is present
|
||||
if openaiResponse.Usage != nil {
|
||||
log.Debugf("[OpenAI->Claude] Processing usage info - input: %d, output: %d",
|
||||
openaiResponse.Usage.PromptTokens, openaiResponse.Usage.CompletionTokens)
|
||||
|
||||
@@ -730,7 +771,7 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
StopSequence: json.RawMessage("null"), // Explicit null per Claude spec
|
||||
},
|
||||
Usage: &claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
@@ -789,10 +830,12 @@ func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeCha
|
||||
switch claudeContent.Type {
|
||||
case "text":
|
||||
if claudeContent.Text != "" {
|
||||
result.textParts = append(result.textParts, claudeContent.Text)
|
||||
// Strip dynamic cch field from billing header to enable caching
|
||||
processedText := stripCchFromBillingHeader(claudeContent.Text)
|
||||
result.textParts = append(result.textParts, processedText)
|
||||
result.openaiContents = append(result.openaiContents, chatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: claudeContent.Text,
|
||||
Text: processedText,
|
||||
CacheControl: claudeContent.CacheControl,
|
||||
})
|
||||
}
|
||||
@@ -884,6 +927,7 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
toolState.claudeContentIndex, toolState.id, toolState.name)
|
||||
|
||||
// Send content_block_start
|
||||
emptyInput := map[string]interface{}{}
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &toolState.claudeContentIndex,
|
||||
@@ -891,7 +935,7 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
Type: "tool_use",
|
||||
Id: toolState.id,
|
||||
Name: toolState.name,
|
||||
Input: map[string]interface{}{}, // Empty input as per Claude spec
|
||||
Input: &emptyInput, // Empty input as per Claude spec
|
||||
},
|
||||
})
|
||||
|
||||
@@ -910,3 +954,42 @@ func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*clau
|
||||
|
||||
return responses
|
||||
}
|
||||
|
||||
// stripCchFromBillingHeader removes the dynamic cch field from x-anthropic-billing-header text
|
||||
// to enable caching. The cch value changes on every request, which would break prompt caching.
|
||||
// Example input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;"
|
||||
// Example output: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;"
|
||||
func stripCchFromBillingHeader(text string) string {
|
||||
const billingHeaderPrefix = "x-anthropic-billing-header:"
|
||||
|
||||
// Check if this is a billing header
|
||||
if !strings.HasPrefix(text, billingHeaderPrefix) {
|
||||
return text
|
||||
}
|
||||
|
||||
// Remove cch=xxx pattern (may appear with or without trailing semicolon)
|
||||
// Pattern: ; cch=<any-non-semicolon-chars> followed by ; or end of string
|
||||
result := text
|
||||
|
||||
// Try to find and remove ; cch=... pattern
|
||||
// We need to handle both "; cch=xxx;" and "; cch=xxx" (at end)
|
||||
for {
|
||||
cchIdx := strings.Index(result, "; cch=")
|
||||
if cchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the end of cch value (next semicolon or end of string)
|
||||
start := cchIdx + 2 // skip "; "
|
||||
end := strings.Index(result[start:], ";")
|
||||
if end == -1 {
|
||||
// cch is at the end, remove from "; cch=" to end
|
||||
result = result[:cchIdx]
|
||||
} else {
|
||||
// cch is followed by more content, remove "; cch=xxx" part
|
||||
result = result[:cchIdx] + result[start+end:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -388,6 +388,7 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
|
||||
t.Run("convert_tool_result_with_actual_error_data", func(t *testing.T) {
|
||||
// Test using the actual JSON data from the error log to ensure our fix works
|
||||
// This tests the fix for issue #3344 - text content alongside tool_result should be preserved
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
@@ -415,14 +416,20 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one tool message (the text content is included in the same message array)
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
// Should have two messages: tool message + user message with text content
|
||||
// This is the fix for issue #3344 - text content alongside tool_result is preserved
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// Should be tool message
|
||||
// First should be tool message
|
||||
toolMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "tool", toolMsg.Role)
|
||||
assert.Contains(t, toolMsg.Content, "three.js")
|
||||
assert.Equal(t, "toolu_vrtx_01UbCfwoTgoDBqbYEwkVaxd5", toolMsg.ToolCallId)
|
||||
|
||||
// Second should be user message with text content
|
||||
userMsg := openaiRequest.Messages[1]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
assert.Equal(t, "继续", userMsg.Content)
|
||||
})
|
||||
|
||||
t.Run("convert_multiple_tool_calls", func(t *testing.T) {
|
||||
@@ -617,7 +624,7 @@ func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
|
||||
// First content should be text
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", textContent.Text)
|
||||
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", *textContent.Text)
|
||||
|
||||
// Second content should be tool_use
|
||||
toolContent := claudeResponse.Content[1]
|
||||
@@ -627,7 +634,7 @@ func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
|
||||
|
||||
// Verify tool arguments
|
||||
require.NotNil(t, toolContent.Input)
|
||||
assert.Equal(t, "/Users/zhangty/git/higress/README.md", toolContent.Input["file_path"])
|
||||
assert.Equal(t, "/Users/zhangty/git/higress/README.md", (*toolContent.Input)["file_path"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -635,11 +642,9 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claudeRequest string
|
||||
expectedMaxTokens int
|
||||
expectedEffort string
|
||||
expectThinkingConfig bool
|
||||
name string
|
||||
claudeRequest string
|
||||
expectedEffort string
|
||||
}{
|
||||
{
|
||||
name: "thinking_enabled_low",
|
||||
@@ -649,9 +654,7 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 2048}
|
||||
}`,
|
||||
expectedMaxTokens: 2048,
|
||||
expectedEffort: "low",
|
||||
expectThinkingConfig: true,
|
||||
expectedEffort: "low",
|
||||
},
|
||||
{
|
||||
name: "thinking_enabled_medium",
|
||||
@@ -661,9 +664,7 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8192}
|
||||
}`,
|
||||
expectedMaxTokens: 8192,
|
||||
expectedEffort: "medium",
|
||||
expectThinkingConfig: true,
|
||||
expectedEffort: "medium",
|
||||
},
|
||||
{
|
||||
name: "thinking_enabled_high",
|
||||
@@ -673,9 +674,7 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 20480}
|
||||
}`,
|
||||
expectedMaxTokens: 20480,
|
||||
expectedEffort: "high",
|
||||
expectThinkingConfig: true,
|
||||
expectedEffort: "high",
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled",
|
||||
@@ -685,9 +684,7 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "disabled"}
|
||||
}`,
|
||||
expectedMaxTokens: 0,
|
||||
expectedEffort: "",
|
||||
expectThinkingConfig: false,
|
||||
expectedEffort: "",
|
||||
},
|
||||
{
|
||||
name: "no_thinking",
|
||||
@@ -696,9 +693,7 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}`,
|
||||
expectedMaxTokens: 0,
|
||||
expectedEffort: "",
|
||||
expectThinkingConfig: false,
|
||||
expectedEffort: "",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -712,13 +707,23 @@ func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tt.expectThinkingConfig {
|
||||
assert.Equal(t, tt.expectedMaxTokens, openaiRequest.ReasoningMaxTokens)
|
||||
assert.Equal(t, tt.expectedEffort, openaiRequest.ReasoningEffort)
|
||||
} else {
|
||||
assert.Equal(t, 0, openaiRequest.ReasoningMaxTokens)
|
||||
assert.Equal(t, "", openaiRequest.ReasoningEffort)
|
||||
}
|
||||
assert.Equal(t, tt.expectedEffort, openaiRequest.ReasoningEffort)
|
||||
|
||||
// Verify non-standard fields are NEVER set in the converted request.
|
||||
// These fields are not recognized by OpenAI/Azure and would cause 400 errors.
|
||||
assert.Equal(t, 0, openaiRequest.ReasoningMaxTokens,
|
||||
"reasoning_max_tokens must not be set - it is not a standard OpenAI parameter")
|
||||
assert.Nil(t, openaiRequest.Thinking,
|
||||
"thinking must not be set - it is not a standard OpenAI parameter")
|
||||
|
||||
// Also verify at the raw JSON level to catch any serialization issues
|
||||
var rawJSON map[string]interface{}
|
||||
err = json.Unmarshal(result, &rawJSON)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, rawJSON, "thinking",
|
||||
"raw JSON must not contain 'thinking' field")
|
||||
assert.NotContains(t, rawJSON, "reasoning_max_tokens",
|
||||
"raw JSON must not contain 'reasoning_max_tokens' field")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -830,21 +835,146 @@ func TestClaudeToOpenAIConverter_ConvertReasoningResponseToClaude(t *testing.T)
|
||||
// First should be thinking
|
||||
thinkingContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "thinking", thinkingContent.Type)
|
||||
assert.Equal(t, "", thinkingContent.Signature) // OpenAI doesn't provide signature
|
||||
assert.Contains(t, thinkingContent.Thinking, "Let me think about this step by step")
|
||||
require.NotNil(t, thinkingContent.Signature)
|
||||
assert.Equal(t, "", *thinkingContent.Signature) // OpenAI doesn't provide signature
|
||||
require.NotNil(t, thinkingContent.Thinking)
|
||||
assert.Contains(t, *thinkingContent.Thinking, "Let me think about this step by step")
|
||||
|
||||
// Second should be text
|
||||
textContent := claudeResponse.Content[1]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
require.NotNil(t, textContent.Text)
|
||||
assert.Equal(t, tt.expectedText, *textContent.Text)
|
||||
} else {
|
||||
// Should only have text content
|
||||
assert.Len(t, claudeResponse.Content, 1)
|
||||
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
require.NotNil(t, textContent.Text)
|
||||
assert.Equal(t, tt.expectedText, *textContent.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_StripCchFromSystemMessage(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
t.Run("string_system_with_billing_header", func(t *testing.T) {
|
||||
// Test that cch field is stripped from string format system message
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1024,
|
||||
"system": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;"
|
||||
}
|
||||
],
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}]
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// First message should be system with cch stripped
|
||||
systemMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "system", systemMsg.Role)
|
||||
|
||||
// The system content should have cch removed
|
||||
contentArray, ok := systemMsg.Content.([]interface{})
|
||||
require.True(t, ok, "System content should be an array")
|
||||
require.Len(t, contentArray, 1)
|
||||
|
||||
contentMap, ok := contentArray[0].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", contentMap["type"])
|
||||
assert.Equal(t, "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;", contentMap["text"])
|
||||
assert.NotContains(t, contentMap["text"], "cch=")
|
||||
})
|
||||
|
||||
t.Run("plain_string_system_unchanged", func(t *testing.T) {
|
||||
// Test that normal system messages are not modified
|
||||
claudeRequest := `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}]
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First message should be system with original content
|
||||
systemMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "system", systemMsg.Role)
|
||||
assert.Equal(t, "You are a helpful assistant.", systemMsg.Content)
|
||||
})
|
||||
}
|
||||
func TestStripCchFromBillingHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "billing header with cch at end",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "billing header with cch at end without trailing semicolon",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode; cch=abc123",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode",
|
||||
},
|
||||
{
|
||||
name: "billing header with cch in middle",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cch=abc123; cc_entrypoint=claude-vscode;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "billing header without cch",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
{
|
||||
name: "non-billing header text unchanged",
|
||||
input: "This is a normal system prompt",
|
||||
expected: "This is a normal system prompt",
|
||||
},
|
||||
{
|
||||
name: "empty string unchanged",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "billing header with multiple cch fields",
|
||||
input: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cch=first; cc_entrypoint=claude-vscode; cch=second;",
|
||||
expected: "x-anthropic-billing-header: cc_version=2.1.37.3a3; cc_entrypoint=claude-vscode;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := stripCchFromBillingHeader(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -198,6 +198,11 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders)
|
||||
}
|
||||
|
||||
// Apply providerBasePath if configured
|
||||
if c.providerBasePath != "" {
|
||||
modifiedHeaders.Set(":path", c.applyProviderBasePath(modifiedHeaders.Get(":path")))
|
||||
}
|
||||
|
||||
var err error
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body)
|
||||
@@ -605,7 +610,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
|
||||
if c.isFailoverEnabled() {
|
||||
apiToken = c.GetGlobalRandomToken()
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
apiToken = c.GetOrSetTokenWithContext(ctx)
|
||||
}
|
||||
log.Debugf("Use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
@@ -52,6 +52,8 @@ func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
}
|
||||
|
||||
func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// buffer original request body
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -30,39 +31,47 @@ const (
|
||||
)
|
||||
|
||||
type NonOpenAIStyleOptions struct {
|
||||
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
|
||||
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
|
||||
Thinking *thinkingParam `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type thinkingParam struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
BudgetToken int `json:"budget_token,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
NonOpenAIStyleOptions
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
PromptCacheRetention string `json:"prompt_cache_retention,omitempty"`
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getMaxTokens() int {
|
||||
@@ -246,6 +255,70 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
|
||||
}
|
||||
}
|
||||
|
||||
// promoteThinkingOnEmpty promotes reasoning_content to content when content is empty.
|
||||
// This handles models that put user-facing replies into thinking blocks instead of text blocks.
|
||||
func (r *chatCompletionResponse) promoteThinkingOnEmpty() {
|
||||
for i := range r.Choices {
|
||||
msg := r.Choices[i].Message
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if !isContentEmpty(msg.Content) {
|
||||
continue
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
msg.Content = msg.ReasoningContent
|
||||
msg.ReasoningContent = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// promoteStreamingThinkingOnEmpty accumulates reasoning content during streaming.
|
||||
// It strips reasoning from chunks and buffers it. When content is seen, it marks
|
||||
// the stream as having content so no promotion will happen.
|
||||
// Call PromoteStreamingThinkingFlush at the end of the stream to emit buffered
|
||||
// reasoning as content if no content was ever seen.
|
||||
// Returns true if the chunk was modified (reasoning stripped).
|
||||
func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool)
|
||||
if hasContentDelta {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isContentEmpty(msg.Content) {
|
||||
ctx.SetContext(ctxKeyHasContentDelta, true)
|
||||
return false
|
||||
}
|
||||
|
||||
// Buffer reasoning content and strip it from the chunk
|
||||
reasoning := msg.ReasoningContent
|
||||
if reasoning == "" {
|
||||
reasoning = msg.Reasoning
|
||||
}
|
||||
if reasoning != "" {
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
ctx.SetContext(ctxKeyBufferedReasoning, buffered+reasoning)
|
||||
msg.ReasoningContent = ""
|
||||
msg.Reasoning = ""
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isContentEmpty(content any) bool {
|
||||
switch v := content.(type) {
|
||||
case nil:
|
||||
return true
|
||||
case string:
|
||||
return strings.TrimSpace(v) == ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type chatMessageContent struct {
|
||||
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
@@ -455,6 +528,122 @@ type imageGenerationRequest struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type imageInputURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||
var rawURL string
|
||||
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||
i.URL = rawURL
|
||||
i.ImageURL = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type alias imageInputURL
|
||||
var value alias
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = imageInputURL(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imageInputURL) GetURL() string {
|
||||
if i == nil {
|
||||
return ""
|
||||
}
|
||||
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||
return i.ImageURL.Url
|
||||
}
|
||||
return i.URL
|
||||
}
|
||||
|
||||
type imageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Mask *imageInputURL `json:"mask,omitempty"`
|
||||
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) HasMask() bool {
|
||||
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||
return true
|
||||
}
|
||||
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||
}
|
||||
|
||||
type imageVariationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
type imageGenerationData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64 string `json:"b64_json,omitempty"`
|
||||
@@ -523,3 +712,87 @@ func (r embeddingsRequest) ParseInput() []string {
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// PromoteThinkingOnEmptyResponse promotes reasoning_content to content in a non-streaming
|
||||
// response body when content is empty. Returns the original body if no promotion is needed.
|
||||
func PromoteThinkingOnEmptyResponse(body []byte) ([]byte, error) {
|
||||
var resp chatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal response for thinking promotion: %v", err)
|
||||
}
|
||||
promoted := false
|
||||
for i := range resp.Choices {
|
||||
msg := resp.Choices[i].Message
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if !isContentEmpty(msg.Content) {
|
||||
continue
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
msg.Content = msg.ReasoningContent
|
||||
msg.ReasoningContent = ""
|
||||
promoted = true
|
||||
}
|
||||
}
|
||||
if !promoted {
|
||||
return body, nil
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingOnEmptyChunk buffers reasoning deltas and strips them from
|
||||
// the chunk during streaming. Call PromoteStreamingThinkingFlush on the last chunk
|
||||
// to emit buffered reasoning as content if no real content was ever seen.
|
||||
func PromoteStreamingThinkingOnEmptyChunk(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
|
||||
var resp chatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return data, nil // not a valid chat completion chunk, skip
|
||||
}
|
||||
modified := false
|
||||
for i := range resp.Choices {
|
||||
msg := resp.Choices[i].Delta
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if promoteStreamingThinkingOnEmpty(ctx, msg) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !modified {
|
||||
return data, nil
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingFlush checks if the stream had no content and returns
|
||||
// an SSE chunk that emits the buffered reasoning as content. Returns nil if
|
||||
// content was already seen or no reasoning was buffered.
|
||||
func PromoteStreamingThinkingFlush(ctx wrapper.HttpContext) []byte {
|
||||
hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool)
|
||||
if hasContentDelta {
|
||||
return nil
|
||||
}
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
if buffered == "" {
|
||||
return nil
|
||||
}
|
||||
// Build a minimal chat.completion.chunk with the buffered reasoning as content
|
||||
resp := chatCompletionResponse{
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{
|
||||
Content: buffered,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// Format as SSE
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
273
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
273
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var newMultipartWriter = func(w io.Writer) *multipart.Writer {
|
||||
return multipart.NewWriter(w)
|
||||
}
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Size string
|
||||
OutputFormat string
|
||||
N int
|
||||
ImageURLs []string
|
||||
HasMask bool
|
||||
}
|
||||
|
||||
func isMultipartFormData(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartBoundary(contentType string) (string, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return "", fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
return boundary, nil
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
ImageURLs: make([]string, 0),
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
fieldName := part.FormName()
|
||||
if fieldName == "" {
|
||||
_ = part.Close()
|
||||
continue
|
||||
}
|
||||
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||
|
||||
partData, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(partData))
|
||||
switch fieldName {
|
||||
case "model":
|
||||
req.Model = value
|
||||
continue
|
||||
case "prompt":
|
||||
req.Prompt = value
|
||||
continue
|
||||
case "size":
|
||||
req.Size = value
|
||||
continue
|
||||
case "output_format":
|
||||
req.OutputFormat = value
|
||||
continue
|
||||
case "n":
|
||||
if value != "" {
|
||||
if parsed, err := strconv.Atoi(value); err == nil {
|
||||
req.N = parsed
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isMultipartImageField(fieldName) {
|
||||
if isMultipartImageURLValue(value) {
|
||||
req.ImageURLs = append(req.ImageURLs, value)
|
||||
continue
|
||||
}
|
||||
if len(partData) == 0 {
|
||||
continue
|
||||
}
|
||||
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||
continue
|
||||
}
|
||||
if isMultipartMaskField(fieldName) {
|
||||
if len(partData) > 0 || value != "" {
|
||||
req.HasMask = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func extractMultipartModel(body []byte, contentType string) (string, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
model := ""
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
var readErr error
|
||||
if fieldName == "model" {
|
||||
var partData []byte
|
||||
partData, readErr = io.ReadAll(part)
|
||||
if readErr == nil {
|
||||
model = strings.TrimSpace(string(partData))
|
||||
}
|
||||
} else {
|
||||
_, readErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if readErr != nil {
|
||||
return "", fmt.Errorf("unable to read multipart field %s: %v", fieldName, readErr)
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func rewriteMultipartFormModel(body []byte, contentType string, model string) ([]byte, error) {
|
||||
boundary, err := parseMultipartBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := newMultipartWriter(&buffer)
|
||||
if err := writer.SetBoundary(boundary); err != nil {
|
||||
return nil, fmt.Errorf("unable to set multipart boundary: %v", err)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
modelFound := false
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
newPart, err := writer.CreatePart(cloneMultipartPartHeader(part.Header))
|
||||
if err != nil {
|
||||
_ = part.Close()
|
||||
return nil, fmt.Errorf("unable to create multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
var copyErr error
|
||||
if fieldName == "model" {
|
||||
modelFound = true
|
||||
if _, copyErr = io.WriteString(newPart, model); copyErr == nil {
|
||||
_, copyErr = io.Copy(io.Discard, part)
|
||||
}
|
||||
} else {
|
||||
_, copyErr = io.Copy(newPart, part)
|
||||
}
|
||||
_ = part.Close()
|
||||
if copyErr != nil {
|
||||
return nil, fmt.Errorf("unable to write multipart field %s: %v", fieldName, copyErr)
|
||||
}
|
||||
}
|
||||
|
||||
if !modelFound && model != "" {
|
||||
if err := writer.WriteField("model", model); err != nil {
|
||||
return nil, fmt.Errorf("unable to append multipart model field: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to finalize multipart body: %v", err)
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func cloneMultipartPartHeader(header textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
cloned := make(textproto.MIMEHeader, len(header))
|
||||
for key, values := range header {
|
||||
copied := make([]string, len(values))
|
||||
copy(copied, values)
|
||||
cloned[key] = copied
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
func isMultipartMaskField(fieldName string) bool {
|
||||
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||
}
|
||||
|
||||
func isMultipartImageURLValue(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
loweredValue := strings.ToLower(value)
|
||||
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||
}
|
||||
|
||||
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||
mimeType := strings.TrimSpace(contentType)
|
||||
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||
mimeType = http.DetectContentType(data)
|
||||
}
|
||||
mimeType = normalizeMultipartMimeType(mimeType)
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
}
|
||||
|
||||
func normalizeMultipartMimeType(contentType string) string {
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && mediaType != "" {
|
||||
return strings.TrimSpace(mediaType)
|
||||
}
|
||||
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||
return strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/iface"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockMultipartHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func newMockMultipartHttpContext() *mockMultipartHttpContext {
|
||||
return &mockMultipartHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *mockMultipartHttpContext) SetContext(key string, value interface{}) {
|
||||
m.contextMap[key] = value
|
||||
}
|
||||
func (m *mockMultipartHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *mockMultipartHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *mockMultipartHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *mockMultipartHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *mockMultipartHttpContext) Scheme() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Host() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Path() string { return "" }
|
||||
func (m *mockMultipartHttpContext) Method() string { return "" }
|
||||
func (m *mockMultipartHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *mockMultipartHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *mockMultipartHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *mockMultipartHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *mockMultipartHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *mockMultipartHttpContext) DontReadRequestBody() {}
|
||||
func (m *mockMultipartHttpContext) DontReadResponseBody() {}
|
||||
func (m *mockMultipartHttpContext) BufferRequestBody() {}
|
||||
func (m *mockMultipartHttpContext) BufferResponseBody() {}
|
||||
func (m *mockMultipartHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *mockMultipartHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *mockMultipartHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *mockMultipartHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *mockMultipartHttpContext) DisableReroute() {}
|
||||
func (m *mockMultipartHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockMultipartHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockMultipartHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockMultipartHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *mockMultipartHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *mockMultipartHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
|
||||
func buildProviderMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := multipart.NewWriter(&buffer)
|
||||
|
||||
for key, value := range fields {
|
||||
require.NoError(t, writer.WriteField(key, value))
|
||||
}
|
||||
for fieldName, data := range files {
|
||||
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, writer.Close())
|
||||
return buffer.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
type failAfterNWriteWriter struct {
|
||||
target io.Writer
|
||||
failAtCall int
|
||||
writeCalls int
|
||||
}
|
||||
|
||||
func (w *failAfterNWriteWriter) Write(p []byte) (int, error) {
|
||||
w.writeCalls++
|
||||
if w.writeCalls >= w.failAtCall {
|
||||
return 0, errors.New("injected write failure")
|
||||
}
|
||||
return w.target.Write(p)
|
||||
}
|
||||
|
||||
func withInjectedMultipartWriterFactory(t *testing.T, failAtCall int, testFunc func()) {
|
||||
t.Helper()
|
||||
|
||||
originalFactory := newMultipartWriter
|
||||
newMultipartWriter = func(target io.Writer) *multipart.Writer {
|
||||
return multipart.NewWriter(&failAfterNWriteWriter{
|
||||
target: target,
|
||||
failAtCall: failAtCall,
|
||||
})
|
||||
}
|
||||
defer func() {
|
||||
newMultipartWriter = originalFactory
|
||||
}()
|
||||
|
||||
testFunc()
|
||||
}
|
||||
|
||||
func TestRewriteMultipartFormModel(t *testing.T) {
|
||||
t.Run("rewrites existing model field", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
assert.Contains(t, string(transformed), "fake-image-content")
|
||||
})
|
||||
|
||||
t.Run("appends model field when missing", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
transformed, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
})
|
||||
|
||||
t.Run("returns error on invalid content type", func(t *testing.T) {
|
||||
_, err := rewriteMultipartFormModel([]byte("not-multipart"), "multipart/form-data", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing multipart boundary")
|
||||
})
|
||||
|
||||
t.Run("returns error when boundary cannot be set", func(t *testing.T) {
|
||||
longBoundary := strings.Repeat("a", 71)
|
||||
_, err := rewriteMultipartFormModel([]byte(""), "multipart/form-data; boundary="+longBoundary, "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to set multipart boundary")
|
||||
})
|
||||
|
||||
t.Run("returns error on malformed multipart header", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart part")
|
||||
})
|
||||
|
||||
t.Run("returns error when multipart part copy fails", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"image\"; filename=\"a.png\"\r\nContent-Type: image/png\r\n\r\nabc\r\n--ab")
|
||||
_, err := rewriteMultipartFormModel(body, "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to write multipart field image")
|
||||
})
|
||||
|
||||
t.Run("returns error when creating rewritten multipart part fails", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel(body, contentType, "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to create multipart field")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("returns error when appending model field fails", func(t *testing.T) {
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "gpt-image-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to append multipart model field")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("returns error when finalizing multipart body fails", func(t *testing.T) {
|
||||
withInjectedMultipartWriterFactory(t, 1, func() {
|
||||
_, err := rewriteMultipartFormModel([]byte("--abc--\r\n"), "multipart/form-data; boundary=abc", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to finalize multipart body")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultTransformMultipartRequestBody(t *testing.T) {
|
||||
t.Run("maps multipart model and keeps body valid", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"gpt-image-1.5": "gpt-image-1",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1.5", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
assert.Equal(t, "Turn the dog white", req.Prompt)
|
||||
assert.Len(t, req.ImageURLs, 1)
|
||||
assert.Contains(t, string(transformed), "fake-image-content")
|
||||
})
|
||||
|
||||
t.Run("appends mapped model when multipart request omits model", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"*": "gpt-image-1",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageVariation, body, contentType)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := parseMultipartImageRequest(transformed, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", req.Model)
|
||||
})
|
||||
|
||||
t.Run("returns original body when multipart model is unchanged", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1",
|
||||
"prompt": "Turn the dog white",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
config := &ProviderConfig{}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, transformed)
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Equal(t, "gpt-image-1", ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
})
|
||||
|
||||
t.Run("ignores non image multipart apis", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1",
|
||||
}, nil)
|
||||
|
||||
config := &ProviderConfig{
|
||||
modelMapping: map[string]string{
|
||||
"gpt-image-1": "mapped-model",
|
||||
},
|
||||
}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
transformed, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameChatCompletion, body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, transformed)
|
||||
assert.Nil(t, ctx.GetContext(ctxKeyOriginalRequestModel))
|
||||
assert.Nil(t, ctx.GetContext(ctxKeyFinalRequestModel))
|
||||
})
|
||||
|
||||
t.Run("surfaces multipart parse errors", func(t *testing.T) {
|
||||
config := &ProviderConfig{}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
|
||||
_, err := config.defaultTransformMultipartRequestBody(ctx, ApiNameImageEdit, []byte("bad-body"), "multipart/form-data")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing multipart boundary")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractMultipartModel(t *testing.T) {
|
||||
t.Run("extracts model value", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "Turn the dog white",
|
||||
}, nil)
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-image-1.5", model)
|
||||
})
|
||||
|
||||
t.Run("returns empty model when field missing", func(t *testing.T) {
|
||||
body, contentType := buildProviderMultipartRequestBody(t, map[string]string{
|
||||
"prompt": "Turn the dog white",
|
||||
}, nil)
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", model)
|
||||
})
|
||||
|
||||
t.Run("returns parse error for invalid content type", func(t *testing.T) {
|
||||
_, err := extractMultipartModel([]byte("bad-body"), "multipart/form-data; boundary=\"")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to parse content-type")
|
||||
})
|
||||
|
||||
t.Run("returns parse error for malformed multipart header", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart part")
|
||||
})
|
||||
|
||||
t.Run("returns field read error on truncated model part", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
|
||||
_, err := extractMultipartModel(body, "multipart/form-data; boundary=abc")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to read multipart field model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseMultipartImageRequestContentTypeError(t *testing.T) {
|
||||
_, err := parseMultipartImageRequest([]byte("bad-body"), "multipart/form-data; boundary=\"")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unable to parse content-type")
|
||||
}
|
||||
|
||||
func TestIsMultipartFormData(t *testing.T) {
|
||||
assert.True(t, isMultipartFormData("multipart/form-data; boundary=abc"))
|
||||
assert.False(t, isMultipartFormData("application/json"))
|
||||
assert.False(t, isMultipartFormData("multipart/form-data; boundary=\""))
|
||||
}
|
||||
@@ -34,6 +34,9 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameImageEdit): PathOpenAIImageEdit,
|
||||
string(ApiNameImageVariation): PathOpenAIImageVariation,
|
||||
string(ApiNameAudioSpeech): PathOpenAIAudioSpeech,
|
||||
string(ApiNameAudioTranscription): PathOpenAIAudioTranscriptions,
|
||||
string(ApiNameAudioTranslation): PathOpenAIAudioTranslations,
|
||||
string(ApiNameRealtime): PathOpenAIRealtime,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
string(ApiNameFiles): PathOpenAIFiles,
|
||||
string(ApiNameRetrieveFile): PathOpenAIRetrieveFile,
|
||||
@@ -63,6 +66,8 @@ func isDirectPath(path string) bool {
|
||||
return strings.HasSuffix(path, "/completions") ||
|
||||
strings.HasSuffix(path, "/embeddings") ||
|
||||
strings.HasSuffix(path, "/audio/speech") ||
|
||||
strings.HasSuffix(path, "/audio/transcriptions") ||
|
||||
strings.HasSuffix(path, "/audio/translations") ||
|
||||
strings.HasSuffix(path, "/images/generations") ||
|
||||
strings.HasSuffix(path, "/images/variations") ||
|
||||
strings.HasSuffix(path, "/images/edits") ||
|
||||
@@ -70,6 +75,7 @@ func isDirectPath(path string) bool {
|
||||
strings.HasSuffix(path, "/responses") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/jobs") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/checkpoints") ||
|
||||
strings.HasSuffix(path, "/realtime") ||
|
||||
strings.HasSuffix(path, "/videos")
|
||||
}
|
||||
|
||||
|
||||
@@ -79,6 +79,22 @@ func (o *openrouterProvider) TransformRequestBody(ctx wrapper.HttpContext, apiNa
|
||||
// Check if ReasoningMaxTokens exists in the request body
|
||||
reasoningMaxTokens := gjson.GetBytes(body, "reasoning_max_tokens")
|
||||
if !reasoningMaxTokens.Exists() || reasoningMaxTokens.Int() == 0 {
|
||||
// Check if budget_tokens was stored in context (from Claude auto-conversion path)
|
||||
// Only use it when thinking was explicitly enabled, to avoid dirty input
|
||||
if thinkingType, _ := ctx.GetContext(ctxKeyClaudeThinkingType).(string); thinkingType == "enabled" {
|
||||
if budgetTokens, ok := ctx.GetContext(ctxKeyClaudeBudgetTokens).(int); ok && budgetTokens > 0 {
|
||||
// Use budget_tokens from Claude thinking config
|
||||
modifiedBody, err := sjson.DeleteBytes(body, "reasoning_effort")
|
||||
if err != nil {
|
||||
modifiedBody = body
|
||||
}
|
||||
modifiedBody, err = sjson.SetBytes(modifiedBody, "reasoning.max_tokens", budgetTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.config.defaultTransformRequestBody(ctx, apiName, modifiedBody)
|
||||
}
|
||||
}
|
||||
// No reasoning_max_tokens, use default transformation
|
||||
return o.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path"
|
||||
@@ -39,6 +41,9 @@ const (
|
||||
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription"
|
||||
ApiNameAudioTranslation ApiName = "openai/v1/audiotranslation"
|
||||
ApiNameRealtime ApiName = "openai/v1/realtime"
|
||||
ApiNameFiles ApiName = "openai/v1/files"
|
||||
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
|
||||
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
|
||||
@@ -88,6 +93,9 @@ const (
|
||||
PathOpenAIImageEdit = "/v1/images/edits"
|
||||
PathOpenAIImageVariation = "/v1/images/variations"
|
||||
PathOpenAIAudioSpeech = "/v1/audio/speech"
|
||||
PathOpenAIAudioTranscriptions = "/v1/audio/transcriptions"
|
||||
PathOpenAIAudioTranslations = "/v1/audio/translations"
|
||||
PathOpenAIRealtime = "/v1/realtime"
|
||||
PathOpenAIResponses = "/v1/responses"
|
||||
PathOpenAIFineTuningJobs = "/v1/fine_tuning/jobs"
|
||||
PathOpenAIRetrieveFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}"
|
||||
@@ -151,6 +159,7 @@ const (
|
||||
protocolOriginal = "original"
|
||||
|
||||
roleSystem = "system"
|
||||
roleDeveloper = "developer"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
roleTool = "tool"
|
||||
@@ -159,6 +168,8 @@ const (
|
||||
finishReasonLength = "length"
|
||||
finishReasonToolCall = "tool_calls"
|
||||
|
||||
ctxKeyClaudeBudgetTokens = "claudeBudgetTokens"
|
||||
ctxKeyClaudeThinkingType = "claudeThinkingType"
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
CtxKeyApiName = "apiName"
|
||||
@@ -169,6 +180,8 @@ const (
|
||||
ctxKeyPushedMessage = "pushedMessage"
|
||||
ctxKeyContentPushed = "contentPushed"
|
||||
ctxKeyReasoningContentPushed = "reasoningContentPushed"
|
||||
ctxKeyHasContentDelta = "hasContentDelta"
|
||||
ctxKeyBufferedReasoning = "bufferedReasoning"
|
||||
|
||||
objectChatCompletion = "chat.completion"
|
||||
objectChatCompletionChunk = "chat.completion.chunk"
|
||||
@@ -193,6 +206,11 @@ type providerInitializer interface {
|
||||
var (
|
||||
errUnsupportedApiName = errors.New("unsupported API name")
|
||||
|
||||
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
||||
developerRoleSupportedProviders = map[string]bool{
|
||||
providerTypeAzure: true,
|
||||
}
|
||||
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||||
providerTypeAzure: &azureProviderInitializer{},
|
||||
@@ -346,6 +364,12 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Amazon Bedrock 额外模型请求参数
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务,用于设置模型特定的推理参数
|
||||
bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
||||
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt Cache 保留策略(默认值)
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。作为请求中 prompt_cache_retention 缺省时的默认值,支持 in_memory 和 24h。
|
||||
promptCacheRetention string `required:"false" yaml:"promptCacheRetention" json:"promptCacheRetention"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
@@ -445,6 +469,27 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Claude Code 模式
|
||||
// @Description zh-CN 仅适用于Claude服务。启用后将伪装成Claude Code客户端发起请求,支持使用Claude Code的OAuth Token进行认证。
|
||||
claudeCodeMode bool `required:"false" yaml:"claudeCodeMode" json:"claudeCodeMode"`
|
||||
// @Title zh-CN 智谱AI服务域名
|
||||
// @Description zh-CN 仅适用于智谱AI服务。默认为 open.bigmodel.cn(中国),可配置为 api.z.ai(国际)
|
||||
zhipuDomain string `required:"false" yaml:"zhipuDomain" json:"zhipuDomain"`
|
||||
// @Title zh-CN 智谱AI Code Plan 模式
|
||||
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
||||
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
||||
// @Title zh-CN 合并连续同角色消息
|
||||
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||
// @Title zh-CN 通用 Provider 域名
|
||||
// @Description zh-CN 通用的 Provider 服务域名配置,适用于所有 Provider。当配置此字段时,将优先使用此域名覆盖默认的硬编码域名。常用于代理服务器场景
|
||||
providerDomain string `required:"false" yaml:"providerDomain" json:"providerDomain"`
|
||||
// @Title zh-CN 空内容时提升思考为正文
|
||||
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
|
||||
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
|
||||
// @Title zh-CN HiClaw 模式
|
||||
// @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。
|
||||
hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"`
|
||||
// @Title zh-CN Provider 基础路径
|
||||
// @Description zh-CN 当配置了此值时,各个 Provider 在改写请求路径时会将其添加到路径前面,例如配置"/api/ai"后,请求路径"/v1/chat/completions"会被改写为"/api/ai/v1/chat/completions"
|
||||
providerBasePath string `required:"false" yaml:"providerBasePath" json:"providerBasePath"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -538,6 +583,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||
c.bedrockAdditionalFields[k] = v.Value()
|
||||
}
|
||||
c.promptCacheRetention = json.Get("promptCacheRetention").String()
|
||||
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
||||
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
||||
for k, v := range rawPositions.Map() {
|
||||
c.bedrockPromptCachePointPositions[k] = v.Bool()
|
||||
}
|
||||
}
|
||||
}
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
@@ -632,6 +684,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
string(ApiNameImageVariation),
|
||||
string(ApiNameImageEdit),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameAudioTranscription),
|
||||
string(ApiNameAudioTranslation),
|
||||
string(ApiNameRealtime),
|
||||
string(ApiNameResponses),
|
||||
string(ApiNameCohereV1Rerank),
|
||||
string(ApiNameVideos),
|
||||
string(ApiNameRetrieveVideo),
|
||||
@@ -650,12 +706,23 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
|
||||
c.doubaoDomain = json.Get("doubaoDomain").String()
|
||||
c.claudeCodeMode = json.Get("claudeCodeMode").Bool()
|
||||
c.zhipuDomain = json.Get("zhipuDomain").String()
|
||||
c.zhipuCodePlanMode = json.Get("zhipuCodePlanMode").Bool()
|
||||
c.contextCleanupCommands = make([]string, 0)
|
||||
for _, cmd := range json.Get("contextCleanupCommands").Array() {
|
||||
if cmd.String() != "" {
|
||||
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
||||
}
|
||||
}
|
||||
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||
c.providerDomain = json.Get("providerDomain").String()
|
||||
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
|
||||
c.hiclawMode = json.Get("hiclawMode").Bool()
|
||||
if c.hiclawMode {
|
||||
c.mergeConsecutiveMessages = true
|
||||
c.promoteThinkingOnEmpty = true
|
||||
}
|
||||
c.providerBasePath = json.Get("providerBasePath").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -690,12 +757,45 @@ func (c *ProviderConfig) Validate() error {
|
||||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||||
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
|
||||
if ctxApiKey == nil {
|
||||
ctxApiKey = c.GetRandomToken()
|
||||
token := c.selectApiToken(ctx)
|
||||
ctxApiKey = token
|
||||
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
|
||||
}
|
||||
return ctxApiKey.(string)
|
||||
}
|
||||
|
||||
// selectApiToken selects an API token based on the request context
|
||||
// For stateful APIs, it uses consumer affinity if available
|
||||
func (c *ProviderConfig) selectApiToken(ctx wrapper.HttpContext) string {
|
||||
// Get API name from context if available
|
||||
ctxApiName := ctx.GetContext(CtxKeyApiName)
|
||||
var apiName string
|
||||
if ctxApiName != nil {
|
||||
// ctxApiName is of type ApiName, need to convert to string
|
||||
apiName = string(ctxApiName.(ApiName))
|
||||
}
|
||||
|
||||
// For stateful APIs, try to use consumer affinity
|
||||
if isStatefulAPI(apiName) {
|
||||
consumer := c.getConsumerFromContext(ctx)
|
||||
if consumer != "" {
|
||||
return c.GetTokenWithConsumerAffinity(ctx, consumer)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to random selection
|
||||
return c.GetRandomToken()
|
||||
}
|
||||
|
||||
// getConsumerFromContext retrieves the consumer identifier from the request context
|
||||
func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string {
|
||||
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
|
||||
if err == nil && consumer != "" {
|
||||
return consumer
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetRandomToken() string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
@@ -709,10 +809,62 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
}
|
||||
}
|
||||
|
||||
// isStatefulAPI checks if the given API name is a stateful API that requires consumer affinity
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
|
||||
// GetTokenWithConsumerAffinity selects an API token based on consumer affinity
|
||||
// If x-mse-consumer header is present and API is stateful, it will consistently select the same token
|
||||
func (c *ProviderConfig) GetTokenWithConsumerAffinity(ctx wrapper.HttpContext, consumer string) string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
switch count {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return apiTokens[0]
|
||||
default:
|
||||
// Use FNV-1a hash for consistent token selection
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(consumer))
|
||||
hashValue := h.Sum32()
|
||||
index := int(hashValue) % count
|
||||
if index < 0 {
|
||||
index += count
|
||||
}
|
||||
return apiTokens[index]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsGeneric() bool {
|
||||
return c.typ == providerTypeGeneric
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool {
|
||||
return c.promoteThinkingOnEmpty
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
|
||||
return ReplaceByCustomSettings(body, c.customSettings)
|
||||
}
|
||||
@@ -725,6 +877,14 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
// applyProviderBasePath prepends the ProviderBasePath to the given path if configured.
|
||||
func (c *ProviderConfig) applyProviderBasePath(path string) string {
|
||||
if c.providerBasePath != "" && !strings.HasPrefix(path, c.providerBasePath) {
|
||||
return c.providerBasePath + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
@@ -751,6 +911,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageEditRequest:
|
||||
if err := decodeImageEditRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageVariationRequest:
|
||||
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -766,6 +936,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
model = &req.Model
|
||||
case *imageGenerationRequest:
|
||||
model = &req.Model
|
||||
case *imageEditRequest:
|
||||
model = &req.Model
|
||||
case *imageVariationRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -838,6 +1012,34 @@ func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// isDeveloperRoleSupported checks if the provider supports the "developer" role.
|
||||
func isDeveloperRoleSupported(providerType string) bool {
|
||||
return developerRoleSupportedProviders[providerType]
|
||||
}
|
||||
|
||||
// convertDeveloperRoleToSystem converts "developer" roles to "system" role in the request body.
|
||||
// This is used for providers that don't support the "developer" role.
|
||||
func convertDeveloperRoleToSystem(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for developer role conversion: %v", err)
|
||||
}
|
||||
|
||||
converted := false
|
||||
for i := range request.Messages {
|
||||
if request.Messages[i].Role == roleDeveloper {
|
||||
request.Messages[i].Role = roleSystem
|
||||
converted = true
|
||||
}
|
||||
}
|
||||
|
||||
if converted {
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
@@ -898,7 +1100,7 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent
|
||||
if lineStartIndex != -1 {
|
||||
value := string(body[valueStartIndex:i])
|
||||
currentEvent.SetValue(currentKey, value)
|
||||
} else {
|
||||
} else if eventStartIndex != -1 {
|
||||
currentEvent.RawEvent = string(body[eventStartIndex : i+1])
|
||||
// Extra new line. The current event is complete.
|
||||
events = append(events, *currentEvent)
|
||||
@@ -957,6 +1159,21 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
// If main.go detected a Claude request that needs conversion, convert the body
|
||||
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
||||
if needClaudeConversion {
|
||||
// Extract thinking config from original Claude body before conversion,
|
||||
// so downstream providers (OpenRouter, ZhipuAI) can access it.
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType == "" {
|
||||
// Claude request had no thinking field at all - treat as disabled
|
||||
thinkingType = "disabled"
|
||||
}
|
||||
ctx.SetContext(ctxKeyClaudeThinkingType, thinkingType)
|
||||
// Only extract budget_tokens when thinking is explicitly enabled
|
||||
if thinkingType == "enabled" {
|
||||
if budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int(); budgetTokens > 0 {
|
||||
ctx.SetContext(ctxKeyClaudeBudgetTokens, int(budgetTokens))
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Claude protocol to OpenAI protocol
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
body, err = converter.ConvertClaudeRequestToOpenAI(body)
|
||||
@@ -976,12 +1193,39 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// merge consecutive same-role messages for providers that require strict role alternation
|
||||
if apiName == ApiNameChatCompletion && c.mergeConsecutiveMessages {
|
||||
body, err = mergeConsecutiveMessages(body)
|
||||
if err != nil {
|
||||
log.Warnf("[mergeConsecutiveMessages] failed to merge messages: %v", err)
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[mergeConsecutiveMessages] merged consecutive messages for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// convert developer role to system role for providers that don't support it
|
||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||
body, err = convertDeveloperRoleToSystem(body)
|
||||
if err != nil {
|
||||
log.Warnf("[developerRole] failed to convert developer role to system: %v", err)
|
||||
// Continue processing even if conversion fails
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[developerRole] converted developer role to system for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// use openai protocol (either original openai or converted from claude)
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
// Apply providerBasePath if configured
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(headers.Get(":path")))
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
|
||||
@@ -1038,11 +1282,27 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
||||
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
||||
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
||||
}
|
||||
|
||||
// Apply providerBasePath if configured
|
||||
currentPath := headers.Get(":path")
|
||||
if c.providerBasePath != "" {
|
||||
headers.Set(":path", c.applyProviderBasePath(currentPath))
|
||||
}
|
||||
|
||||
// Apply providerDomain if configured (overrides any domain set by the provider)
|
||||
if c.providerDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, c.providerDomain)
|
||||
}
|
||||
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if contentType, err := proxywasm.GetHttpRequestHeader(util.HeaderContentType); err == nil && isMultipartFormData(contentType) {
|
||||
return c.defaultTransformMultipartRequestBody(ctx, apiName, body, contentType)
|
||||
}
|
||||
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion,
|
||||
ApiNameVideos,
|
||||
@@ -1062,6 +1322,28 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
||||
return sjson.SetBytes(body, "model", mappedModel)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) defaultTransformMultipartRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, contentType string) ([]byte, error) {
|
||||
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
model, err := extractMultipartModel(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
|
||||
mappedModel := getMappedModel(model, c.modelMapping)
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
|
||||
if mappedModel == model || (mappedModel == "" && model == "") {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return rewriteMultipartFormModel(body, contentType, mappedModel)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
if c.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
|
||||
680
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
680
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
@@ -0,0 +1,680 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestIsStatefulAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName string
|
||||
expected bool
|
||||
}{
|
||||
// Stateful APIs - should return true
|
||||
{
|
||||
name: "responses_api",
|
||||
apiName: string(ApiNameResponses),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "files_api",
|
||||
apiName: string(ApiNameFiles),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_api",
|
||||
apiName: string(ApiNameRetrieveFile),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_content_api",
|
||||
apiName: string(ApiNameRetrieveFileContent),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "batches_api",
|
||||
apiName: string(ApiNameBatches),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_batch_api",
|
||||
apiName: string(ApiNameRetrieveBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_batch_api",
|
||||
apiName: string(ApiNameCancelBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_jobs_api",
|
||||
apiName: string(ApiNameFineTuningJobs),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_fine_tuning_job_api",
|
||||
apiName: string(ApiNameRetrieveFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_events_api",
|
||||
apiName: string(ApiNameFineTuningJobEvents),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_checkpoints_api",
|
||||
apiName: string(ApiNameFineTuningJobCheckpoints),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_fine_tuning_job_api",
|
||||
apiName: string(ApiNameCancelFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "resume_fine_tuning_job_api",
|
||||
apiName: string(ApiNameResumeFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
// Non-stateful APIs - should return false
|
||||
{
|
||||
name: "chat_completion_api",
|
||||
apiName: string(ApiNameChatCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completion_api",
|
||||
apiName: string(ApiNameCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "embeddings_api",
|
||||
apiName: string(ApiNameEmbeddings),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "models_api",
|
||||
apiName: string(ApiNameModels),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "image_generation_api",
|
||||
apiName: string(ApiNameImageGeneration),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "audio_speech_api",
|
||||
apiName: string(ApiNameAudioSpeech),
|
||||
expected: false,
|
||||
},
|
||||
// Empty/unknown API - should return false
|
||||
{
|
||||
name: "empty_api_name",
|
||||
apiName: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unknown_api_name",
|
||||
apiName: "unknown/api",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isStatefulAPI(tt.apiName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiTokens []string
|
||||
consumer string
|
||||
wantEmpty bool
|
||||
wantToken string // If not empty, expected specific token (for single token case)
|
||||
}{
|
||||
{
|
||||
name: "no_tokens_returns_empty",
|
||||
apiTokens: []string{},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "nil_tokens_returns_empty",
|
||||
apiTokens: nil,
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single_token_always_returns_same_token",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer1",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "single_token_with_different_consumer",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer2",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_consistent_for_same_consumer",
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: false, // Will get one of the tokens, consistently
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_different_consumers_may_get_different_tokens",
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
consumer: "consumerA",
|
||||
wantEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: tt.apiTokens,
|
||||
}
|
||||
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tt.consumer)
|
||||
|
||||
if tt.wantEmpty {
|
||||
assert.Empty(t, result)
|
||||
} else if tt.wantToken != "" {
|
||||
assert.Equal(t, tt.wantToken, result)
|
||||
} else {
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, tt.apiTokens, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_Consistency(t *testing.T) {
|
||||
// Test that the same consumer always gets the same token (consistency)
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3", "token4", "token5"},
|
||||
}
|
||||
|
||||
t.Run("same_consumer_gets_same_token_repeatedly", func(t *testing.T) {
|
||||
consumer := "test-consumer"
|
||||
var firstResult string
|
||||
|
||||
// Call multiple times and verify consistency
|
||||
for i := 0; i < 10; i++ {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
if i == 0 {
|
||||
firstResult = result
|
||||
}
|
||||
assert.Equal(t, firstResult, result, "Consumer should consistently get the same token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different_consumers_distribute_across_tokens", func(t *testing.T) {
|
||||
// Use multiple consumers and verify they distribute across tokens
|
||||
consumers := []string{"consumer1", "consumer2", "consumer3", "consumer4", "consumer5", "consumer6", "consumer7", "consumer8", "consumer9", "consumer10"}
|
||||
tokenCounts := make(map[string]int)
|
||||
|
||||
for _, consumer := range consumers {
|
||||
token := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
tokenCounts[token]++
|
||||
}
|
||||
|
||||
// Verify all tokens returned are valid
|
||||
for token := range tokenCounts {
|
||||
assert.Contains(t, config.apiTokens, token)
|
||||
}
|
||||
|
||||
// With 10 consumers and 5 tokens, we expect some distribution
|
||||
// (not necessarily perfect distribution, but should use multiple tokens)
|
||||
assert.GreaterOrEqual(t, len(tokenCounts), 2, "Should use at least 2 different tokens")
|
||||
})
|
||||
|
||||
t.Run("empty_consumer_returns_empty_string", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
}
|
||||
result := config.GetTokenWithConsumerAffinity(nil, "")
|
||||
// Empty consumer still returns a token (hash of empty string)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, []string{"token1", "token2"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
|
||||
// Test that the hash function distributes consumers reasonably across tokens
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
}
|
||||
|
||||
// Test specific consumers to verify hash behavior
|
||||
testCases := []struct {
|
||||
consumer string
|
||||
expectValid bool
|
||||
}{
|
||||
{"user-alice", true},
|
||||
{"user-bob", true},
|
||||
{"user-charlie", true},
|
||||
{"service-api-v1", true},
|
||||
{"service-api-v2", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("consumer_"+tc.consumer, func(t *testing.T) {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tc.consumer)
|
||||
assert.True(t, tc.expectValid)
|
||||
assert.Contains(t, config.apiTokens, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderDomain_Config(t *testing.T) {
|
||||
t.Run("providerDomain_field_exists", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
config.FromJson(gjson.Result{})
|
||||
assert.Equal(t, "", config.providerDomain)
|
||||
})
|
||||
|
||||
t.Run("providerDomain_parsed_from_json", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{"providerDomain": "universal-proxy.example.com"}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "universal-proxy.example.com", config.providerDomain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderBasePath_Config(t *testing.T) {
|
||||
t.Run("providerBasePath_field_exists", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
config.FromJson(gjson.Result{})
|
||||
assert.Equal(t, "", config.providerBasePath)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_parsed_from_json", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{"providerBasePath": "/api/ai"}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "/api/ai", config.providerBasePath)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_with_other_config", func(t *testing.T) {
|
||||
config := ProviderConfig{}
|
||||
jsonStr := `{
|
||||
"type": "openai",
|
||||
"apiToken": "sk-test",
|
||||
"providerBasePath": "/api/v1",
|
||||
"providerDomain": "proxy.example.com"
|
||||
}`
|
||||
config.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, "openai", config.typ)
|
||||
assert.Equal(t, "/api/v1", config.providerBasePath)
|
||||
assert.Equal(t, "proxy.example.com", config.providerDomain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyProviderBasePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerBasePath string
|
||||
originalPath string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "no_base_path_configured",
|
||||
providerBasePath: "",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "base_path_prepended",
|
||||
providerBasePath: "/api/ai",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/api/ai/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "path_already_has_base_path",
|
||||
providerBasePath: "/api/ai",
|
||||
originalPath: "/api/ai/v1/chat/completions",
|
||||
expectedPath: "/api/ai/v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "base_path_with_trailing_slash",
|
||||
providerBasePath: "/api/ai/",
|
||||
originalPath: "/v1/chat/completions",
|
||||
expectedPath: "/api/ai//v1/chat/completions",
|
||||
},
|
||||
{
|
||||
name: "deep_base_path",
|
||||
providerBasePath: "/internal/services/ai",
|
||||
originalPath: "/v1/models",
|
||||
expectedPath: "/internal/services/ai/v1/models",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: tt.providerBasePath,
|
||||
}
|
||||
result := config.applyProviderBasePath(tt.originalPath)
|
||||
assert.Equal(t, tt.expectedPath, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestHeaders_PathHandling(t *testing.T) {
|
||||
// This test verifies the path handling logic in handleRequestHeaders
|
||||
// including basePathHandling and providerBasePath
|
||||
|
||||
t.Run("basePath_removePrefix_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/gateway",
|
||||
basePathHandling: basePathHandlingRemovePrefix,
|
||||
}
|
||||
// Simulate the logic - actual test would need mock provider
|
||||
originPath := "/gateway/v1/chat"
|
||||
expectedPath := "/v1/chat"
|
||||
result := strings.TrimPrefix(originPath, config.basePath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("basePath_prepend_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/api",
|
||||
basePathHandling: basePathHandlingPrepend,
|
||||
}
|
||||
currentPath := "/v1/chat"
|
||||
// basePath preprend + providerBasePath (not set) = just basePath effect
|
||||
// Note: applyProviderBasePath only handles providerBasePath, not basePath
|
||||
// So this test just verifies that applyProviderBasePath doesn't modify path when providerBasePath is empty
|
||||
expectedPath := "/v1/chat" // applyProviderBasePath doesn't change path without providerBasePath configured
|
||||
result := config.applyProviderBasePath(currentPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_only", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/ai-proxy",
|
||||
}
|
||||
currentPath := "/v1/chat"
|
||||
expectedPath := "/ai-proxy/v1/chat"
|
||||
result := config.applyProviderBasePath(currentPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("both_basePath_and_providerBasePath", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
basePath: "/gateway",
|
||||
basePathHandling: basePathHandlingRemovePrefix,
|
||||
providerBasePath: "/ai",
|
||||
}
|
||||
// First removePrefix, then apply providerBasePath
|
||||
originPath := "/gateway/v1/chat"
|
||||
afterRemovePrefix := strings.TrimPrefix(originPath, config.basePath)
|
||||
finalPath := config.applyProviderBasePath(afterRemovePrefix)
|
||||
assert.Equal(t, "/ai/v1/chat", finalPath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_IsOriginal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
expected bool
|
||||
}{
|
||||
{"openai_protocol", protocolOpenAI, false},
|
||||
{"original_protocol", protocolOriginal, true},
|
||||
{"empty_protocol", "", false},
|
||||
{"unknown_protocol", "unknown", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
protocol: tt.protocol,
|
||||
}
|
||||
result := config.IsOriginal()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfig_GetPromoteThinkingOnEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
promoteThinkingOnEmpty bool
|
||||
expected bool
|
||||
}{
|
||||
{"enabled", true, true},
|
||||
{"disabled", false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
promoteThinkingOnEmpty: tt.promoteThinkingOnEmpty,
|
||||
}
|
||||
result := config.GetPromoteThinkingOnEmpty()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Failover Tests ============
|
||||
|
||||
func TestFailover_FromJson_Defaults(t *testing.T) {
|
||||
t.Run("default_failure_threshold", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(3), f.failureThreshold)
|
||||
})
|
||||
|
||||
t.Run("default_success_threshold", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(1), f.successThreshold)
|
||||
})
|
||||
|
||||
t.Run("default_health_check_interval", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(5000), f.healthCheckInterval)
|
||||
})
|
||||
|
||||
t.Run("default_health_check_timeout", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(5000), f.healthCheckTimeout)
|
||||
})
|
||||
|
||||
t.Run("custom_values", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"failureThreshold": 5,
|
||||
"successThreshold": 3,
|
||||
"healthCheckInterval": 10000,
|
||||
"healthCheckTimeout": 8000,
|
||||
"healthCheckModel": "test-model"
|
||||
}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, true, f.enabled)
|
||||
assert.Equal(t, int64(5), f.failureThreshold)
|
||||
assert.Equal(t, int64(3), f.successThreshold)
|
||||
assert.Equal(t, int64(10000), f.healthCheckInterval)
|
||||
assert.Equal(t, int64(8000), f.healthCheckTimeout)
|
||||
assert.Equal(t, "test-model", f.healthCheckModel)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
t.Run("parse_failoverOnStatus_array", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"failoverOnStatus": ["401", "403", "5[0-9][0-9]"]
|
||||
}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, 3, len(f.failoverOnStatus))
|
||||
assert.Contains(t, f.failoverOnStatus, "401")
|
||||
assert.Contains(t, f.failoverOnStatus, "403")
|
||||
assert.Contains(t, f.failoverOnStatus, "5[0-9][0-9]")
|
||||
})
|
||||
|
||||
t.Run("empty_failoverOnStatus", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
// When failoverOnStatus is not specified, it keeps default values
|
||||
// Default regex patterns may be set elsewhere
|
||||
assert.True(t, f.enabled)
|
||||
assert.Equal(t, int64(3), f.failureThreshold)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealthCheckEndpoint_Struct(t *testing.T) {
|
||||
t.Run("health_check_endpoint_fields", func(t *testing.T) {
|
||||
endpoint := HealthCheckEndpoint{
|
||||
Host: "api.example.com",
|
||||
Path: "/v1/chat/completions",
|
||||
Cluster: "ai-provider-cluster",
|
||||
}
|
||||
assert.Equal(t, "api.example.com", endpoint.Host)
|
||||
assert.Equal(t, "/v1/chat/completions", endpoint.Path)
|
||||
assert.Equal(t, "ai-provider-cluster", endpoint.Cluster)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLease_Struct(t *testing.T) {
|
||||
t.Run("lease_fields", func(t *testing.T) {
|
||||
lease := Lease{
|
||||
VMID: "vm-12345",
|
||||
Timestamp: 1234567890,
|
||||
}
|
||||
assert.Equal(t, "vm-12345", lease.VMID)
|
||||
assert.Equal(t, int64(1234567890), lease.Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Constants(t *testing.T) {
|
||||
t.Run("cas_max_retries_value", func(t *testing.T) {
|
||||
assert.Equal(t, 10, casMaxRetries)
|
||||
})
|
||||
|
||||
t.Run("operation_constants", func(t *testing.T) {
|
||||
assert.Equal(t, "addApiToken", addApiTokenOperation)
|
||||
assert.Equal(t, "removeApiToken", removeApiTokenOperation)
|
||||
assert.Equal(t, "addApiTokenRequestCount", addApiTokenRequestCountOperation)
|
||||
assert.Equal(t, "resetApiTokenRequestCount", resetApiTokenRequestCountOperation)
|
||||
})
|
||||
|
||||
t.Run("context_key_constants", func(t *testing.T) {
|
||||
assert.Equal(t, "requestHost", CtxRequestHost)
|
||||
assert.Equal(t, "requestPath", CtxRequestPath)
|
||||
assert.Equal(t, "requestBody", CtxRequestBody)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_TransformRequestHeadersAndBody_PathHandling(t *testing.T) {
|
||||
// Test that providerBasePath is applied in transformRequestHeadersAndBody
|
||||
t.Run("providerBasePath_applied", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/api/ai",
|
||||
}
|
||||
|
||||
// Test the applyProviderBasePath logic used in transformRequestHeadersAndBody
|
||||
testPath := "/v1/chat/completions"
|
||||
expectedPath := "/api/ai/v1/chat/completions"
|
||||
result := config.applyProviderBasePath(testPath)
|
||||
assert.Equal(t, expectedPath, result)
|
||||
})
|
||||
|
||||
t.Run("providerBasePath_already_present", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
providerBasePath: "/api/ai",
|
||||
}
|
||||
|
||||
testPath := "/api/ai/v1/chat/completions"
|
||||
result := config.applyProviderBasePath(testPath)
|
||||
// Should not duplicate the prefix
|
||||
assert.Equal(t, "/api/ai/v1/chat/completions", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_IsSupportedAPI(t *testing.T) {
|
||||
t.Run("supported_api", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
string(ApiNameEmbeddings): "/v1/embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameChatCompletion)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported_api", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameEmbeddings)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("empty_capabilities", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{},
|
||||
}
|
||||
|
||||
result := config.IsSupportedAPI(ApiNameChatCompletion)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderConfig_SetDefaultCapabilities(t *testing.T) {
|
||||
t.Run("set_when_nil", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: nil,
|
||||
}
|
||||
|
||||
defaultCaps := map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
}
|
||||
config.setDefaultCapabilities(defaultCaps)
|
||||
|
||||
assert.NotNil(t, config.capabilities)
|
||||
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
|
||||
})
|
||||
|
||||
t.Run("merge_with_existing", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
capabilities: map[string]string{
|
||||
string(ApiNameEmbeddings): "/v1/embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
defaultCaps := map[string]string{
|
||||
string(ApiNameChatCompletion): "/v1/chat/completions",
|
||||
}
|
||||
config.setDefaultCapabilities(defaultCaps)
|
||||
|
||||
assert.Equal(t, "/v1/embeddings", config.capabilities[string(ApiNameEmbeddings)])
|
||||
assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)])
|
||||
})
|
||||
}
|
||||
@@ -30,6 +30,7 @@ const (
|
||||
qwenCompatibleChatCompletionPath = "/compatible-mode/v1/chat/completions"
|
||||
qwenCompatibleCompletionsPath = "/compatible-mode/v1/completions"
|
||||
qwenCompatibleTextEmbeddingPath = "/compatible-mode/v1/embeddings"
|
||||
qwenCompatibleResponsesPath = "/api/v2/apps/protocols/compatible-mode/v1/responses"
|
||||
qwenCompatibleFilesPath = "/compatible-mode/v1/files"
|
||||
qwenCompatibleRetrieveFilePath = "/compatible-mode/v1/files/{file_id}"
|
||||
qwenCompatibleRetrieveFileContentPath = "/compatible-mode/v1/files/{file_id}/content"
|
||||
@@ -37,7 +38,7 @@ const (
|
||||
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
qwenAnthropicMessagesPath = "/api/v2/apps/claude-code-proxy/v1/messages"
|
||||
qwenAnthropicMessagesPath = "/apps/anthropic/v1/messages"
|
||||
|
||||
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
||||
qwenAsyncTaskPath = "/api/v1/tasks/"
|
||||
@@ -69,6 +70,7 @@ func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool)
|
||||
string(ApiNameChatCompletion): qwenCompatibleChatCompletionPath,
|
||||
string(ApiNameEmbeddings): qwenCompatibleTextEmbeddingPath,
|
||||
string(ApiNameCompletion): qwenCompatibleCompletionsPath,
|
||||
string(ApiNameResponses): qwenCompatibleResponsesPath,
|
||||
string(ApiNameFiles): qwenCompatibleFilesPath,
|
||||
string(ApiNameRetrieveFile): qwenCompatibleRetrieveFilePath,
|
||||
string(ApiNameRetrieveFileContent): qwenCompatibleRetrieveFileContentPath,
|
||||
@@ -707,6 +709,8 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
case strings.Contains(path, qwenTextEmbeddingPath),
|
||||
strings.Contains(path, qwenCompatibleTextEmbeddingPath):
|
||||
return ApiNameEmbeddings
|
||||
case strings.Contains(path, qwenCompatibleResponsesPath):
|
||||
return ApiNameResponses
|
||||
case strings.Contains(path, qwenAsyncAIGCPath):
|
||||
return ApiNameQwenAsyncAIGC
|
||||
case strings.Contains(path, qwenAsyncTaskPath):
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
@@ -140,6 +154,54 @@ func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, erro
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeConsecutiveMessages merges consecutive messages of the same role (user or assistant).
|
||||
// Many LLM providers require strict user↔assistant alternation and reject requests where
|
||||
// two messages of the same role appear consecutively. When enabled, consecutive same-role
|
||||
// messages have their content concatenated into a single message.
|
||||
func mergeConsecutiveMessages(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for message merging: %v", err)
|
||||
}
|
||||
if len(request.Messages) <= 1 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
merged := false
|
||||
result := make([]chatMessage, 0, len(request.Messages))
|
||||
for _, msg := range request.Messages {
|
||||
if len(result) > 0 &&
|
||||
result[len(result)-1].Role == msg.Role &&
|
||||
(msg.Role == roleUser || msg.Role == roleAssistant) {
|
||||
last := &result[len(result)-1]
|
||||
last.Content = mergeMessageContent(last.Content, msg.Content)
|
||||
merged = true
|
||||
continue
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
|
||||
if !merged {
|
||||
return body, nil
|
||||
}
|
||||
request.Messages = result
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeMessageContent concatenates two message content values.
|
||||
// If both are plain strings they are joined with a blank line.
|
||||
// Otherwise both are converted to content-block arrays and concatenated.
|
||||
func mergeMessageContent(prev, curr any) any {
|
||||
prevStr, prevIsStr := prev.(string)
|
||||
currStr, currIsStr := curr.(string)
|
||||
if prevIsStr && currIsStr {
|
||||
return prevStr + "\n\n" + currStr
|
||||
}
|
||||
prevParts := (&chatMessage{Content: prev}).ParseContent()
|
||||
currParts := (&chatMessage{Content: curr}).ParseContent()
|
||||
return append(prevParts, currParts...)
|
||||
}
|
||||
|
||||
func ReplaceResponseBody(body []byte) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
|
||||
@@ -8,6 +8,131 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeConsecutiveMessages(t *testing.T) {
|
||||
t.Run("no_consecutive_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "再见"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// No merging needed, returned body should be identical
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_user_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "第一条"},
|
||||
{Role: "user", Content: "第二条"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "第一条\n\n第二条", output.Messages[0].Content)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_assistant_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "问题"},
|
||||
{Role: "assistant", Content: "第一段"},
|
||||
{Role: "assistant", Content: "第二段"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
assert.Equal(t, "第一段\n\n第二段", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("merges_multiple_consecutive_same_role", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "A"},
|
||||
{Role: "user", Content: "B"},
|
||||
{Role: "user", Content: "C"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "A\n\nB\n\nC", output.Messages[0].Content)
|
||||
})
|
||||
|
||||
t.Run("system_messages_not_merged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "系统提示1"},
|
||||
{Role: "system", Content: "系统提示2"},
|
||||
{Role: "user", Content: "问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// system messages are not merged, body unchanged
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("single_message_unchanged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "只有一条"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("invalid_json_body", func(t *testing.T) {
|
||||
body := []byte(`invalid json`)
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCleanupContextMessages(t *testing.T) {
|
||||
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
@@ -45,7 +45,9 @@ const (
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
contextVertexStreamDoneMarker = "vertexStreamDoneSent"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
@@ -98,6 +100,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameImageEdit): vertexPathTemplate,
|
||||
string(ApiNameImageVariation): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
@@ -255,12 +259,12 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if v.isOpenAICompatibleMode() {
|
||||
ctx.SetContext(contextOpenAICompatibleMarker, true)
|
||||
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
// OpenAI 兼容模式需要 OAuth token
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
@@ -273,6 +277,9 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
if v.isExpressMode() {
|
||||
@@ -280,15 +287,12 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, err
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// 标准模式: 需要获取 OAuth token
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
@@ -307,6 +311,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
case ApiNameImageEdit:
|
||||
return v.onImageEditRequestBody(ctx, body, headers)
|
||||
case ApiNameImageVariation:
|
||||
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -361,7 +369,10 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
vertexRequest, err := v.buildVertexChatRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
}
|
||||
@@ -387,11 +398,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageEditRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedRequest.HasMask {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.HasMask() {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
if request.Prompt == "" {
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageVariationRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
if prompt == "" {
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
@@ -402,12 +510,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
aspectRatio, imageSize := v.parseImageSize(size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
if outputFormat != "" {
|
||||
switch outputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
@@ -417,12 +525,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||
for _, imageURL := range imageURLs {
|
||||
part, err := convertMediaContent(imageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
if prompt != "" {
|
||||
parts = append(parts, vertexPart{
|
||||
Text: prompt,
|
||||
})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
Role: roleUser,
|
||||
Parts: parts,
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
@@ -440,7 +563,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
return vertexRequest, nil
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
@@ -502,23 +625,46 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
if len(chunk) == 0 && !isLastChunk {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
}
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
// Flush a trailing event when upstream closes stream without a final blank line.
|
||||
chunkForParsing := chunk
|
||||
if isLastChunk {
|
||||
trailingNewLineCount := 0
|
||||
for i := len(chunkForParsing) - 1; i >= 0 && chunkForParsing[i] == '\n'; i-- {
|
||||
trailingNewLineCount++
|
||||
}
|
||||
if trailingNewLineCount < 2 {
|
||||
chunkForParsing = append([]byte(nil), chunk...)
|
||||
for i := 0; i < 2-trailingNewLineCount; i++ {
|
||||
chunkForParsing = append(chunkForParsing, '\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
streamEvents := ExtractStreamingEvents(ctx, chunkForParsing)
|
||||
doneSent, _ := ctx.GetContext(contextVertexStreamDoneMarker).(bool)
|
||||
appendDone := isLastChunk && !doneSent
|
||||
for _, event := range streamEvents {
|
||||
data := event.Data
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if data == streamEndDataValue {
|
||||
if !doneSent {
|
||||
appendDone = true
|
||||
doneSent = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
var vertexResp vertexChatResponse
|
||||
if err := json.Unmarshal([]byte(data), &vertexResp); err != nil {
|
||||
log.Errorf("unable to unmarshal vertex response: %v", err)
|
||||
@@ -532,7 +678,17 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
v.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
if appendDone {
|
||||
responseBuilder.WriteString(ssePrefix + "[DONE]\n\n")
|
||||
doneSent = true
|
||||
}
|
||||
ctx.SetContext(contextVertexStreamDoneMarker, doneSent)
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
if modifiedResponseChunk == "" {
|
||||
// Returning an empty payload prevents main.go from falling back to
|
||||
// forwarding the original raw chunk, which may contain partial JSON.
|
||||
return []byte(""), nil
|
||||
}
|
||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
@@ -553,7 +709,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
@@ -784,7 +940,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
@@ -818,7 +974,7 @@ func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
|
||||
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) (*vertexChatRequest, error) {
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, vertexChatSafetySetting{
|
||||
@@ -853,6 +1009,9 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
}
|
||||
vertexRequest.GenerationConfig.ThinkingConfig = thinkingConfig
|
||||
}
|
||||
if err := v.applyResponseFormatToGenerationConfig(request.ResponseFormat, &vertexRequest.GenerationConfig, request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.Tools != nil {
|
||||
functions := make([]function, 0, len(request.Tools))
|
||||
for _, tool := range request.Tools {
|
||||
@@ -938,7 +1097,130 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
}
|
||||
}
|
||||
|
||||
return &vertexRequest
|
||||
return &vertexRequest, nil
|
||||
}
|
||||
|
||||
// applyResponseFormatToGenerationConfig maps OpenAI response_format into Vertex generationConfig.
|
||||
// The mapping is strict for type=json_schema to avoid silently breaking structured-output contracts.
|
||||
func (v *vertexProvider) applyResponseFormatToGenerationConfig(responseFormat map[string]interface{}, generationConfig *vertexChatGenerationConfig, model string) error {
|
||||
if generationConfig == nil || len(responseFormat) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: Gemini 2.0 structured output requires propertyOrdering.
|
||||
// Because gemini-2.0-* is legacy and rarely used, we intentionally do not implement
|
||||
// propertyOrdering synthesis here; instead we ignore response_format and keep request
|
||||
// as non-structured output for stability and minimal conversion behavior.
|
||||
if requiresPropertyOrderingForModel(model) {
|
||||
return nil
|
||||
}
|
||||
|
||||
responseFormatType, _ := responseFormat["type"].(string)
|
||||
responseFormatType = strings.ToLower(responseFormatType)
|
||||
|
||||
switch responseFormatType {
|
||||
case "":
|
||||
// Be tolerant for non-standard clients that pass schema directly in response_format.
|
||||
if isJSONSchemaMap(responseFormat) {
|
||||
generationConfig.ResponseMimeType = util.MimeTypeApplicationJson
|
||||
generationConfig.ResponseSchema = responseFormat
|
||||
}
|
||||
case "json_object":
|
||||
generationConfig.ResponseMimeType = util.MimeTypeApplicationJson
|
||||
case "json_schema":
|
||||
schema := extractOpenAIJSONSchema(responseFormat)
|
||||
if len(schema) == 0 {
|
||||
return fmt.Errorf("invalid response_format.json_schema: missing schema object")
|
||||
}
|
||||
generationConfig.ResponseMimeType = util.MimeTypeApplicationJson
|
||||
generationConfig.ResponseSchema = schema
|
||||
case "text":
|
||||
// Vertex defaults to text output when no response mime/schema is provided.
|
||||
default:
|
||||
// Be tolerant for non-standard usage where response_format itself is a JSON schema.
|
||||
if isJSONSchemaType(responseFormatType) && isJSONSchemaMap(responseFormat) {
|
||||
generationConfig.ResponseMimeType = util.MimeTypeApplicationJson
|
||||
generationConfig.ResponseSchema = responseFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractOpenAIJSONSchema(responseFormat map[string]interface{}) map[string]interface{} {
|
||||
jsonSchemaValue, ok := responseFormat["json_schema"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonSchemaMap, ok := jsonSchemaValue.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI canonical format:
|
||||
// {
|
||||
// "type":"json_schema",
|
||||
// "json_schema":{"name":"...","strict":true,"schema":{...}}
|
||||
// }
|
||||
if nestedSchemaValue, ok := jsonSchemaMap["schema"]; ok {
|
||||
if nestedSchema, ok := nestedSchemaValue.(map[string]interface{}); ok {
|
||||
return nestedSchema
|
||||
}
|
||||
}
|
||||
|
||||
// Tolerate non-standard format where json_schema itself is the schema.
|
||||
if isJSONSchemaMap(jsonSchemaMap) {
|
||||
return jsonSchemaMap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isJSONSchemaType(value string) bool {
|
||||
switch strings.ToLower(value) {
|
||||
case "object", "array", "string", "number", "integer", "boolean", "null":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isJSONSchemaMap(schema map[string]interface{}) bool {
|
||||
if len(schema) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if typeValue, ok := schema["type"].(string); ok && isJSONSchemaType(typeValue) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Schema might omit "type" and still be valid for specific cases.
|
||||
schemaKeys := []string{
|
||||
"anyOf",
|
||||
"enum",
|
||||
"format",
|
||||
"items",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minimum",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"properties",
|
||||
"description",
|
||||
"propertyOrdering",
|
||||
"required",
|
||||
}
|
||||
for _, key := range schemaKeys {
|
||||
if _, ok := schema[key]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func requiresPropertyOrderingForModel(model string) bool {
|
||||
model = strings.ToLower(model)
|
||||
return strings.HasPrefix(model, "gemini-2.0-")
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildEmbeddingRequest(request *embeddingsRequest) *vertexEmbeddingRequest {
|
||||
@@ -1017,14 +1299,16 @@ type vertexChatSafetySetting struct {
|
||||
}
|
||||
|
||||
type vertexChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema map[string]interface{} `json:"responseSchema,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageConfig struct {
|
||||
|
||||
258
plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go
Normal file
258
plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVertexProviderBuildChatRequestStructuredOutputMapping(t *testing.T) {
|
||||
t.Run("json_object response format", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_object",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Equal(t, util.MimeTypeApplicationJson, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("json_schema response format with nested schema", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"answer": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []interface{}{"answer"},
|
||||
}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]interface{}{
|
||||
"name": "response",
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Equal(t, util.MimeTypeApplicationJson, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Equal(t, schema, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("json_schema response format with direct schema object", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": schema,
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Equal(t, util.MimeTypeApplicationJson, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Equal(t, schema, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("json_schema response format without valid schema should return error", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": "invalid",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, vertexReq)
|
||||
assert.Contains(t, err.Error(), "invalid response_format.json_schema")
|
||||
})
|
||||
|
||||
t.Run("direct schema in response_format for compatibility", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: schema,
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Equal(t, util.MimeTypeApplicationJson, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Equal(t, schema, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("text response format keeps default text output", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "text",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Empty(t, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("unknown response format does not inject schema config", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.5-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "xml",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
|
||||
assert.Empty(t, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("gemini 2.0 json_schema is ignored for stability", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"answer": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.0-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]interface{}{
|
||||
"name": "response",
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
assert.Empty(t, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("gemini 2.0 malformed json_schema is also ignored", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.0-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": "invalid",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
assert.Empty(t, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
|
||||
t.Run("gemini 2.0 json_object is ignored", func(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
Model: "gemini-2.0-flash",
|
||||
Messages: []chatMessage{
|
||||
{Role: roleUser, Content: "hello"},
|
||||
},
|
||||
ResponseFormat: map[string]interface{}{
|
||||
"type": "json_object",
|
||||
},
|
||||
}
|
||||
|
||||
vertexReq, err := v.buildVertexChatRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, vertexReq)
|
||||
assert.Empty(t, vertexReq.GenerationConfig.ResponseMimeType)
|
||||
assert.Nil(t, vertexReq.GenerationConfig.ResponseSchema)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVertexProviderApplyResponseFormatNilSafety(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
require.NoError(t, v.applyResponseFormatToGenerationConfig(map[string]interface{}{"type": "json_object"}, nil, "gemini-2.5-flash"))
|
||||
require.NoError(t, v.applyResponseFormatToGenerationConfig(nil, &vertexChatGenerationConfig{}, "gemini-2.5-flash"))
|
||||
require.NoError(t, v.applyResponseFormatToGenerationConfig(map[string]interface{}{}, &vertexChatGenerationConfig{}, "gemini-2.5-flash"))
|
||||
}
|
||||
@@ -8,11 +8,15 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
zhipuAiDomain = "open.bigmodel.cn"
|
||||
zhipuAiDefaultDomain = "open.bigmodel.cn"
|
||||
zhipuAiInternationalDomain = "api.z.ai"
|
||||
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
|
||||
zhipuAiCodePlanPath = "/api/coding/paas/v4/chat/completions"
|
||||
zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings"
|
||||
zhipuAiAnthropicMessagesPath = "/api/anthropic/v1/messages"
|
||||
)
|
||||
@@ -26,16 +30,20 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
func (m *zhipuAiProviderInitializer) DefaultCapabilities(codePlanMode bool) map[string]string {
|
||||
chatPath := zhipuAiChatCompletionPath
|
||||
if codePlanMode {
|
||||
chatPath = zhipuAiCodePlanPath
|
||||
}
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): zhipuAiChatCompletionPath,
|
||||
string(ApiNameChatCompletion): chatPath,
|
||||
string(ApiNameEmbeddings): zhipuAiEmbeddingsPath,
|
||||
// string(ApiNameAnthropicMessages): zhipuAiAnthropicMessagesPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities(config.zhipuCodePlanMode))
|
||||
return &zhipuAiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -65,13 +73,39 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
|
||||
// Use configured domain or default to China domain
|
||||
domain := m.config.zhipuDomain
|
||||
if domain == "" {
|
||||
domain = zhipuAiDefaultDomain
|
||||
}
|
||||
util.OverwriteRequestHostHeader(headers, domain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
// Check if reasoning_effort is set
|
||||
reasoningEffort := gjson.GetBytes(body, "reasoning_effort").String()
|
||||
if reasoningEffort != "" {
|
||||
// Add thinking config for ZhipuAI
|
||||
body, _ = sjson.SetBytes(body, "thinking", map[string]string{"type": "enabled"})
|
||||
// Remove reasoning_effort field as ZhipuAI doesn't recognize it
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
} else if thinkingType, ok := ctx.GetContext(ctxKeyClaudeThinkingType).(string); ok && thinkingType != "enabled" {
|
||||
// Request came from Claude auto-conversion with thinking explicitly disabled or absent.
|
||||
// Explicitly set thinking=disabled to prevent ZhipuAI from enabling it by default.
|
||||
body, _ = sjson.SetBytes(body, "thinking", map[string]string{"type": "disabled"})
|
||||
}
|
||||
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) {
|
||||
if strings.Contains(path, zhipuAiChatCompletionPath) || strings.Contains(path, zhipuAiCodePlanPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, zhipuAiEmbeddingsPath) {
|
||||
|
||||
116
plugins/wasm-go/extensions/ai-proxy/test/api_paths.go
Normal file
116
plugins/wasm-go/extensions/ai-proxy/test/api_paths.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
wasmtest "github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func openAICustomEndpointConfig(customURL string) json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test-custom-endpoint"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4o-mini",
|
||||
},
|
||||
"openaiCustomUrl": customURL,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
var openAICustomAudioTranscriptionsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/audio/transcriptions")
|
||||
var openAICustomAudioTranslationsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/audio/translations")
|
||||
var openAICustomRealtimeEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/realtime")
|
||||
var openAICustomRealtimeSessionsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/realtime/sessions")
|
||||
|
||||
func RunApiPathRegressionTests(t *testing.T) {
|
||||
wasmtest.RunTest(t, func(t *testing.T) {
|
||||
t.Run("openai direct custom endpoint audio transcriptions", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomAudioTranscriptionsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/audio/transcriptions", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai direct custom endpoint audio translations", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomAudioTranslationsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/translations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/audio/translations", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai direct custom endpoint realtime", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomRealtimeEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/realtime", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai non-direct endpoint appends mapped realtime suffix", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomRealtimeSessionsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/realtime/sessions/realtime", pathValue)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
@@ -1,7 +1,11 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -80,6 +84,22 @@ var azureDomainOnlyConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
var azureDomainOnlyImageMultipartConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-image-multipart",
|
||||
},
|
||||
"azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-image-1.5": "gpt-image-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI多模型配置
|
||||
var azureMultiModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -99,6 +119,74 @@ var azureMultiModelConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
func getMultipartTextField(body []byte, contentType string, fieldName string) (string, bool, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if part.FormName() == fieldName {
|
||||
return string(data), true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunAzureMultipartHelperTests(t *testing.T) {
|
||||
t.Run("multipart text field returns error for invalid content type", func(t *testing.T) {
|
||||
_, _, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data; boundary=\"", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns not found for missing boundary", func(t *testing.T) {
|
||||
value, found, err := getMultipartTextField([]byte("bad-body"), "multipart/form-data", "model")
|
||||
require.NoError(t, err)
|
||||
require.False(t, found)
|
||||
require.Equal(t, "", value)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns not found on eof", func(t *testing.T) {
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
}, nil)
|
||||
|
||||
value, found, err := getMultipartTextField(body, contentType, "prompt")
|
||||
require.NoError(t, err)
|
||||
require.False(t, found)
|
||||
require.Equal(t, "", value)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns next part error on malformed body", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nnot-a-header\r\n\r\nvalue\r\n--abc--\r\n")
|
||||
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("multipart text field returns read error on truncated part", func(t *testing.T) {
|
||||
body := []byte("--abc\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\nvalue\r\n--ab")
|
||||
_, _, err := getMultipartTextField(body, "multipart/form-data; boundary=abc", "model")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置:Azure OpenAI无效配置(缺少azureServiceUrl)
|
||||
var azureInvalidConfigMissingUrl = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -615,6 +703,121 @@ func RunAzureOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.Equal(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions?api-version=2024-02-15-preview", "Path should use model from request body")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart image edit request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "把小狗换成白色",
|
||||
"size": "1024x1024",
|
||||
"n": "1",
|
||||
}, map[string][]byte{
|
||||
"image[]": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
|
||||
require.NoError(t, err)
|
||||
require.True(t, found, "Model field should exist in multipart body")
|
||||
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
|
||||
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, "/openai/deployments/gpt-image-1/images/edits?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
|
||||
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart image variation request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gpt-image-1.5",
|
||||
"prompt": "生成类似风格",
|
||||
"size": "1024x1024",
|
||||
"n": "1",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
modelValue, found, err := getMultipartTextField(transformedBody, contentType, "model")
|
||||
require.NoError(t, err)
|
||||
require.True(t, found, "Model field should exist in multipart body")
|
||||
require.Equal(t, "gpt-image-1", modelValue, "Model field should be mapped in multipart body")
|
||||
require.Contains(t, string(transformedBody), "fake-image-content", "Image file content should remain in multipart body")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, "/openai/deployments/gpt-image-1/images/variations?api-version=2024-02-15-preview", pathValue, "Path should use mapped multipart model")
|
||||
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(requestHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, contentType, contentTypeValue, "Multipart Content-Type should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("azure domain only multipart malformed body logs transform failure", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyImageMultipartConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "multipart/form-data"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte("bad-multipart-body"))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasMultipartTransformFailureLog := false
|
||||
for _, debugLog := range debugLogs {
|
||||
if strings.Contains(debugLog, "[azure multipart] body transform failed") {
|
||||
hasMultipartTransformFailureLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasMultipartTransformFailureLog, "Should log azure multipart transform failure")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI模型无关请求处理(仅域名配置)
|
||||
t.Run("azure domain only model independent", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyConfig)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:多 API Token 配置(用于测试 consumer affinity)
|
||||
var multiTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-1", "sk-token-2", "sk-token-3"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:单 API Token 配置
|
||||
var singleTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-single-token"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunConsumerAffinityParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("multi token config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunConsumerAffinityOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 stateful API(responses)使用 consumer affinity
|
||||
t.Run("stateful api responses with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 使用 x-mse-consumer header
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-alice"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证 Authorization header 使用了其中一个 token
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(files)使用 consumer affinity
|
||||
t.Run("stateful api files with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/files"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-files"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(batches)使用 consumer affinity
|
||||
t.Run("stateful api batches with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/batches"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-batches"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(fine_tuning)使用 consumer affinity
|
||||
t.Run("stateful api fine_tuning with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/fine_tuning/jobs"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-finetuning"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试非 stateful API 正常工作
|
||||
t.Run("non stateful api chat completions works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-chat"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试无 x-mse-consumer header 时正常工作
|
||||
t.Run("stateful api without consumer header works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试单个 token 时始终使用该 token
|
||||
t.Run("single token always used", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(singleTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-test"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.Contains(t, authValue, "sk-single-token", "Single token should always be used")
|
||||
})
|
||||
|
||||
// 测试同一 consumer 多次请求获得相同 token(consumer affinity 一致性)
|
||||
t.Run("same consumer gets consistent token across requests", func(t *testing.T) {
|
||||
consumer := "consumer-consistency-test"
|
||||
var firstToken string
|
||||
|
||||
// 运行 5 次请求,验证同一个 consumer 始终获得相同的 token
|
||||
for i := 0; i < 5; i++ {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens")
|
||||
|
||||
if i == 0 {
|
||||
firstToken = authValue
|
||||
} else {
|
||||
require.Equal(t, firstToken, authValue, "Same consumer should get same token consistently (consumer affinity)")
|
||||
}
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
})
|
||||
|
||||
// 测试不同 consumer 可能获得不同 token
|
||||
t.Run("different consumers get tokens based on hash", func(t *testing.T) {
|
||||
tokens := make(map[string]string)
|
||||
|
||||
consumers := []string{"consumer-alpha", "consumer-beta", "consumer-gamma", "consumer-delta", "consumer-epsilon"}
|
||||
for _, consumer := range consumers {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
tokens[consumer] = authValue
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
|
||||
// 验证至少使用了多个不同的 token(hash 分布)
|
||||
uniqueTokens := make(map[string]bool)
|
||||
for _, token := range tokens {
|
||||
uniqueTokens[token] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(uniqueTokens), 2, "Different consumers should use at least 2 different tokens")
|
||||
})
|
||||
})
|
||||
}
|
||||
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package test
|
||||
|
||||
import "github.com/higress-group/wasm-go/pkg/iface"
|
||||
|
||||
// MockHttpContext is a minimal mock for wrapper.HttpContext used in unit tests
|
||||
// that call provider functions directly (e.g. streaming thinking promotion).
|
||||
type MockHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func NewMockHttpContext() *MockHttpContext {
|
||||
return &MockHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *MockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value }
|
||||
func (m *MockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *MockHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *MockHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *MockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *MockHttpContext) Scheme() string { return "" }
|
||||
func (m *MockHttpContext) Host() string { return "" }
|
||||
func (m *MockHttpContext) Path() string { return "" }
|
||||
func (m *MockHttpContext) Method() string { return "" }
|
||||
func (m *MockHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *MockHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *MockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *MockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *MockHttpContext) DontReadRequestBody() {}
|
||||
func (m *MockHttpContext) DontReadResponseBody() {}
|
||||
func (m *MockHttpContext) BufferRequestBody() {}
|
||||
func (m *MockHttpContext) BufferResponseBody() {}
|
||||
func (m *MockHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *MockHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *MockHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *MockHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *MockHttpContext) DisableReroute() {}
|
||||
func (m *MockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *MockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *MockHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *MockHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -243,6 +244,84 @@ func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
|
||||
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(语音转写接口)
|
||||
t.Run("openai audio transcriptions request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.openai.com", hostValue)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/audio/transcriptions", "Path should contain audio transcriptions endpoint")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(语音翻译接口)
|
||||
t.Run("openai audio translations request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/translations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/audio/translations", "Path should contain audio translations endpoint")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(实时接口,WebSocket握手)
|
||||
t.Run("openai realtime websocket handshake request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime?model=gpt-4o-realtime-preview"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
|
||||
// WebSocket 握手本身不应依赖请求体。受测试框架限制,某些场景可能仍返回 HeaderStopIteration。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/realtime", "Path should contain realtime endpoint")
|
||||
require.Contains(t, pathValue, "model=gpt-4o-realtime-preview", "Query parameters should be preserved")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(图像生成接口)
|
||||
t.Run("openai image generation request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
@@ -305,6 +384,61 @@ func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
|
||||
// 对于直接路径,应该保持原有路径
|
||||
require.Contains(t, pathValue, "/v1/chat/completions", "Path should be preserved for direct custom path")
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名请求头处理(间接路径语音转写)
|
||||
t.Run("openai custom domain indirect path audio transcriptions request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "custom.openai.com", hostValue, "Host should be changed to custom domain")
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/audio/transcriptions", "Path should be rewritten with indirect custom prefix")
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名请求头处理(间接路径 realtime,WebSocket握手)
|
||||
t.Run("openai custom domain indirect path realtime websocket handshake request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime?model=gpt-4o-realtime-preview"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
|
||||
// WebSocket 握手本身不应依赖请求体。受测试框架限制,某些场景可能仍返回 HeaderStopIteration。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/realtime", "Path should be rewritten with indirect custom prefix")
|
||||
require.Contains(t, pathValue, "model=gpt-4o-realtime-preview", "Query parameters should be preserved")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -864,3 +998,158 @@ func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置:OpenAI配置 + promoteThinkingOnEmpty
|
||||
var openAIPromoteThinkingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"promoteThinkingOnEmpty": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI配置 + hiclawMode
|
||||
var openAIHiclawModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"hiclawMode": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyTests(t *testing.T) {
|
||||
// Config parsing tests via host framework
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("promoteThinkingOnEmpty config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIPromoteThinkingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
t.Run("hiclawMode config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIHiclawModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
|
||||
// Non-streaming promote logic tests via provider functions directly
|
||||
t.Run("promotes reasoning_content when content is empty string", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"这是思考内容"`)
|
||||
require.NotContains(t, string(result), `"reasoning_content":"这是思考内容"`)
|
||||
})
|
||||
|
||||
t.Run("promotes reasoning_content when content is nil", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"思考结果"`)
|
||||
})
|
||||
|
||||
t.Run("no promotion when content is present", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when no reasoning", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when both empty", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("invalid json returns error", func(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyStreamingTests(t *testing.T) {
|
||||
// Streaming tests use provider functions directly since the test framework
|
||||
// does not expose GetStreamingResponseBody.
|
||||
t.Run("streaming: buffers reasoning and flushes on end when no content", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// Chunk with only reasoning_content
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"流式思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
// Reasoning should be stripped (not promoted inline)
|
||||
require.NotContains(t, string(result), `"content":"流式思考"`)
|
||||
|
||||
// Flush should emit buffered reasoning as content
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"流式思考"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when content was seen", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// First chunk: content delta
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
// Second chunk: reasoning only
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
require.NoError(t, err)
|
||||
// Should be unchanged since content was already seen
|
||||
require.Equal(t, string(data2), string(result))
|
||||
|
||||
// Flush should return nil since content was seen
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: accumulates multiple reasoning chunks", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第一段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第二段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"第一段第二段"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when no reasoning buffered", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: invalid json returns original", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data := []byte(`not json`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(data), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
148
plugins/wasm-go/extensions/ai-proxy/test/openrouter.go
Normal file
148
plugins/wasm-go/extensions/ai-proxy/test/openrouter.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var basicOpenRouterConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openrouter",
|
||||
"apiTokens": []string{"sk-openrouter-test"},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunOpenRouterClaudeAutoConversionTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("claude thinking budget_tokens is converted to reasoning.max_tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenRouterConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Send request with Claude /v1/messages path to trigger auto-conversion
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Claude request body with thinking enabled
|
||||
requestBody := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"max_tokens": 8000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 10000}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// reasoning.max_tokens should be set from budget_tokens
|
||||
reasoning, ok := bodyMap["reasoning"].(map[string]interface{})
|
||||
require.True(t, ok, "reasoning field should be present")
|
||||
assert.Equal(t, float64(10000), reasoning["max_tokens"],
|
||||
"reasoning.max_tokens should preserve the original budget_tokens value")
|
||||
|
||||
// reasoning_effort should be removed (OpenRouter uses reasoning.max_tokens instead)
|
||||
assert.NotContains(t, bodyMap, "reasoning_effort",
|
||||
"reasoning_effort should be removed")
|
||||
|
||||
// Non-standard fields should not be present
|
||||
assert.NotContains(t, bodyMap, "thinking",
|
||||
"thinking should not be in the final request")
|
||||
assert.NotContains(t, bodyMap, "reasoning_max_tokens",
|
||||
"reasoning_max_tokens should not be in the final request")
|
||||
})
|
||||
|
||||
t.Run("claude without thinking uses default transformation", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenRouterConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No reasoning fields should be present
|
||||
assert.NotContains(t, bodyMap, "reasoning")
|
||||
assert.NotContains(t, bodyMap, "reasoning_effort")
|
||||
assert.NotContains(t, bodyMap, "thinking")
|
||||
assert.NotContains(t, bodyMap, "reasoning_max_tokens")
|
||||
})
|
||||
|
||||
t.Run("claude thinking disabled does not set reasoning", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenRouterConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// thinking disabled with budget_tokens (dirty input)
|
||||
requestBody := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "disabled", "budget_tokens": 5000}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should NOT have reasoning.max_tokens since thinking was disabled
|
||||
assert.NotContains(t, bodyMap, "reasoning",
|
||||
"reasoning should not be set when thinking is disabled")
|
||||
assert.NotContains(t, bodyMap, "thinking")
|
||||
assert.NotContains(t, bodyMap, "reasoning_max_tokens")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -100,6 +100,22 @@ var qwenEnableCompatibleConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:qwen original + 兼容模式(用于覆盖 provider.GetApiName 分支)
|
||||
var qwenOriginalCompatibleConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "qwen",
|
||||
"apiTokens": []string{"sk-qwen-original-compatible"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "qwen-turbo",
|
||||
},
|
||||
"qwenEnableCompatible": true,
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:qwen文件ID配置
|
||||
var qwenFileIdsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -159,6 +175,15 @@ var qwenConflictConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
func hasUnsupportedAPINameError(errorLogs []string) bool {
|
||||
for _, log := range errorLogs {
|
||||
if strings.Contains(log, "unsupported API name") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func RunQwenParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本qwen配置解析
|
||||
@@ -403,6 +428,29 @@ func RunQwenOnHttpRequestHeadersTests(t *testing.T) {
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/compatible-mode/v1/chat/completions", "Path should use compatible mode path")
|
||||
})
|
||||
|
||||
// 测试qwen兼容模式请求头处理(responses接口)
|
||||
t.Run("qwen compatible mode responses request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/v2/apps/protocols/compatible-mode/v1/responses", "Path should use compatible mode responses path")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -651,6 +699,112 @@ func RunQwenOnHttpRequestBodyTests(t *testing.T) {
|
||||
}
|
||||
require.True(t, hasCompatibleLogs, "Should have compatible mode processing logs")
|
||||
})
|
||||
|
||||
// 测试qwen请求体处理(兼容模式 responses接口)
|
||||
t.Run("qwen compatible mode responses request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
require.Contains(t, string(processedBody), "qwen-turbo", "Model name should be preserved in responses request")
|
||||
})
|
||||
|
||||
// 测试qwen请求体处理(非兼容模式 responses接口应报不支持)
|
||||
t.Run("qwen non-compatible mode responses request body unsupported", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicQwenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/responses", "Path should remain unchanged when responses is unsupported")
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
bodyAction := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, bodyAction)
|
||||
|
||||
hasUnsupportedErr := hasUnsupportedAPINameError(host.GetErrorLogs())
|
||||
require.True(t, hasUnsupportedErr, "Should log unsupported API name for non-compatible responses")
|
||||
})
|
||||
|
||||
// 覆盖 qwen.GetApiName 中以下分支:
|
||||
// - qwenCompatibleTextEmbeddingPath => ApiNameEmbeddings
|
||||
// - qwenCompatibleResponsesPath => ApiNameResponses
|
||||
// - qwenAsyncAIGCPath => ApiNameQwenAsyncAIGC
|
||||
// - qwenAsyncTaskPath => ApiNameQwenAsyncTask
|
||||
t.Run("qwen original protocol get api name coverage for compatible embeddings responses and async paths", func(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "compatible embeddings path",
|
||||
path: "/compatible-mode/v1/embeddings",
|
||||
},
|
||||
{
|
||||
name: "compatible responses path",
|
||||
path: "/api/v2/apps/protocols/compatible-mode/v1/responses",
|
||||
},
|
||||
{
|
||||
name: "async aigc path",
|
||||
path: "/api/v1/services/aigc/custom-async-endpoint",
|
||||
},
|
||||
{
|
||||
name: "async task path",
|
||||
path: "/api/v1/tasks/task-123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenOriginalCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", tc.path},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
// 测试框架中 action 可能表现为 Continue 或 HeaderStopIteration,
|
||||
// 这里关注的是后续 body 阶段不出现 unsupported API name。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
bodyAction := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, bodyAction)
|
||||
|
||||
hasUnsupportedErr := hasUnsupportedAPINameError(host.GetErrorLogs())
|
||||
require.False(t, hasUnsupportedErr, "Path should be recognized by qwen.GetApiName in original protocol")
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -986,6 +1140,51 @@ func RunQwenOnHttpResponseBodyTests(t *testing.T) {
|
||||
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
|
||||
require.Contains(t, responseStr, "qwen-turbo", "Response should contain model name")
|
||||
})
|
||||
|
||||
// 测试qwen响应体处理(兼容模式 responses 接口透传)
|
||||
t.Run("qwen compatible mode responses response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
responseBody := `{
|
||||
"id": "resp-123",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "hello"
|
||||
}]
|
||||
}]
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "\"object\": \"response\"", "Responses API payload should be passthrough in compatible mode")
|
||||
require.Contains(t, responseStr, "\"text\": \"hello\"", "Assistant content should be preserved")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -378,6 +380,273 @@ func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: json_schema 映射
|
||||
t.Run("vertex express mode structured outputs json_schema request body mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.5-flash",
|
||||
"messages":[{"role":"user","content":"return structured output"}],
|
||||
"response_format":{
|
||||
"type":"json_schema",
|
||||
"json_schema":{
|
||||
"name":"demo_schema",
|
||||
"strict":true,
|
||||
"schema":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"answer":{"type":"string"}
|
||||
},
|
||||
"required":["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var transformed map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(processedBody, &transformed))
|
||||
|
||||
generationConfig, ok := transformed["generationConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "generationConfig should exist")
|
||||
require.Equal(t, "application/json", generationConfig["responseMimeType"], "responseMimeType should be mapped for json_schema")
|
||||
|
||||
responseSchema, ok := generationConfig["responseSchema"].(map[string]interface{})
|
||||
require.True(t, ok, "responseSchema should be mapped from response_format.json_schema.schema")
|
||||
require.Equal(t, "object", responseSchema["type"])
|
||||
|
||||
properties, ok := responseSchema["properties"].(map[string]interface{})
|
||||
require.True(t, ok, "responseSchema.properties should exist")
|
||||
_, hasAnswer := properties["answer"]
|
||||
require.True(t, hasAnswer, "responseSchema.properties.answer should exist")
|
||||
})
|
||||
|
||||
// 测试 Gemini 2.0 structured outputs: 忽略 response_format,按非结构化输出处理
|
||||
t.Run("vertex express mode structured outputs gemini 2.0 ignore response format", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.0-flash",
|
||||
"messages":[{"role":"user","content":"return structured output"}],
|
||||
"response_format":{
|
||||
"type":"json_schema",
|
||||
"json_schema":{
|
||||
"name":"demo_schema",
|
||||
"strict":true,
|
||||
"schema":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"beta":{"type":"string"},
|
||||
"alpha":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"z":{"type":"string"},
|
||||
"a":{"type":"string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var transformed map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(processedBody, &transformed))
|
||||
|
||||
generationConfig, ok := transformed["generationConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "generationConfig should exist")
|
||||
_, hasMimeType := generationConfig["responseMimeType"]
|
||||
_, hasSchema := generationConfig["responseSchema"]
|
||||
require.False(t, hasMimeType, "gemini-2.0 should ignore response_format and not set responseMimeType")
|
||||
require.False(t, hasSchema, "gemini-2.0 should ignore response_format and not set responseSchema")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: json_object 映射
|
||||
t.Run("vertex express mode structured outputs json_object request body mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.5-flash",
|
||||
"messages":[{"role":"user","content":"return json"}],
|
||||
"response_format":{"type":"json_object"}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var transformed map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(processedBody, &transformed))
|
||||
|
||||
generationConfig, ok := transformed["generationConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "generationConfig should exist")
|
||||
require.Equal(t, "application/json", generationConfig["responseMimeType"], "responseMimeType should be mapped for json_object")
|
||||
|
||||
_, hasSchema := generationConfig["responseSchema"]
|
||||
require.False(t, hasSchema, "json_object should not inject responseSchema")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: 兼容 direct schema
|
||||
t.Run("vertex express mode structured outputs direct schema response_format mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.5-flash",
|
||||
"messages":[{"role":"user","content":"return structured output"}],
|
||||
"response_format":{
|
||||
"type":"object",
|
||||
"properties":{"city":{"type":"string"}}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var transformed map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(processedBody, &transformed))
|
||||
|
||||
generationConfig, ok := transformed["generationConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "generationConfig should exist")
|
||||
require.Equal(t, "application/json", generationConfig["responseMimeType"], "direct schema should be mapped to JSON mime type")
|
||||
|
||||
responseSchema, ok := generationConfig["responseSchema"].(map[string]interface{})
|
||||
require.True(t, ok, "direct schema should be mapped to responseSchema")
|
||||
require.Equal(t, "object", responseSchema["type"])
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: 异常 json_schema 应返回错误(不能静默降级)
|
||||
t.Run("vertex express mode structured outputs malformed json_schema mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.5-flash",
|
||||
"messages":[{"role":"user","content":"return structured output"}],
|
||||
"response_format":{
|
||||
"type":"json_schema",
|
||||
"json_schema":"invalid"
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
errorLogs := host.GetErrorLogs()
|
||||
hasInvalidSchemaError := false
|
||||
for _, log := range errorLogs {
|
||||
if strings.Contains(log, "invalid response_format.json_schema") {
|
||||
hasInvalidSchemaError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasInvalidSchemaError, "malformed json_schema should produce explicit validation error")
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
require.Contains(t, string(processedBody), `"response_format"`, "failed request should keep original body")
|
||||
require.NotContains(t, string(processedBody), `"generationConfig"`, "failed request should not be rewritten into Vertex format")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, "/v1/chat/completions", pathHeader, "failed validation should not rewrite upstream path")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode structured outputs: 未知类型不映射
|
||||
t.Run("vertex express mode structured outputs unknown response format type", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.5-flash",
|
||||
"messages":[{"role":"user","content":"return xml"}],
|
||||
"response_format":{"type":"xml"}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var transformed map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(processedBody, &transformed))
|
||||
|
||||
generationConfig, ok := transformed["generationConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "generationConfig should exist")
|
||||
_, hasMime := generationConfig["responseMimeType"]
|
||||
_, hasSchema := generationConfig["responseSchema"]
|
||||
require.False(t, hasMime, "unknown response_format type should not inject responseMimeType")
|
||||
require.False(t, hasSchema, "unknown response_format type should not inject responseSchema")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(嵌入接口)
|
||||
t.Run("vertex express mode embeddings request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
@@ -611,8 +880,8 @@ func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// OpenAI 兼容模式需要等待 OAuth token,所以返回 ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
// 测试环境使用伪造密钥,OAuth 获取会失败,期望 ActionContinue 并记录错误
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体保持 OpenAI 格式(不转换为 Vertex 原生格式)
|
||||
processedBody := host.GetRequestBody()
|
||||
@@ -635,6 +904,47 @@ func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.Contains(t, pathHeader, "/endpoints/openapi/chat/completions", "Path should contain openapi chat completions endpoint")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式 structured outputs 请求体透传
|
||||
t.Run("vertex openai compatible mode structured outputs passthrough", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{
|
||||
"model":"gemini-2.0-flash",
|
||||
"messages":[{"role":"user","content":"test"}],
|
||||
"response_format":{
|
||||
"type":"json_schema",
|
||||
"json_schema":{
|
||||
"name":"demo_schema",
|
||||
"strict":true,
|
||||
"schema":{
|
||||
"type":"object",
|
||||
"properties":{"answer":{"type":"string"}},
|
||||
"required":["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
|
||||
require.Contains(t, bodyStr, `"response_format"`, "OpenAI compatible mode should preserve response_format")
|
||||
require.Contains(t, bodyStr, `"json_schema"`, "OpenAI compatible mode should preserve json_schema")
|
||||
require.NotContains(t, bodyStr, `"generationConfig"`, "OpenAI compatible mode should not convert to Vertex native generationConfig")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式请求体处理(含模型映射)
|
||||
t.Run("vertex openai compatible mode with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig)
|
||||
@@ -653,7 +963,7 @@ func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体中的模型名被映射
|
||||
processedBody := host.GetRequestBody()
|
||||
@@ -689,7 +999,7 @@ func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
|
||||
func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 流式响应处理
|
||||
// 测试 Vertex Express Mode 流式响应处理:最后一个 chunk 不应丢失
|
||||
t.Run("vertex express mode streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
@@ -707,6 +1017,9 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
@@ -715,8 +1028,8 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}`
|
||||
chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}`
|
||||
chunk1 := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello\"}],\"role\":\"model\"},\"finishReason\":\"\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":5,\"totalTokenCount\":14}}\n\n"
|
||||
chunk2 := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello! How can I help you today?\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":12,\"totalTokenCount\":21}}\n\n"
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
@@ -725,16 +1038,194 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
// 验证流式响应处理
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") {
|
||||
hasStreamingLogs = true
|
||||
// 验证最后一个 chunk 的内容不会被 [DONE] 覆盖
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "Hello! How can I help you today?", "last chunk content should be preserved")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 流式响应处理:单个 SSE 事件被拆包时可正确重组
|
||||
t.Run("vertex express mode streaming response body with split sse event", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
fullEvent := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"split chunk\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":1,\"candidatesTokenCount\":2,\"totalTokenCount\":3}}\n\n"
|
||||
splitIdx := strings.Index(fullEvent, "chunk")
|
||||
require.Greater(t, splitIdx, 0, "split marker should exist in test payload")
|
||||
chunkPart1 := fullEvent[:splitIdx]
|
||||
chunkPart2 := fullEvent[splitIdx:]
|
||||
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "split chunk", "split SSE event should be reassembled and parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
|
||||
// 测试:thoughtSignature 很大时,单个 SSE 事件被拆成多段也能重组并成功解析
|
||||
t.Run("vertex express mode streaming response body with huge thought signature split across chunks", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
hugeThoughtSignature := strings.Repeat("CmMBjz1rX4j+TQjtDy2rZxSdYOE1jUqDbRhWetraLlQNrkyaRNQZ/", 180)
|
||||
fullEvent := "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"thought-signature-merge-ok\",\"thoughtSignature\":\"" +
|
||||
hugeThoughtSignature +
|
||||
"\"}]},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":28,\"candidatesTokenCount\":3589,\"totalTokenCount\":5240,\"thoughtsTokenCount\":1623}}\n\n"
|
||||
|
||||
signatureStart := strings.Index(fullEvent, "\"thoughtSignature\":\"")
|
||||
require.Greater(t, signatureStart, 0, "thoughtSignature field should exist in test payload")
|
||||
splitAt1 := signatureStart + len("\"thoughtSignature\":\"") + 700
|
||||
splitAt2 := splitAt1 + 1600
|
||||
require.Less(t, splitAt2, len(fullEvent)-1, "split indexes should keep payload in three chunks")
|
||||
|
||||
chunkPart1 := fullEvent[:splitAt1]
|
||||
chunkPart2 := fullEvent[splitAt1:splitAt2]
|
||||
chunkPart3 := fullEvent[splitAt2:]
|
||||
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
firstBody := host.GetResponseBody()
|
||||
require.Equal(t, 0, len(firstBody), "partial chunk should not be forwarded to client")
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart2), false)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
secondBody := host.GetResponseBody()
|
||||
require.Equal(t, 0, len(secondBody), "partial chunk should not be forwarded to client")
|
||||
|
||||
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart3), true)
|
||||
require.Equal(t, types.ActionContinue, action3)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "thought-signature-merge-ok", "split huge thoughtSignature event should be reassembled and parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
|
||||
errorLogs := host.GetErrorLogs()
|
||||
hasUnmarshalError := false
|
||||
for _, log := range errorLogs {
|
||||
if strings.Contains(log, "unable to unmarshal vertex response") {
|
||||
hasUnmarshalError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
require.False(t, hasUnmarshalError, "should not have vertex unmarshal errors for split huge thoughtSignature event")
|
||||
})
|
||||
|
||||
// 测试:上游已发送 [DONE],框架再触发空的最后回调时不应重复输出 [DONE]
|
||||
t.Run("vertex express mode streaming response body with upstream done and empty final callback", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
doneChunk := "data: [DONE]\n\n"
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(doneChunk), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
firstBody := host.GetResponseBody()
|
||||
require.NotNil(t, firstBody)
|
||||
require.Contains(t, string(firstBody), "data: [DONE]", "first callback should output [DONE]")
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte{}, true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
debugLogs := host.GetDebugLogs()
|
||||
doneChunkLogCount := 0
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "=== modified response chunk: data: [DONE]") {
|
||||
doneChunkLogCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, doneChunkLogCount, "[DONE] should only be emitted once when upstream already sent it")
|
||||
})
|
||||
|
||||
// 测试:最后一个 chunk 缺少 SSE 结束空行时,isLastChunk=true 也应正确解析并输出
|
||||
t.Run("vertex express mode streaming response body last chunk without terminator", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
lastChunkWithoutTerminator := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"no terminator\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":2,\"candidatesTokenCount\":3,\"totalTokenCount\":5}}"
|
||||
action := host.CallOnHttpStreamingResponseBody([]byte(lastChunkWithoutTerminator), true)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "no terminator", "last chunk without terminator should still be parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1273,6 +1764,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
var buffer bytes.Buffer
|
||||
writer := multipart.NewWriter(&buffer)
|
||||
|
||||
for key, value := range fields {
|
||||
require.NoError(t, writer.WriteField(key, value))
|
||||
}
|
||||
|
||||
for fieldName, data := range files {
|
||||
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, writer.Close())
|
||||
return buffer.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "Add sunglasses to the cat",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 12,
|
||||
"candidatesTokenCount": 1024,
|
||||
"totalTokenCount": 1036
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 8,
|
||||
"candidatesTokenCount": 768,
|
||||
"totalTokenCount": 776
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== Vertex Raw 模式测试 ====================
|
||||
|
||||
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
|
||||
138
plugins/wasm-go/extensions/ai-proxy/test/zhipuai.go
Normal file
138
plugins/wasm-go/extensions/ai-proxy/test/zhipuai.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var basicZhipuAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "zhipuai",
|
||||
"apiTokens": []string{"sk-zhipuai-test"},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunZhipuAIClaudeAutoConversionTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("claude thinking enabled sets thinking enabled for zhipuai", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicZhipuAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "glm-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8192}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ZhipuAI should have thinking=enabled (converted from reasoning_effort)
|
||||
thinking, ok := bodyMap["thinking"].(map[string]interface{})
|
||||
require.True(t, ok, "thinking field should be present")
|
||||
assert.Equal(t, "enabled", thinking["type"])
|
||||
|
||||
// reasoning_effort should be removed (ZhipuAI doesn't recognize it)
|
||||
assert.NotContains(t, bodyMap, "reasoning_effort")
|
||||
})
|
||||
|
||||
t.Run("claude without thinking sets thinking disabled for zhipuai", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicZhipuAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "glm-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ZhipuAI should explicitly set thinking=disabled
|
||||
thinking, ok := bodyMap["thinking"].(map[string]interface{})
|
||||
require.True(t, ok, "thinking field should be present for disabled state")
|
||||
assert.Equal(t, "disabled", thinking["type"])
|
||||
})
|
||||
|
||||
t.Run("claude thinking disabled sets thinking disabled for zhipuai", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicZhipuAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/messages"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "glm-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "disabled"}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ZhipuAI should explicitly set thinking=disabled
|
||||
thinking, ok := bodyMap["thinking"].(map[string]interface{})
|
||||
require.True(t, ok, "thinking field should be present for disabled state")
|
||||
assert.Equal(t, "disabled", thinking["type"])
|
||||
|
||||
// No reasoning fields
|
||||
assert.NotContains(t, bodyMap, "reasoning_effort")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -41,15 +41,43 @@ description: 阿里云内容安全检测
|
||||
| `consumerResponseCheckService` | map | optional | - | 为不同消费者指定特定的响应检测服务 |
|
||||
| `consumerRiskLevel` | map | optional | - | 为不同消费者指定各维度的拦截风险等级 |
|
||||
|
||||
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
|
||||
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
|
||||
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应
|
||||
### 拒绝响应结构
|
||||
|
||||
如果用户使用了非openai格式的协议,此时对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回用户配置的 `denyMessage` 内容,非流式响应
|
||||
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应
|
||||
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,非流式响应
|
||||
内容被拦截时,插件(`MultiModalGuard` action)统一返回以下结构化 JSON 对象,各协议的承载位置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"blockedDetails": [
|
||||
{
|
||||
"Type": "contentModeration",
|
||||
"Level": "high",
|
||||
"Suggestion": "block"
|
||||
}
|
||||
],
|
||||
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
|
||||
"guardCode": 200
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| `blockedDetails` | array | 命中拦截的维度明细;若安全服务未返回明细,则根据顶层风险信号自动合成 |
|
||||
| `blockedDetails[].Type` | string | 风险类型:`contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
|
||||
| `blockedDetails[].Level` | string | 风险等级:`high` / `medium` / `low` 等 |
|
||||
| `blockedDetails[].Suggestion` | string | 安全服务建议操作,通常为 `block` |
|
||||
| `requestId` | string | 安全服务的请求 ID,用于追踪 |
|
||||
| `guardCode` | int | 安全服务返回的业务码(非 HTTP 状态码,成功检测时为 `200`) |
|
||||
|
||||
各协议承载位置:
|
||||
|
||||
- **`text_generation`(OpenAI 非流式)**:上述结构体序列化为 JSON 字符串后放入 `choices[0].message.content`
|
||||
- **`text_generation`(OpenAI 流式 SSE)**:同上,放入首个 chunk 的 `delta.content`
|
||||
- **`text_generation`(`protocol=original`)**:上述结构体直接作为 JSON 响应 body 返回
|
||||
- **`image_generation`**:上述结构体直接作为 JSON 响应 body 返回(HTTP 403)
|
||||
- **`mcp`(JSON-RPC)**:上述结构体序列化为 JSON 字符串后放入 `error.message`
|
||||
- **`mcp`(SSE)**:同上,通过 SSE 事件返回
|
||||
|
||||
补充说明一下内容合规检测、提示词攻击检测、敏感内容检测三种风险的四个等级:
|
||||
|
||||
|
||||
@@ -41,6 +41,43 @@ Plugin Priority: `300`
|
||||
| `consumerResponseCheckService` | map | optional | - | Specify specific response detection services for different consumers |
|
||||
| `consumerRiskLevel` | map | optional | - | Specify interception risk levels for different consumers in different dimensions |
|
||||
|
||||
### Deny Response Body
|
||||
|
||||
When content is blocked, the plugin (`MultiModalGuard` action) returns the following structured JSON object. The location in the response depends on the protocol:
|
||||
|
||||
```json
|
||||
{
|
||||
"blockedDetails": [
|
||||
{
|
||||
"Type": "contentModeration",
|
||||
"Level": "high",
|
||||
"Suggestion": "block"
|
||||
}
|
||||
],
|
||||
"requestId": "AAAAAA-BBBB-CCCC-DDDD-EEEEEEE****",
|
||||
"guardCode": 200
|
||||
}
|
||||
```
|
||||
|
||||
Field descriptions:
|
||||
|
||||
| Field | Type | Description |
|
||||
| --- | --- | --- |
|
||||
| `blockedDetails` | array | Details of the triggered blocking dimensions. Synthesised from top-level risk signals when the security service returns no detail entries. |
|
||||
| `blockedDetails[].Type` | string | Risk type: `contentModeration` / `promptAttack` / `sensitiveData` / `maliciousUrl` / `modelHallucination` |
|
||||
| `blockedDetails[].Level` | string | Risk level: `high` / `medium` / `low` etc. |
|
||||
| `blockedDetails[].Suggestion` | string | Action recommended by the security service, usually `block` |
|
||||
| `requestId` | string | Request ID from the security service, for tracing |
|
||||
| `guardCode` | int | Business code returned by the security service (not an HTTP status code; `200` indicates a successful check that detected a risk) |
|
||||
|
||||
How the body is embedded per protocol:
|
||||
|
||||
- **`text_generation` (OpenAI non-streaming)**: serialised as a JSON string and placed in `choices[0].message.content`
|
||||
- **`text_generation` (OpenAI streaming SSE)**: same, placed in `delta.content` of the first chunk
|
||||
- **`text_generation` (`protocol=original`)**: returned directly as the JSON response body
|
||||
- **`image_generation`**: returned directly as the JSON response body (HTTP 403)
|
||||
- **`mcp` (JSON-RPC)**: serialised as a JSON string and placed in `error.message`
|
||||
- **`mcp` (SSE)**: same, returned via SSE event
|
||||
|
||||
## Examples of configuration
|
||||
### Check if the input is legal
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
@@ -584,3 +585,68 @@ func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, co
|
||||
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
|
||||
}
|
||||
}
|
||||
|
||||
type DenyResponseBody struct {
|
||||
BlockedDetails []Detail `json:"blockedDetails"`
|
||||
RequestId string `json:"requestId"`
|
||||
// GuardCode is the business code returned by the security service (typically 200 when the check
|
||||
// succeeded and a risk was detected). It is NOT an HTTP status code.
|
||||
GuardCode int `json:"guardCode"`
|
||||
}
|
||||
|
||||
func BuildDenyResponseBody(response Response, config AISecurityConfig, consumer string) ([]byte, error) {
|
||||
body := DenyResponseBody{
|
||||
BlockedDetails: GetUnacceptableDetail(response.Data, config, consumer),
|
||||
RequestId: response.RequestId,
|
||||
GuardCode: response.Code,
|
||||
}
|
||||
return json.Marshal(body)
|
||||
}
|
||||
|
||||
func GetUnacceptableDetail(data Data, config AISecurityConfig, consumer string) []Detail {
|
||||
result := []Detail{}
|
||||
for _, detail := range data.Detail {
|
||||
switch detail.Type {
|
||||
case ContentModerationType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
result = append(result, detail)
|
||||
}
|
||||
case PromptAttackType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
result = append(result, detail)
|
||||
}
|
||||
case SensitiveDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
|
||||
result = append(result, detail)
|
||||
}
|
||||
case MaliciousUrlDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
|
||||
result = append(result, detail)
|
||||
}
|
||||
case ModelHallucinationDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
|
||||
result = append(result, detail)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: when the security service returns a top-level risk signal but no Detail entries,
|
||||
// synthesise detail items from RiskLevel/AttackLevel so blockedDetails is never empty on a
|
||||
// real block event.
|
||||
if len(result) == 0 {
|
||||
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
result = append(result, Detail{
|
||||
Type: ContentModerationType,
|
||||
Level: data.RiskLevel,
|
||||
Suggestion: "block",
|
||||
})
|
||||
}
|
||||
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
result = append(result, Detail{
|
||||
Type: PromptAttackType,
|
||||
Level: data.AttackLevel,
|
||||
Suggestion: "block",
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -65,13 +65,19 @@ func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config c
|
||||
return
|
||||
}
|
||||
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = "\n" + response.Data.Advice[0].Answer
|
||||
} else if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
||||
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
||||
bufferQueue = [][]byte{}
|
||||
if !endStream {
|
||||
ctx.SetContext("during_call", false)
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
||||
@@ -199,21 +205,22 @@ func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecu
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if isStreamingResponse {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
|
||||
@@ -85,14 +85,13 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
@@ -157,14 +156,13 @@ func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
@@ -244,7 +242,13 @@ func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg
|
||||
}
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
|
||||
@@ -243,14 +243,13 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
@@ -315,14 +314,13 @@ func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AI
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
@@ -402,14 +400,13 @@ func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.A
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
|
||||
@@ -2,6 +2,7 @@ package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,9 +19,9 @@ import (
|
||||
|
||||
const (
|
||||
MethodToolCall = "tools/call"
|
||||
DenyResponse = `{"jsonrpc":"2.0","id":0,"error":{"code":403,"message":"blocked by security guard"}}`
|
||||
DenyResponse = `{"jsonrpc":"2.0","id":0,"error":{"code":403,"message":"%s"}}`
|
||||
DenySSEResponse = `event: message
|
||||
data: {"jsonrpc":"2.0","id":0,"error":{"code":403,"message":"blocked by security guard"}}
|
||||
data: {"jsonrpc":"2.0","id":0,"error":{"code":403,"message":"%s"}}
|
||||
|
||||
`
|
||||
)
|
||||
@@ -78,7 +79,15 @@ func HandleMcpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(DenyResponse), -1)
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
denyResponse := fmt.Sprintf(DenyResponse, marshalledDenyMessage)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyResponse), -1)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
@@ -124,7 +133,15 @@ func HandleMcpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecuri
|
||||
return
|
||||
}
|
||||
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
proxywasm.InjectEncodedDataToFilterChain([]byte(DenySSEResponse), true)
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
denySSEResponse := fmt.Sprintf(DenySSEResponse, marshalledDenyMessage)
|
||||
proxywasm.InjectEncodedDataToFilterChain([]byte(denySSEResponse), true)
|
||||
} else {
|
||||
proxywasm.InjectEncodedDataToFilterChain(frontBuffer, false)
|
||||
}
|
||||
@@ -212,8 +229,16 @@ func HandleMcpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig,
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
denyResponseBody := fmt.Sprintf(DenyResponse, marshalledDenyMessage)
|
||||
proxywasm.RemoveHttpResponseHeader("content-length")
|
||||
proxywasm.ReplaceHttpResponseBody([]byte(DenyResponse))
|
||||
proxywasm.ReplaceHttpResponseBody([]byte(denyResponseBody))
|
||||
proxywasm.ResumeHttpResponse()
|
||||
// proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(DenyResponse), -1)
|
||||
}
|
||||
|
||||
@@ -96,21 +96,22 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
@@ -178,21 +179,22 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
|
||||
@@ -53,21 +53,22 @@ func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecur
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
denyBody, err := cfg.BuildDenyResponseBody(response, config, consumer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to build deny response body: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, denyBody, -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
marshalledDenyMessage := wrapper.MarshalStr(string(denyBody))
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
|
||||
@@ -156,6 +156,90 @@ var mcpConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:MCP配置(启用请求检查)
|
||||
var mcpRequestConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": false,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "mcp",
|
||||
"requestContentJsonPath": "params.arguments",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:MultiModalGuard 文本生成
|
||||
var multiModalGuardTextConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "text_generation",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
"bufferLimit": 1000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:MultiModalGuard OpenAI 图像生成
|
||||
var multiModalGuardImageOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "image_generation",
|
||||
"providerType": "openai",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:MultiModalGuard Qwen 图像生成
|
||||
var multiModalGuardImageQwenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "security-service",
|
||||
"servicePort": 8080,
|
||||
"serviceHost": "security.example.com",
|
||||
"accessKey": "test-ak",
|
||||
"secretKey": "test-sk",
|
||||
"checkRequest": true,
|
||||
"checkResponse": true,
|
||||
"action": "MultiModalGuard",
|
||||
"apiType": "image_generation",
|
||||
"providerType": "qwen",
|
||||
"contentModerationLevelBar": "high",
|
||||
"promptAttackLevelBar": "high",
|
||||
"sensitiveDataLevelBar": "S3",
|
||||
"timeout": 2000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础配置解析
|
||||
@@ -330,6 +414,51 @@ func TestOnHttpRequestBody(t *testing.T) {
|
||||
// 空内容应该直接通过
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// TextModerationPlus(默认 action,含 agent/OpenAI 形态)请求拦截应返回 choices[0].message.content 内的 blockedDetails JSON
|
||||
t.Run("text moderation plus request deny returns blockedDetails in openai completion shape", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for request deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-tmp-deny")
|
||||
|
||||
type openAIChatCompletion struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
var outer openAIChatCompletion
|
||||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||||
require.Len(t, outer.Choices, 1)
|
||||
|
||||
var deny cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny))
|
||||
require.Equal(t, "req-tmp-deny", deny.RequestId)
|
||||
require.Equal(t, 200, deny.GuardCode)
|
||||
require.NotEmpty(t, deny.BlockedDetails)
|
||||
require.Equal(t, cfg.ContentModerationType, deny.BlockedDetails[0].Type)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -649,3 +778,444 @@ func TestUtilityFunctions(t *testing.T) {
|
||||
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiModalGuardTextGenerationDeny(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// MultiModalGuard text_generation request deny → exercises multi_modal_guard/text/openai.go BuildDenyResponseBody path
|
||||
t.Run("multi modal guard text request deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"messages": [{"role": "user", "content": "trigger deny"}]}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-text-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for request deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-mmg-text-deny")
|
||||
})
|
||||
|
||||
// MultiModalGuard text_generation response deny → exercises common/text/openai.go HandleTextGenerationResponseBody BuildDenyResponseBody path
|
||||
t.Run("multi modal guard text response deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response content"}}]}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for response deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-mmg-resp-deny")
|
||||
})
|
||||
|
||||
// MultiModalGuard text_generation request pass
|
||||
t.Run("multi modal guard text request pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardTextConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"messages": [{"role": "user", "content": "Hello"}]}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mmg-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action := host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiModalGuardImageGenerationDeny(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// OpenAI image generation request deny → exercises multi_modal_guard/image/openai.go BuildDenyResponseBody path
|
||||
t.Run("openai image request deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"prompt": "generate bad image"}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-openai-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for OpenAI image request deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-img-openai-deny")
|
||||
})
|
||||
|
||||
// OpenAI image generation request pass
|
||||
t.Run("openai image request pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardImageOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"prompt": "a cute cat"}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action := host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// Qwen image generation request deny → exercises multi_modal_guard/image/qwen.go BuildDenyResponseBody path
|
||||
t.Run("qwen image request deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"input": {"prompt": "generate bad image"}}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-img-qwen-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for Qwen image request deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-img-qwen-deny")
|
||||
})
|
||||
|
||||
// Qwen image generation request pass
|
||||
t.Run("qwen image request pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiModalGuardImageQwenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"input": {"prompt": "a cute cat"}}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-qwen-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action := host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPRequestDeny(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// MCP request deny → exercises multi_modal_guard/mcp/mcp.go HandleMcpRequestBody BuildDenyResponseBody path
|
||||
t.Run("mcp request deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(mcpRequestConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/mcp/call"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"method": "tools/call", "params": {"arguments": "bad request content"}}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for MCP request deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-mcp-deny")
|
||||
})
|
||||
|
||||
// MCP request pass
|
||||
t.Run("mcp request pass", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(mcpRequestConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/mcp/call"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"method": "tools/call", "params": {"arguments": "safe content"}}`
|
||||
require.Equal(t, types.ActionPause, host.CallOnHttpRequestBody([]byte(body)))
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-mcp-pass", "Data": {"RiskLevel": "low"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
action := host.GetHttpStreamAction()
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// MCP request skip non-tool-call method
|
||||
t.Run("mcp request skip non-tool-call", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(mcpRequestConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/mcp/call"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
body := `{"method": "resources/list", "params": {}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestTextModerationPlusResponseDeny(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// TextModerationPlus response deny → exercises text_moderation_plus/text (via common/text) BuildDenyResponseBody response path
|
||||
t.Run("text moderation plus response deny returns blockedDetails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
body := `{"choices": [{"message": {"role": "assistant", "content": "bad response"}}]}`
|
||||
action := host.CallOnHttpResponseBody([]byte(body))
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-tmp-resp-deny", "Data": {"RiskLevel": "high"}}`
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(securityResponse))
|
||||
|
||||
local := host.GetLocalResponse()
|
||||
require.NotNil(t, local, "expected SendHttpResponse for response deny")
|
||||
require.Contains(t, string(local.Data), "blockedDetails")
|
||||
require.Contains(t, string(local.Data), "req-tmp-resp-deny")
|
||||
|
||||
// Verify OpenAI completion shape wrapper
|
||||
type openAIChatCompletion struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
var outer openAIChatCompletion
|
||||
require.NoError(t, json.Unmarshal(local.Data, &outer))
|
||||
require.Len(t, outer.Choices, 1)
|
||||
|
||||
var deny cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal([]byte(outer.Choices[0].Message.Content), &deny))
|
||||
require.Equal(t, "req-tmp-resp-deny", deny.RequestId)
|
||||
require.Equal(t, 200, deny.GuardCode)
|
||||
require.NotEmpty(t, deny.BlockedDetails)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildDenyResponseBody(t *testing.T) {
|
||||
makeConfig := func(contentBar, promptBar string) cfg.AISecurityConfig {
|
||||
return cfg.AISecurityConfig{
|
||||
ContentModerationLevelBar: contentBar,
|
||||
PromptAttackLevelBar: promptBar,
|
||||
SensitiveDataLevelBar: "S4",
|
||||
MaliciousUrlLevelBar: "max",
|
||||
ModelHallucinationLevelBar: "max",
|
||||
Action: cfg.MultiModalGuard,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("guardCode equals response.Code", func(t *testing.T) {
|
||||
resp := cfg.Response{
|
||||
Code: 200,
|
||||
RequestId: "req-123",
|
||||
Data: cfg.Data{},
|
||||
}
|
||||
body, err := cfg.BuildDenyResponseBody(resp, makeConfig("high", "high"), "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var result cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
require.Equal(t, 200, result.GuardCode)
|
||||
require.Equal(t, "req-123", result.RequestId)
|
||||
})
|
||||
|
||||
t.Run("blockedDetails from Data.Detail", func(t *testing.T) {
|
||||
resp := cfg.Response{
|
||||
Code: 200,
|
||||
RequestId: "req-456",
|
||||
Data: cfg.Data{
|
||||
Detail: []cfg.Detail{
|
||||
{Type: cfg.ContentModerationType, Level: "high", Suggestion: "block"},
|
||||
{Type: cfg.PromptAttackType, Level: "low", Suggestion: "block"},
|
||||
},
|
||||
},
|
||||
}
|
||||
config := makeConfig("high", "high")
|
||||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var result cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
// only the contentModeration entry meets the "high" bar; promptAttack at "low" does not
|
||||
require.Len(t, result.BlockedDetails, 1)
|
||||
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
|
||||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||||
})
|
||||
|
||||
t.Run("blockedDetails fallback from RiskLevel when Detail is empty", func(t *testing.T) {
|
||||
resp := cfg.Response{
|
||||
Code: 200,
|
||||
RequestId: "req-789",
|
||||
Data: cfg.Data{
|
||||
RiskLevel: "high",
|
||||
// Detail deliberately empty
|
||||
},
|
||||
}
|
||||
config := makeConfig("high", "high")
|
||||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var result cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from RiskLevel")
|
||||
require.Equal(t, cfg.ContentModerationType, result.BlockedDetails[0].Type)
|
||||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||||
require.Equal(t, "block", result.BlockedDetails[0].Suggestion)
|
||||
})
|
||||
|
||||
t.Run("blockedDetails fallback from AttackLevel when Detail is empty", func(t *testing.T) {
|
||||
resp := cfg.Response{
|
||||
Code: 200,
|
||||
RequestId: "req-abc",
|
||||
Data: cfg.Data{
|
||||
AttackLevel: "high",
|
||||
// Detail deliberately empty
|
||||
},
|
||||
}
|
||||
config := makeConfig("high", "high")
|
||||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var result cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
require.NotEmpty(t, result.BlockedDetails, "expected fallback detail from AttackLevel")
|
||||
require.Equal(t, cfg.PromptAttackType, result.BlockedDetails[0].Type)
|
||||
require.Equal(t, "high", result.BlockedDetails[0].Level)
|
||||
require.Equal(t, "block", result.BlockedDetails[0].Suggestion)
|
||||
})
|
||||
|
||||
t.Run("blockedDetails empty when risk levels below threshold", func(t *testing.T) {
|
||||
resp := cfg.Response{
|
||||
Code: 200,
|
||||
RequestId: "req-def",
|
||||
Data: cfg.Data{
|
||||
RiskLevel: "low",
|
||||
AttackLevel: "low",
|
||||
},
|
||||
}
|
||||
// threshold is "high", so "low" must not produce fallback entries
|
||||
config := makeConfig("high", "high")
|
||||
body, err := cfg.BuildDenyResponseBody(resp, config, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var result cfg.DenyResponseBody
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
require.Empty(t, result.BlockedDetails)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ description: AI可观测配置参考
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-------|------|-----|------------------------|
|
||||
| `use_default_attributes` | bool | 非必填 | false | 是否使用默认完整属性配置,包含 messages、answer、question 等所有字段。适用于调试、审计场景 |
|
||||
| `use_default_response_attributes` | bool | 非必填 | false | 是否使用轻量级默认属性配置(推荐),包含 model 和 token 统计,不缓冲流式响应体。适用于高并发生产环境 |
|
||||
| `attributes` | []Attribute | 非必填 | - | 用户希望记录在log/span中的信息 |
|
||||
| `disable_openai_usage` | bool | 非必填 | false | 非openai兼容协议时,model、token的支持非标,配置为true时可以避免报错 |
|
||||
| `value_length_limit` | int | 非必填 | 4000 | 记录的单个value的长度限制 |
|
||||
@@ -67,6 +69,7 @@ Attribute 配置说明:
|
||||
| 内置属性键 | 说明 | 适用场景 |
|
||||
|---------|------|---------|
|
||||
| `question` | 用户提问内容 | 支持 OpenAI/Claude 消息格式 |
|
||||
| `system` | 系统提示词 | 支持 Claude `/v1/messages` 的顶层 system 字段 |
|
||||
| `answer` | AI 回答内容 | 支持 OpenAI/Claude 消息格式,流式和非流式 |
|
||||
| `tool_calls` | 工具调用信息 | OpenAI/Claude 工具调用 |
|
||||
| `reasoning` | 推理过程 | OpenAI o1 等推理模型 |
|
||||
@@ -332,6 +335,195 @@ attributes:
|
||||
2. **性能分析**:分析推理 token 占比,评估推理模型的实际开销
|
||||
3. **使用统计**:细粒度统计各类 token 的使用情况
|
||||
|
||||
## 流式响应观测能力
|
||||
|
||||
流式(Streaming)响应是 AI 对话的常见场景,插件提供了完善的流式观测支持,能够正确拼接和提取流式响应中的关键信息。
|
||||
|
||||
### 流式响应的挑战
|
||||
|
||||
流式响应将完整内容拆分为多个 SSE chunk 逐步返回,例如:
|
||||
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"Hello"}}]}
|
||||
data: {"choices":[{"delta":{"content":" 👋"}}]}
|
||||
data: {"choices":[{"delta":{"content":"!"}}]}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
要获取完整的回答内容,需要将各个 chunk 中的 `delta.content` 拼接起来。
|
||||
|
||||
### 自动拼接机制
|
||||
|
||||
插件针对不同类型的内容提供了自动拼接能力:
|
||||
|
||||
| 内容类型 | 拼接方式 | 说明 |
|
||||
|---------|---------|------|
|
||||
| `answer` | 文本追加(append) | 将各 chunk 的 `delta.content` 按顺序拼接成完整回答 |
|
||||
| `reasoning` | 文本追加(append) | 将各 chunk 的 `delta.reasoning_content` 按顺序拼接 |
|
||||
| `tool_calls` | 按 index 组装 | 识别每个工具调用的 `index`,分别拼接各自的 `arguments` |
|
||||
|
||||
#### answer 和 reasoning 拼接示例
|
||||
|
||||
流式响应:
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"你好"}}]}
|
||||
data: {"choices":[{"delta":{"content":",我是"}}]}
|
||||
data: {"choices":[{"delta":{"content":"AI助手"}}]}
|
||||
```
|
||||
|
||||
最终提取的 `answer`:`"你好,我是AI助手"`
|
||||
|
||||
#### tool_calls 拼接示例
|
||||
|
||||
流式响应(多个并行工具调用):
|
||||
```
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_001","function":{"name":"get_weather"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":1,"id":"call_002","function":{"name":"get_time"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Beijing\"}"}}]}}]}
|
||||
data: {"choices":[{"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"city\":\"Shanghai\"}"}}]}}]}
|
||||
```
|
||||
|
||||
最终提取的 `tool_calls`:
|
||||
```json
|
||||
[
|
||||
{"index":0,"id":"call_001","function":{"name":"get_weather","arguments":"{\"city\":\"Beijing\"}"}},
|
||||
{"index":1,"id":"call_002","function":{"name":"get_time","arguments":"{\"city\":\"Shanghai\"}"}}
|
||||
]
|
||||
```
|
||||
|
||||
### 使用默认配置快速启用
|
||||
|
||||
插件提供两种默认配置模式:
|
||||
|
||||
#### 轻量模式(推荐用于生产环境)
|
||||
|
||||
通过 `use_default_response_attributes: true` 启用轻量模式:
|
||||
|
||||
```yaml
|
||||
use_default_response_attributes: true
|
||||
```
|
||||
|
||||
此配置是**推荐的生产环境配置**,特别适合高并发、高延迟的场景:
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `model` | 模型名称(从请求体提取) |
|
||||
| `reasoning_tokens` | 推理 token 数 |
|
||||
| `cached_tokens` | 缓存命中 token 数 |
|
||||
| `input_token_details` | 输入 token 详情 |
|
||||
| `output_token_details` | 输出 token 详情 |
|
||||
|
||||
**为什么推荐轻量模式?**
|
||||
|
||||
LLM 请求有两个显著特点:**延迟高**(通常数秒到数十秒)和**请求体大**(多轮对话可能达到数百 KB 甚至 MB 级别)。
|
||||
|
||||
在高并发场景下,如果请求体和响应体都被缓存在内存中,积压的请求会占用大量内存:
|
||||
- 假设 QPS=100,平均延迟=10秒,请求体=500KB
|
||||
- 同时在处理的请求数 ≈ 100 × 10 = 1000 个
|
||||
- 如果缓存完整请求体+响应体:1000 × 1.5MB ≈ **1.5GB 内存**
|
||||
|
||||
轻量模式通过以下方式降低内存占用:
|
||||
- **缓冲请求体**:仅用于提取 `model` 字段(很小),不提取 `question`、`system`、`messages` 等大字段
|
||||
- **不缓冲流式响应体**:不提取 `answer`、`reasoning`、`tool_calls` 等需要完整响应的字段
|
||||
- **只统计 token**:从响应的 usage 字段提取 token 信息
|
||||
|
||||
**内存对比:**
|
||||
|
||||
| 场景 | 完整模式 | 轻量模式 |
|
||||
|------|----------|----------|
|
||||
| 单次请求 (1MB 请求 + 500KB 响应) | ~1.5MB | ~1MB(请求体) |
|
||||
| 高并发 (100 QPS, 10s 延迟) | ~1.5GB | ~1GB |
|
||||
| 超高并发 (1000 QPS, 10s 延迟) | ~15GB | ~10GB |
|
||||
|
||||
**注意**:轻量模式下 `chat_round` 字段会正常计算,`model` 会从请求体正常提取。
|
||||
|
||||
#### 完整模式
|
||||
|
||||
通过 `use_default_attributes: true` 可以一键启用完整的流式观测能力:
|
||||
|
||||
```yaml
|
||||
use_default_attributes: true
|
||||
```
|
||||
|
||||
此配置会自动记录以下字段,**但会缓冲完整的请求体和流式响应体**:
|
||||
|
||||
| 字段 | 说明 | 内存影响 |
|
||||
|------|------|----------|
|
||||
| `messages` | 完整对话历史 | ⚠️ 可能很大 |
|
||||
| `question` | 最后一条用户消息 | 需要缓冲请求体 |
|
||||
| `system` | 系统提示词 | 需要缓冲请求体 |
|
||||
| `answer` | AI 回答(自动拼接流式 chunk) | ⚠️ 需要缓冲响应体 |
|
||||
| `reasoning` | 推理过程(自动拼接流式 chunk) | ⚠️ 需要缓冲响应体 |
|
||||
| `tool_calls` | 工具调用(自动按 index 组装) | 需要缓冲响应体 |
|
||||
| `reasoning_tokens` | 推理 token 数 | 无 |
|
||||
| `cached_tokens` | 缓存命中 token 数 | 无 |
|
||||
| `input_token_details` | 输入 token 详情 | 无 |
|
||||
| `output_token_details` | 输出 token 详情 | 无 |
|
||||
|
||||
**注意**:完整模式适用于调试、审计等需要完整对话记录的场景,但在高并发生产环境可能消耗大量内存。
|
||||
|
||||
### 流式日志示例
|
||||
|
||||
启用默认配置后,一个流式请求的日志输出示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"answer": "2 plus 2 equals 4.",
|
||||
"question": "What is 2+2?",
|
||||
"response_type": "stream",
|
||||
"tool_calls": null,
|
||||
"reasoning": null,
|
||||
"model": "glm-4-flash",
|
||||
"input_token": 10,
|
||||
"output_token": 8,
|
||||
"llm_first_token_duration": 425,
|
||||
"llm_service_duration": 985,
|
||||
"chat_id": "chat_abc123"
|
||||
}
|
||||
```
|
||||
|
||||
包含工具调用的流式日志示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"answer": null,
|
||||
"question": "What's the weather in Beijing?",
|
||||
"response_type": "stream",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"Beijing\"}"
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "glm-4-flash",
|
||||
"input_token": 50,
|
||||
"output_token": 15,
|
||||
"llm_first_token_duration": 300,
|
||||
"llm_service_duration": 500
|
||||
}
|
||||
```
|
||||
|
||||
### 流式特有指标
|
||||
|
||||
流式响应会额外记录以下指标:
|
||||
|
||||
- `llm_first_token_duration`:从请求发出到收到首个 token 的时间(首字延迟)
|
||||
- `llm_stream_duration_count`:流式请求次数
|
||||
|
||||
可用于监控流式响应的用户体验:
|
||||
|
||||
```promql
|
||||
# 平均首字延迟
|
||||
irate(route_upstream_model_consumer_metric_llm_first_token_duration[5m])
|
||||
/
|
||||
irate(route_upstream_model_consumer_metric_llm_stream_duration_count[5m])
|
||||
```
|
||||
|
||||
## 调试
|
||||
|
||||
### 验证 ai_log 内容
|
||||
|
||||
@@ -105,6 +105,7 @@ const (
|
||||
BuiltinAnswerKey = "answer"
|
||||
BuiltinToolCallsKey = "tool_calls"
|
||||
BuiltinReasoningKey = "reasoning"
|
||||
BuiltinSystemKey = "system"
|
||||
BuiltinReasoningTokens = "reasoning_tokens"
|
||||
BuiltinCachedTokens = "cached_tokens"
|
||||
BuiltinInputTokenDetails = "input_token_details"
|
||||
@@ -115,6 +116,9 @@ const (
|
||||
QuestionPathOpenAI = "messages.@reverse.0.content"
|
||||
QuestionPathClaude = "messages.@reverse.0.content" // Claude uses same format
|
||||
|
||||
// System prompt paths (from request body)
|
||||
SystemPathClaude = "system" // Claude /v1/messages has system as a top-level field
|
||||
|
||||
// Answer paths (from response body - non-streaming)
|
||||
AnswerPathOpenAINonStreaming = "choices.0.message.content"
|
||||
AnswerPathClaudeNonStreaming = "content.0.text"
|
||||
@@ -123,10 +127,19 @@ const (
|
||||
AnswerPathOpenAIStreaming = "choices.0.delta.content"
|
||||
AnswerPathClaudeStreaming = "delta.text"
|
||||
|
||||
// Tool calls paths
|
||||
// Tool calls paths (OpenAI format)
|
||||
ToolCallsPathNonStreaming = "choices.0.message.tool_calls"
|
||||
ToolCallsPathStreaming = "choices.0.delta.tool_calls"
|
||||
|
||||
// Claude/Anthropic tool calls paths (streaming)
|
||||
ClaudeEventType = "type"
|
||||
ClaudeContentBlockType = "content_block.type"
|
||||
ClaudeContentBlockID = "content_block.id"
|
||||
ClaudeContentBlockName = "content_block.name"
|
||||
ClaudeContentBlockInput = "content_block.input"
|
||||
ClaudeDeltaPartialJSON = "delta.partial_json"
|
||||
ClaudeIndex = "index"
|
||||
|
||||
// Reasoning paths
|
||||
ReasoningPathNonStreaming = "choices.0.message.reasoning_content"
|
||||
ReasoningPathStreaming = "choices.0.delta.reasoning_content"
|
||||
@@ -136,27 +149,34 @@ const (
|
||||
)
|
||||
|
||||
// getDefaultAttributes returns the default attributes configuration for empty config
|
||||
// This includes all attributes but may consume significant memory for large conversations
|
||||
func getDefaultAttributes() []Attribute {
|
||||
return []Attribute{
|
||||
// Extract complete conversation history from request body
|
||||
{
|
||||
Key: "messages",
|
||||
Key: "messages",
|
||||
ValueSource: RequestBody,
|
||||
Value: "messages",
|
||||
ApplyToLog: true,
|
||||
Value: "messages",
|
||||
ApplyToLog: true,
|
||||
},
|
||||
// Built-in attributes (no value_source needed, will be auto-extracted)
|
||||
{
|
||||
Key: BuiltinQuestionKey,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinSystemKey,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinAnswerKey,
|
||||
ApplyToLog: true,
|
||||
Rule: RuleAppend, // Streaming responses need to append content from all chunks
|
||||
},
|
||||
{
|
||||
Key: BuiltinReasoningKey,
|
||||
ApplyToLog: true,
|
||||
Rule: RuleAppend, // Streaming responses need to append content from all chunks
|
||||
},
|
||||
{
|
||||
Key: BuiltinToolCallsKey,
|
||||
@@ -183,6 +203,34 @@ func getDefaultAttributes() []Attribute {
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultResponseAttributes returns a lightweight default attributes configuration
|
||||
// for production environments with high concurrency and high latency.
|
||||
// - Buffers request body for model extraction (small, essential field)
|
||||
// - Does NOT extract large fields like question, system, messages
|
||||
// - Does NOT buffer streaming response body (no answer, reasoning, tool_calls)
|
||||
// - Only extracts token statistics from response context
|
||||
func getDefaultResponseAttributes() []Attribute {
|
||||
return []Attribute{
|
||||
// Token statistics (extracted from context, no body buffering needed)
|
||||
{
|
||||
Key: BuiltinReasoningTokens,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinCachedTokens,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinInputTokenDetails,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
{
|
||||
Key: BuiltinOutputTokenDetails,
|
||||
ApplyToLog: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Default session ID headers in priority order
|
||||
var defaultSessionHeaders = []string{
|
||||
"x-openclaw-session-key",
|
||||
@@ -211,10 +259,10 @@ func extractSessionId(customHeader string) string {
|
||||
|
||||
// ToolCall represents a single tool call in the response
|
||||
type ToolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function ToolCallFunction `json:"function,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function ToolCallFunction `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCallFunction represents the function details in a tool call
|
||||
@@ -225,14 +273,18 @@ type ToolCallFunction struct {
|
||||
|
||||
// StreamingToolCallsBuffer holds the state for assembling streaming tool calls
|
||||
type StreamingToolCallsBuffer struct {
|
||||
ToolCalls map[int]*ToolCall // keyed by index
|
||||
ToolCalls map[int]*ToolCall // keyed by index (OpenAI format)
|
||||
InToolBlock map[int]bool // tracks which indices are in tool_use blocks (Claude format)
|
||||
ArgumentsBuffer map[int]string // buffers partial JSON arguments (Claude format)
|
||||
}
|
||||
|
||||
// extractStreamingToolCalls extracts and assembles tool calls from streaming response chunks
|
||||
// extractStreamingToolCalls extracts and assembles tool calls from streaming response chunks (OpenAI format)
|
||||
func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *StreamingToolCallsBuffer {
|
||||
if buffer == nil {
|
||||
buffer = &StreamingToolCallsBuffer{
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
InToolBlock: make(map[int]bool),
|
||||
ArgumentsBuffer: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,7 +297,7 @@ func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *S
|
||||
|
||||
for _, tcResult := range toolCallsResult.Array() {
|
||||
index := int(tcResult.Get("index").Int())
|
||||
|
||||
|
||||
// Get or create tool call entry
|
||||
tc, exists := buffer.ToolCalls[index]
|
||||
if !exists {
|
||||
@@ -273,6 +325,86 @@ func extractStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *S
|
||||
return buffer
|
||||
}
|
||||
|
||||
// extractClaudeStreamingToolCalls extracts and assembles tool calls from Claude/Anthropic streaming response chunks
|
||||
// Claude format uses events: content_block_start, content_block_delta, content_block_stop
|
||||
func extractClaudeStreamingToolCalls(data []byte, buffer *StreamingToolCallsBuffer) *StreamingToolCallsBuffer {
|
||||
if buffer == nil {
|
||||
buffer = &StreamingToolCallsBuffer{
|
||||
ToolCalls: make(map[int]*ToolCall),
|
||||
InToolBlock: make(map[int]bool),
|
||||
ArgumentsBuffer: make(map[int]string),
|
||||
}
|
||||
}
|
||||
|
||||
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
|
||||
for _, chunk := range chunks {
|
||||
// Get event type
|
||||
eventType := gjson.GetBytes(chunk, ClaudeEventType)
|
||||
if !eventType.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch eventType.String() {
|
||||
case "content_block_start":
|
||||
// Check if this is a tool_use block
|
||||
contentBlockType := gjson.GetBytes(chunk, ClaudeContentBlockType)
|
||||
if contentBlockType.Exists() && contentBlockType.String() == "tool_use" {
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
|
||||
// Create tool call entry
|
||||
tc := &ToolCall{Index: index}
|
||||
|
||||
// Extract id and name
|
||||
if id := gjson.GetBytes(chunk, ClaudeContentBlockID).String(); id != "" {
|
||||
tc.ID = id
|
||||
}
|
||||
if name := gjson.GetBytes(chunk, ClaudeContentBlockName).String(); name != "" {
|
||||
tc.Function.Name = name
|
||||
}
|
||||
tc.Type = "tool_use"
|
||||
|
||||
buffer.ToolCalls[index] = tc
|
||||
buffer.InToolBlock[index] = true
|
||||
buffer.ArgumentsBuffer[index] = ""
|
||||
|
||||
// Try to extract initial input if present
|
||||
if input := gjson.GetBytes(chunk, ClaudeContentBlockInput); input.Exists() {
|
||||
if inputMap, ok := input.Value().(map[string]interface{}); ok {
|
||||
if jsonBytes, err := json.Marshal(inputMap); err == nil {
|
||||
buffer.ArgumentsBuffer[index] = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Check if we're in a tool block
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
if buffer.InToolBlock[index] {
|
||||
// Accumulate partial JSON arguments
|
||||
partialJSON := gjson.GetBytes(chunk, ClaudeDeltaPartialJSON)
|
||||
if partialJSON.Exists() {
|
||||
buffer.ArgumentsBuffer[index] += partialJSON.String()
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize the tool call if we were in a tool block
|
||||
index := int(gjson.GetBytes(chunk, ClaudeIndex).Int())
|
||||
if buffer.InToolBlock[index] {
|
||||
buffer.InToolBlock[index] = false
|
||||
|
||||
// Parse accumulated arguments and set them
|
||||
if tc, exists := buffer.ToolCalls[index]; exists {
|
||||
tc.Function.Arguments = buffer.ArgumentsBuffer[index]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buffer
|
||||
}
|
||||
|
||||
// getToolCallsFromBuffer converts the buffer to a sorted slice of tool calls
|
||||
func getToolCallsFromBuffer(buffer *StreamingToolCallsBuffer) []ToolCall {
|
||||
if buffer == nil || len(buffer.ToolCalls) == 0 {
|
||||
@@ -317,6 +449,8 @@ type AIStatisticsConfig struct {
|
||||
attributes []Attribute
|
||||
// If there exist attributes extracted from streaming body, chunks should be buffered
|
||||
shouldBufferStreamingBody bool
|
||||
// If there exist attributes extracted from request body, request body should be buffered
|
||||
shouldBufferRequestBody bool
|
||||
// If disableOpenaiUsage is true, model/input_token/output_token logs will be skipped
|
||||
disableOpenaiUsage bool
|
||||
valueLengthLimit int
|
||||
@@ -411,6 +545,8 @@ func isContentTypeEnabled(contentType string, enabledContentTypes []string) bool
|
||||
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
// Check if use_default_attributes is enabled
|
||||
useDefaultAttributes := configJson.Get("use_default_attributes").Bool()
|
||||
// Check if use_default_response_attributes is enabled (lightweight mode)
|
||||
useDefaultResponseAttributes := configJson.Get("use_default_response_attributes").Bool()
|
||||
|
||||
// Parse tracing span attributes setting.
|
||||
attributeConfigs := configJson.Get("attributes").Array()
|
||||
@@ -419,7 +555,7 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
if configJson.Get("value_length_limit").Exists() {
|
||||
config.valueLengthLimit = int(configJson.Get("value_length_limit").Int())
|
||||
} else {
|
||||
config.valueLengthLimit = 4000
|
||||
config.valueLengthLimit = 32000
|
||||
}
|
||||
|
||||
// Parse attributes or use defaults
|
||||
@@ -430,6 +566,13 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
config.valueLengthLimit = 10485760 // 10MB
|
||||
}
|
||||
log.Infof("Using default attributes configuration")
|
||||
} else if useDefaultResponseAttributes {
|
||||
config.attributes = getDefaultResponseAttributes()
|
||||
// Use a reasonable default for lightweight mode
|
||||
if !configJson.Get("value_length_limit").Exists() {
|
||||
config.valueLengthLimit = 4000
|
||||
}
|
||||
log.Infof("Using default response attributes configuration (lightweight mode)")
|
||||
} else {
|
||||
config.attributes = make([]Attribute, len(attributeConfigs))
|
||||
for i, attributeConfig := range attributeConfigs {
|
||||
@@ -439,15 +582,38 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
log.Errorf("parse config failed, %v", err)
|
||||
return err
|
||||
}
|
||||
if attribute.ValueSource == ResponseStreamingBody {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
if attribute.Rule != "" && attribute.Rule != RuleFirst && attribute.Rule != RuleReplace && attribute.Rule != RuleAppend {
|
||||
return errors.New("value of rule must be one of [nil, first, replace, append]")
|
||||
}
|
||||
config.attributes[i] = attribute
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any attribute needs request body or streaming body buffering
|
||||
for _, attribute := range config.attributes {
|
||||
// Check for request body buffering
|
||||
if attribute.ValueSource == RequestBody {
|
||||
config.shouldBufferRequestBody = true
|
||||
}
|
||||
// Check for streaming body buffering (explicitly configured)
|
||||
if attribute.ValueSource == ResponseStreamingBody {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
// For built-in attributes without explicit ValueSource, check default sources
|
||||
if attribute.ValueSource == "" && isBuiltinAttribute(attribute.Key) {
|
||||
defaultSources := getBuiltinAttributeDefaultSources(attribute.Key)
|
||||
for _, src := range defaultSources {
|
||||
if src == RequestBody {
|
||||
config.shouldBufferRequestBody = true
|
||||
}
|
||||
// Only answer/reasoning/tool_calls need actual body buffering
|
||||
// Token-related attributes are extracted from context, not from body
|
||||
if src == ResponseStreamingBody && needsBodyBuffering(attribute.Key) {
|
||||
config.shouldBufferStreamingBody = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Metric settings
|
||||
config.counterMetrics = make(map[string]proxywasm.MetricCounter)
|
||||
|
||||
@@ -458,8 +624,8 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig) error {
|
||||
pathSuffixes := configJson.Get("enable_path_suffixes").Array()
|
||||
config.enablePathSuffixes = make([]string, 0, len(pathSuffixes))
|
||||
|
||||
// If use_default_attributes is enabled and enable_path_suffixes is not configured, use default path suffixes
|
||||
if useDefaultAttributes && !configJson.Get("enable_path_suffixes").Exists() {
|
||||
// If use_default_attributes or use_default_response_attributes is enabled and enable_path_suffixes is not configured, use default path suffixes
|
||||
if (useDefaultAttributes || useDefaultResponseAttributes) && !configJson.Get("enable_path_suffixes").Exists() {
|
||||
config.enablePathSuffixes = []string{"/completions", "/messages"}
|
||||
log.Infof("Using default path suffixes: /completions, /messages")
|
||||
} else {
|
||||
@@ -527,6 +693,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig) ty
|
||||
ctx.SetContext(ConsumerKey, consumer)
|
||||
}
|
||||
|
||||
// Always buffer request body to extract model field
|
||||
// This is essential for metrics and logging
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
|
||||
// Extract session ID from headers
|
||||
@@ -551,13 +719,21 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body)
|
||||
// Set span attributes for ARMS.
|
||||
// Only process request body if we need to extract attributes from it
|
||||
if config.shouldBufferRequestBody && len(body) > 0 {
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body)
|
||||
}
|
||||
|
||||
// Extract model from request body if available, otherwise try path
|
||||
requestModel := "UNKNOWN"
|
||||
if model := gjson.GetBytes(body, "model"); model.Exists() {
|
||||
requestModel = model.String()
|
||||
} else {
|
||||
if len(body) > 0 {
|
||||
if model := gjson.GetBytes(body, "model"); model.Exists() {
|
||||
requestModel = model.String()
|
||||
}
|
||||
}
|
||||
// If model not found in body, try to extract from path (Gemini style)
|
||||
if requestModel == "UNKNOWN" {
|
||||
requestPath := ctx.GetStringContext(RequestPath, "")
|
||||
if strings.Contains(requestPath, "generateContent") || strings.Contains(requestPath, "streamGenerateContent") { // Google Gemini GenerateContent
|
||||
reg := regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):\w+Content$`)
|
||||
@@ -569,21 +745,23 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
}
|
||||
ctx.SetContext(tokenusage.CtxKeyRequestModel, requestModel)
|
||||
setSpanAttribute(ArmsRequestModel, requestModel)
|
||||
// Set the number of conversation rounds
|
||||
|
||||
// Set the number of conversation rounds (only if body is available)
|
||||
userPromptCount := 0
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
// OpenAI and Claude/Anthropic format - both use "messages" array with "role" field
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
// OpenAI and Claude/Anthropic format - both use "messages" array with "role" field
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() {
|
||||
// Google Gemini GenerateContent
|
||||
for _, content := range contents.Array() {
|
||||
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() {
|
||||
// Google Gemini GenerateContent
|
||||
for _, content := range contents.Array() {
|
||||
if !content.Get("role").Exists() || content.Get("role").String() == "user" {
|
||||
userPromptCount += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -665,7 +843,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
setSpanAttribute(ArmsModelName, usage.Model)
|
||||
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
||||
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
||||
|
||||
|
||||
// Set token details to context for later use in attributes
|
||||
if len(usage.InputTokenDetails) > 0 {
|
||||
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
||||
@@ -673,6 +851,9 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
if len(usage.OutputTokenDetails) > 0 {
|
||||
ctx.SetContext(tokenusage.CtxKeyOutputTokenDetails, usage.OutputTokenDetails)
|
||||
}
|
||||
|
||||
// Write once
|
||||
_ = ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
}
|
||||
// If the end of the stream is reached, record metrics/logs/spans.
|
||||
@@ -680,14 +861,14 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
responseEndTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
// Set user defined log & span attributes from streaming body.
|
||||
// Always call setAttributeBySource even if shouldBufferStreamingBody is false,
|
||||
// because token-related attributes are extracted from context (not buffered body).
|
||||
var streamingBodyBuffer []byte
|
||||
if config.shouldBufferStreamingBody {
|
||||
streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
|
||||
streamingBodyBuffer, _ = ctx.GetContext(CtxStreamingBodyBuffer).([]byte)
|
||||
}
|
||||
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer)
|
||||
|
||||
// Write log
|
||||
debugLogAiLog(ctx)
|
||||
@@ -729,7 +910,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
setSpanAttribute(ArmsInputToken, usage.InputToken)
|
||||
setSpanAttribute(ArmsOutputToken, usage.OutputToken)
|
||||
setSpanAttribute(ArmsTotalToken, usage.TotalToken)
|
||||
|
||||
|
||||
// Set token details to context for later use in attributes
|
||||
if len(usage.InputTokenDetails) > 0 {
|
||||
ctx.SetContext(tokenusage.CtxKeyInputTokenDetails, usage.InputTokenDetails)
|
||||
@@ -797,7 +978,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
if (value == nil || value == "") && attribute.DefaultValue != "" {
|
||||
value = attribute.DefaultValue
|
||||
}
|
||||
|
||||
|
||||
// Format value for logging/span
|
||||
var formattedValue interface{}
|
||||
switch v := value.(type) {
|
||||
@@ -816,7 +997,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
formattedValue = fmt.Sprint(value)[:config.valueLengthLimit/2] + " [truncated] " + fmt.Sprint(value)[len(fmt.Sprint(value))-config.valueLengthLimit/2:]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, formattedValue)
|
||||
if attribute.ApplyToLog {
|
||||
if attribute.AsSeparateLogField {
|
||||
@@ -849,21 +1030,32 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
|
||||
// isBuiltinAttribute checks if the given key is a built-in attribute
|
||||
func isBuiltinAttribute(key string) bool {
|
||||
return key == BuiltinQuestionKey || key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey ||
|
||||
return key == BuiltinQuestionKey || key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey || key == BuiltinSystemKey ||
|
||||
key == BuiltinReasoningTokens || key == BuiltinCachedTokens ||
|
||||
key == BuiltinInputTokenDetails || key == BuiltinOutputTokenDetails
|
||||
}
|
||||
|
||||
// needsBodyBuffering checks if a built-in attribute needs body buffering
|
||||
// Token-related attributes are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// so they don't require buffering the response body.
|
||||
func needsBodyBuffering(key string) bool {
|
||||
return key == BuiltinAnswerKey || key == BuiltinToolCallsKey || key == BuiltinReasoningKey
|
||||
}
|
||||
|
||||
// getBuiltinAttributeDefaultSources returns the default value_source(s) for a built-in attribute
|
||||
// Returns nil if the key is not a built-in attribute
|
||||
// Note: Token-related attributes are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// so they don't require body buffering even though they're processed during response phase.
|
||||
func getBuiltinAttributeDefaultSources(key string) []string {
|
||||
switch key {
|
||||
case BuiltinQuestionKey:
|
||||
case BuiltinQuestionKey, BuiltinSystemKey:
|
||||
return []string{RequestBody}
|
||||
case BuiltinAnswerKey, BuiltinToolCallsKey, BuiltinReasoningKey:
|
||||
return []string{ResponseStreamingBody, ResponseBody}
|
||||
case BuiltinReasoningTokens, BuiltinCachedTokens, BuiltinInputTokenDetails, BuiltinOutputTokenDetails:
|
||||
// Token details are only available after response is received
|
||||
// Token details are extracted from context (set by tokenusage.GetTokenUsage),
|
||||
// not from body parsing. We use ResponseStreamingBody/ResponseBody to indicate
|
||||
// they should be processed during response phase, but they don't require body buffering.
|
||||
return []string{ResponseStreamingBody, ResponseBody}
|
||||
default:
|
||||
return nil
|
||||
@@ -896,6 +1088,13 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon
|
||||
return value
|
||||
}
|
||||
}
|
||||
case BuiltinSystemKey:
|
||||
if source == RequestBody {
|
||||
// Try Claude /v1/messages format (system is a top-level field)
|
||||
if value := gjson.GetBytes(body, SystemPathClaude).Value(); value != nil && value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
case BuiltinAnswerKey:
|
||||
if source == ResponseStreamingBody {
|
||||
// Try OpenAI format first
|
||||
@@ -923,9 +1122,12 @@ func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsCon
|
||||
if existingBuffer, ok := ctx.GetContext(CtxStreamingToolCallsBuffer).(*StreamingToolCallsBuffer); ok {
|
||||
buffer = existingBuffer
|
||||
}
|
||||
// Try OpenAI format first
|
||||
buffer = extractStreamingToolCalls(body, buffer)
|
||||
// Also try Claude format (both formats can be checked)
|
||||
buffer = extractClaudeStreamingToolCalls(body, buffer)
|
||||
ctx.SetContext(CtxStreamingToolCallsBuffer, buffer)
|
||||
|
||||
|
||||
// Also set tool_calls to user attributes so they appear in ai_log
|
||||
toolCalls := getToolCallsFromBuffer(buffer)
|
||||
if len(toolCalls) > 0 {
|
||||
@@ -1047,6 +1249,9 @@ func debugLogAiLog(ctx wrapper.HttpContext) {
|
||||
if question := ctx.GetUserAttribute("question"); question != nil {
|
||||
userAttrs["question"] = question
|
||||
}
|
||||
if system := ctx.GetUserAttribute("system"); system != nil {
|
||||
userAttrs["system"] = system
|
||||
}
|
||||
if answer := ctx.GetUserAttribute("answer"); answer != nil {
|
||||
userAttrs["answer"] = answer
|
||||
}
|
||||
|
||||
@@ -1712,3 +1712,305 @@ func TestTokenDetails(t *testing.T) {
|
||||
host.CompleteHttp()
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmatchedPathsAndContentTypes(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
restrictiveConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"enable_path_suffixes": []string{"/allowed_path"},
|
||||
"enable_content_types": []string{"application/json"},
|
||||
"attributes": []map[string]interface{}{
|
||||
{
|
||||
"key": "test_attr",
|
||||
"value_source": "response_body",
|
||||
"value": "data",
|
||||
"apply_to_log": true,
|
||||
},
|
||||
},
|
||||
"disable_openai_usage": true,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
t.Run("skip request for unenabled path", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(restrictiveConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/disallowed_path"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
t.Run("skip response for unenabled content type", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(restrictiveConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/allowed_path"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "text/plain"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetSpanAttributeAndLoggingEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
configBytes := []byte(`{
|
||||
"attributes": [
|
||||
{
|
||||
"key": "test_attr1",
|
||||
"value_source": "fixed_value",
|
||||
"value": "",
|
||||
"apply_to_span": true
|
||||
},
|
||||
{
|
||||
"key": "test_attr2",
|
||||
"value_source": "fixed_value",
|
||||
"value": "long_value_that_exceeds_limit_long_value_that_exceeds_limit_long_value_that_exceeds_limit",
|
||||
"apply_to_log": true
|
||||
}
|
||||
],
|
||||
"value_length_limit": 20
|
||||
}`)
|
||||
|
||||
t.Run("span attribute edge cases", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Setting fixed value attribute to empty should just print a debug log and skip setting span
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRouteAndClusterNameEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("properties absence", func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Host doesn't have route_name implicitly by default without SetRouteName, but getRouteName handles err
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
t.Run("api name with @", func(t *testing.T) {
|
||||
host, status := test.NewTestHost([]byte(`{}`))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.SetRouteName("api@v1@service@extra") // @ has special handling in getAPIName
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClaudeStreamingToolCallsMissingInput(t *testing.T) {
|
||||
t.Run("claude missing partial_json", func(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather","input":{}}}`),
|
||||
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta"}}`),
|
||||
[]byte(`data: {"type":"content_block_stop","index":0}`),
|
||||
}
|
||||
|
||||
var buffer *StreamingToolCallsBuffer
|
||||
for _, chunk := range chunks {
|
||||
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||
}
|
||||
|
||||
toolCalls := getToolCallsFromBuffer(buffer)
|
||||
require.Len(t, toolCalls, 1)
|
||||
require.Equal(t, "tool_123", toolCalls[0].ID)
|
||||
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||
require.Equal(t, "get_weather", toolCalls[0].Function.Name)
|
||||
// partial_json absence means arguments might be empty
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteMetricEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("disable_openai_usage true", func(t *testing.T) {
|
||||
configBytes := []byte(`{
|
||||
"disable_openai_usage": true
|
||||
}`)
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.SetRouteName("api-v1")
|
||||
host.SetClusterName("cluster-1")
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
})
|
||||
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := []byte(`{
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 8, "total_tokens": 13},
|
||||
"model": "gpt-3.5-turbo"
|
||||
}`)
|
||||
host.CallOnHttpResponseBody(responseBody)
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsPathEnabled(t *testing.T) {
|
||||
require.True(t, isPathEnabled("/v1/chat/completions", nil))
|
||||
require.True(t, isPathEnabled("/v1/chat/completions", []string{}))
|
||||
require.True(t, isPathEnabled("/v1/chat/completions", []string{"/completions", "/messages"}))
|
||||
require.True(t, isPathEnabled("/v1/messages", []string{"/completions", "/messages"}))
|
||||
require.False(t, isPathEnabled("/v1/embeddings", []string{"/completions", "/messages"}))
|
||||
|
||||
// test query params
|
||||
require.True(t, isPathEnabled("/v1/chat/completions?stream=true", []string{"/completions"}))
|
||||
require.False(t, isPathEnabled("/v1/embeddings?stream=true", []string{"/completions"}))
|
||||
}
|
||||
|
||||
func TestIsContentTypeEnabled(t *testing.T) {
|
||||
require.True(t, isContentTypeEnabled("application/json", nil))
|
||||
require.True(t, isContentTypeEnabled("application/json", []string{}))
|
||||
require.True(t, isContentTypeEnabled("application/json", []string{"application/json", "text/event-stream"}))
|
||||
require.True(t, isContentTypeEnabled("text/event-stream; charset=utf-8", []string{"application/json", "text/event-stream"}))
|
||||
require.False(t, isContentTypeEnabled("text/html", []string{"application/json", "text/event-stream"}))
|
||||
}
|
||||
|
||||
func TestConvertToUInt(t *testing.T) {
|
||||
val, ok := convertToUInt(int32(10))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
val, ok = convertToUInt(int64(10))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
val, ok = convertToUInt(uint32(10))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
val, ok = convertToUInt(uint64(10))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
val, ok = convertToUInt(float32(10.0))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
val, ok = convertToUInt(float64(10.0))
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint64(10), val)
|
||||
|
||||
_, ok = convertToUInt("10")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestExtractClaudeStreamingToolCalls(t *testing.T) {
|
||||
t.Run("claude tool use assembly", func(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_123","name":"get_weather"}}`),
|
||||
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"loc"}}}`),
|
||||
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ation\":\"Bei"}}}`),
|
||||
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"jing\"}"}}}`),
|
||||
[]byte(`data: {"type":"content_block_stop","index":0}`),
|
||||
}
|
||||
|
||||
var buffer *StreamingToolCallsBuffer
|
||||
for _, chunk := range chunks {
|
||||
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||
}
|
||||
|
||||
toolCalls := getToolCallsFromBuffer(buffer)
|
||||
require.Len(t, toolCalls, 1)
|
||||
require.Equal(t, "tool_123", toolCalls[0].ID)
|
||||
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||
require.Equal(t, "get_weather", toolCalls[0].Function.Name)
|
||||
require.Equal(t, `{"location":"Beijing"}`, toolCalls[0].Function.Arguments)
|
||||
})
|
||||
|
||||
t.Run("claude empty chunks", func(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte(`data: {"type":"ping"}`),
|
||||
[]byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`),
|
||||
}
|
||||
var buffer *StreamingToolCallsBuffer
|
||||
for _, chunk := range chunks {
|
||||
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||
}
|
||||
toolCalls := getToolCallsFromBuffer(buffer)
|
||||
require.Len(t, toolCalls, 0)
|
||||
})
|
||||
|
||||
t.Run("claude tool use with initial input", func(t *testing.T) {
|
||||
chunks := [][]byte{
|
||||
[]byte(`data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"tool_456","name":"get_time","input":{"timezone":"UTC+8"}}}`),
|
||||
[]byte(`data: {"type":"content_block_stop","index":1}`),
|
||||
}
|
||||
|
||||
var buffer *StreamingToolCallsBuffer
|
||||
for _, chunk := range chunks {
|
||||
buffer = extractClaudeStreamingToolCalls(chunk, buffer)
|
||||
}
|
||||
|
||||
toolCalls := getToolCallsFromBuffer(buffer)
|
||||
require.Len(t, toolCalls, 1)
|
||||
require.Equal(t, "tool_456", toolCalls[0].ID)
|
||||
require.Equal(t, "tool_use", toolCalls[0].Type)
|
||||
require.Equal(t, "get_time", toolCalls[0].Function.Name)
|
||||
require.Equal(t, `{"timezone":"UTC+8"}`, toolCalls[0].Function.Arguments)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigWithDefaultAttributes(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("use default attributes config", func(t *testing.T) {
|
||||
defaultConfig := []byte(`{
|
||||
"use_default_attributes": true
|
||||
}`)
|
||||
host, status := test.NewTestHost(defaultConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
t.Run("use default response attributes config", func(t *testing.T) {
|
||||
defaultRespConfig := []byte(`{
|
||||
"use_default_response_attributes": true
|
||||
}`)
|
||||
host, status := test.NewTestHost(defaultRespConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
module jsonrpc-converter
|
||||
|
||||
go 1.24.3
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.4
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -15,6 +19,7 @@ require (
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
@@ -22,8 +27,10 @@ require (
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
|
||||
@@ -20,10 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.4 h1:/GqbzCw4oWqJc8UbKEfF94E3/+4CPZGbzxpKo2L3Ldk=
|
||||
github.com/higress-group/wasm-go v1.0.4/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
@@ -49,6 +49,8 @@ github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp"
|
||||
"github.com/higress-group/wasm-go/pkg/mcp/utils"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/mcp/utils"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestTruncateString tests the truncateString function
|
||||
func TestTruncateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -14,6 +20,8 @@ func TestTruncateString(t *testing.T) {
|
||||
{"Short String", "Higress Is an AI-Native API Gateway", 1000, "Higress Is an AI-Native API Gateway"},
|
||||
{"Exact Length", "Higress Is an AI-Native API Gateway", 35, "Higress Is an AI-Native API Gateway"},
|
||||
{"Truncated String", "Higress Is an AI-Native API Gateway", 20, "Higress Is...(truncated)...PI Gateway"},
|
||||
{"Empty String", "", 10, ""},
|
||||
{"Single Char", "A", 10, "A"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -26,3 +34,248 @@ func TestTruncateString(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPreRequestStage tests the isPreRequestStage function
|
||||
func TestIsPreRequestStage(t *testing.T) {
|
||||
config := McpConverterConfig{Stage: ProcessRequest}
|
||||
require.True(t, isPreRequestStage(config))
|
||||
|
||||
config = McpConverterConfig{Stage: ProcessResponse}
|
||||
require.False(t, isPreRequestStage(config))
|
||||
}
|
||||
|
||||
// TestIsPreResponseStage tests the isPreResponseStage function
|
||||
func TestIsPreResponseStage(t *testing.T) {
|
||||
config := McpConverterConfig{Stage: ProcessResponse}
|
||||
require.True(t, isPreResponseStage(config))
|
||||
|
||||
config = McpConverterConfig{Stage: ProcessRequest}
|
||||
require.False(t, isPreResponseStage(config))
|
||||
}
|
||||
|
||||
// TestIsMethodAllowed tests the isMethodAllowed function
|
||||
func TestIsMethodAllowed(t *testing.T) {
|
||||
config := McpConverterConfig{AllowedMethods: []string{MethodToolList, MethodToolCall}}
|
||||
|
||||
require.True(t, isMethodAllowed(config, MethodToolList))
|
||||
require.True(t, isMethodAllowed(config, MethodToolCall))
|
||||
require.False(t, isMethodAllowed(config, "invalid/method"))
|
||||
}
|
||||
|
||||
// TestConstants tests the constant values
|
||||
func TestConstants(t *testing.T) {
|
||||
require.Equal(t, "x-envoy-jsonrpc-id", JsonRpcId)
|
||||
require.Equal(t, "x-envoy-jsonrpc-method", JsonRpcMethod)
|
||||
require.Equal(t, "x-envoy-jsonrpc-params", JsonRpcParams)
|
||||
require.Equal(t, "x-envoy-jsonrpc-result", JsonRpcResult)
|
||||
require.Equal(t, "x-envoy-jsonrpc-error", JsonRpcError)
|
||||
require.Equal(t, "x-envoy-mcp-tool-name", McpToolName)
|
||||
require.Equal(t, "x-envoy-mcp-tool-arguments", McpToolArguments)
|
||||
require.Equal(t, "x-envoy-mcp-tool-response", McpToolResponse)
|
||||
require.Equal(t, "x-envoy-mcp-tool-error", McpToolError)
|
||||
require.Equal(t, 4000, DefaultMaxHeaderLength)
|
||||
require.Equal(t, "tools/list", MethodToolList)
|
||||
require.Equal(t, "tools/call", MethodToolCall)
|
||||
require.Equal(t, ProcessStage("request"), ProcessRequest)
|
||||
require.Equal(t, ProcessStage("response"), ProcessResponse)
|
||||
}
|
||||
|
||||
// TestMcpConverterConfigDefaults tests config default values
|
||||
func TestMcpConverterConfigDefaults(t *testing.T) {
|
||||
config := McpConverterConfig{}
|
||||
require.Equal(t, 0, config.MaxHeaderLength)
|
||||
require.Equal(t, ProcessStage(""), config.Stage)
|
||||
require.Nil(t, config.AllowedMethods)
|
||||
}
|
||||
|
||||
// TestProcessStage tests ProcessStage type
|
||||
func TestProcessStage(t *testing.T) {
|
||||
require.Equal(t, ProcessStage("request"), ProcessRequest)
|
||||
require.Equal(t, ProcessStage("response"), ProcessResponse)
|
||||
}
|
||||
|
||||
// TestRemoveJsonRpcHeadersFunction tests removeJsonRpcHeaders function logic
|
||||
func TestRemoveJsonRpcHeadersFunction(t *testing.T) {
|
||||
headersToRemove := []string{
|
||||
JsonRpcId,
|
||||
JsonRpcMethod,
|
||||
JsonRpcParams,
|
||||
JsonRpcResult,
|
||||
McpToolName,
|
||||
McpToolArguments,
|
||||
McpToolResponse,
|
||||
McpToolError,
|
||||
}
|
||||
require.Len(t, headersToRemove, 8)
|
||||
}
|
||||
|
||||
// TestTruncateStringLong tests truncation of very long strings
|
||||
func TestTruncateStringLong(t *testing.T) {
|
||||
longString := ""
|
||||
for i := 0; i < 5000; i++ {
|
||||
longString += "a"
|
||||
}
|
||||
config := McpConverterConfig{MaxHeaderLength: 1000}
|
||||
result := truncateString(longString, config)
|
||||
require.Contains(t, result, "...(truncated)...")
|
||||
require.LessOrEqual(t, len(result), 1020)
|
||||
}
|
||||
|
||||
// TestTruncateStringWithSmallMaxLength tests truncation with small max length
|
||||
func TestTruncateStringWithSmallMaxLength(t *testing.T) {
|
||||
config := McpConverterConfig{MaxHeaderLength: 10}
|
||||
result := truncateString("This is a very long string", config)
|
||||
require.Contains(t, result, "...(truncated)...")
|
||||
}
|
||||
|
||||
// TestPluginInit tests plugin initialization
|
||||
func TestPluginInit(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
}
|
||||
|
||||
// TestProcessJsonRpcRequest tests processJsonRpcRequest function
|
||||
func TestProcessJsonRpcRequest(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
toolsListRequest := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(toolsListRequest))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessToolCallRequest tests processToolCallRequest function
|
||||
func TestProcessToolCallRequest(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessRequest,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
toolCallRequest := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"arg1": "value1"}
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(toolCallRequest))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessJsonRpcResponse tests processJsonRpcResponse function
|
||||
func TestProcessJsonRpcResponse(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessResponse,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList, MethodToolCall},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"tools": [{"name": "test_tool"}]
|
||||
}
|
||||
}`
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
// TestProcessToolListResponse tests processToolListResponse function
|
||||
func TestProcessToolListResponse(t *testing.T) {
|
||||
configBytes, _ := json.Marshal(McpConverterConfig{
|
||||
Stage: ProcessResponse,
|
||||
MaxHeaderLength: DefaultMaxHeaderLength,
|
||||
AllowedMethods: []string{MethodToolList},
|
||||
})
|
||||
|
||||
host, status := test.NewTestHost(configBytes)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.InitHttp()
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "mcp-server.example.com"},
|
||||
{":method", "POST"},
|
||||
{":path", "/mcp"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"tools": [{"name": "test_tool"}]
|
||||
}
|
||||
}`
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
host.CompleteHttp()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ module mcp-router
|
||||
|
||||
go 1.24.1
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go/pkg/mcp => ../../pkg/mcp
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774
|
||||
github.com/alibaba/higress/plugins/wasm-go/pkg/mcp v0.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
@@ -20,12 +20,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b h1:rRI9+ThQbe+nw4jUiYEyOFaREkXCMMW9k1X2gy2d6pE=
|
||||
github.com/higress-group/gjson_template v0.0.0-20250413075336-4c4161ed428b/go.mod h1:rU3M+Tq5VrQOo0dxpKHGb03Ty0sdWIZfAH+YCOACx/Y=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57 h1:WhNdnKSDtAQrh4Yil8HAtbl7VW+WC85m7WS8kirnHAA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250807064511-eb1cd98e1f57/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774 h1:2wlbNpFJCQNbPBFYgswz7Zvxo9O3L0PH0AJxwiCc5lk=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250911113549-cbf1cfcce774/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9 h1:sUuUXZwr50l3W1St7MESlFmxmUAu+QUNNfJXx4P6bas=
|
||||
github.com/higress-group/wasm-go v1.0.10-0.20260115123534-84ef43c39dc9/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user