mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 02:35:02 +08:00
feat(ai-context-limit): add context window limit wasm plugin (#4000)
Signed-off-by: Cai Rui <cairui@U-7VTK6WQN-2207.local>
This commit is contained in:
@@ -14,6 +14,7 @@ COPY . .
|
||||
WORKDIR /workspace/extensions/$PLUGIN_NAME
|
||||
|
||||
RUN go mod tidy
|
||||
RUN if [ -f prepare.sh ]; then sh ./prepare.sh; fi
|
||||
RUN \
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o /main.wasm .
|
||||
|
||||
|
||||
@@ -60,7 +60,9 @@ builder:
|
||||
@echo "image: ${BUILDER}"
|
||||
|
||||
local-build:
|
||||
cd extensions/${PLUGIN_NAME};GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm .
|
||||
cd extensions/${PLUGIN_NAME}; \
|
||||
if [ -f prepare.sh ]; then sh ./prepare.sh; fi; \
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm .
|
||||
|
||||
@echo ""
|
||||
@echo "wasm: extensions/${PLUGIN_NAME}/main.wasm"
|
||||
|
||||
6
plugins/wasm-go/extensions/ai-context-limit/.gitignore
vendored
Normal file
6
plugins/wasm-go/extensions/ai-context-limit/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
plugin.wasm
|
||||
*.wasm
|
||||
.tmp/
|
||||
.gocache/
|
||||
PR.md
|
||||
bpe/o200k_base.tiktoken
|
||||
13
plugins/wasm-go/extensions/ai-context-limit/Makefile
Normal file
13
plugins/wasm-go/extensions/ai-context-limit/Makefile
Normal file
@@ -0,0 +1,13 @@
|
||||
.PHONY: prepare build build-go clean
|
||||
|
||||
prepare:
|
||||
@sh ./prepare.sh
|
||||
|
||||
build: build-go
|
||||
|
||||
build-go: prepare
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm .
|
||||
|
||||
clean:
|
||||
rm -f main.wasm
|
||||
rm -f bpe/o200k_base.tiktoken
|
||||
68
plugins/wasm-go/extensions/ai-context-limit/README.md
Normal file
68
plugins/wasm-go/extensions/ai-context-limit/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
---
|
||||
title: AI 上下文窗口限制
|
||||
keywords: [ AI网关, 上下文窗口, Token ]
|
||||
description: AI 上下文窗口限制插件配置参考
|
||||
---
|
||||
|
||||
## 功能说明
|
||||
|
||||
`ai-context-limit` 用于在请求转发到上游大模型前,对 OpenAI Chat Completions、Anthropic Messages 等协议兼容请求中的文本输入进行 token 估算。当估算结果超过配置的上下文窗口大小时,插件会直接返回错误响应,避免超长上下文继续进入后端模型服务。
|
||||
|
||||
该插件适用于按路由、服务、域名或 MCP Server 控制请求上下文规模的场景,可用于为不同业务、模型或调用入口设置独立的上下文窗口上限。
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
|
||||
插件执行优先级:`1000`
|
||||
|
||||
## 构建
|
||||
|
||||
插件依赖内嵌的 BPE 词表文件,首次构建前需要下载:
|
||||
|
||||
```bash
|
||||
make build
|
||||
```
|
||||
|
||||
或分步执行:
|
||||
|
||||
```bash
|
||||
make prepare # 下载词表到 bpe/o200k_base.tiktoken
|
||||
make build-go # 编译 WASM
|
||||
```
|
||||
|
||||
## 配置字段
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------|---------|---------|--------|------|
|
||||
| `max_context_tokens` | int | 必填 | - | 最大上下文 token 数。输入估算结果超过该值时,请求会被拦截。设为 0 表示禁用拦截。 |
|
||||
| `buffer_ratio` | float | 非必填 | 1.10 | 安全缓冲系数(取值范围 0~10)。插件会将估算 token 数乘以该系数后再与阈值比较。 |
|
||||
| `error_status_code` | int | 非必填 | 400 | 请求超出上下文窗口限制时返回的 HTTP 状态码(取值范围 400~599)。 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
max_context_tokens: 128000
|
||||
buffer_ratio: 1.10
|
||||
error_status_code: 400
|
||||
```
|
||||
|
||||
## 返回示例
|
||||
|
||||
当请求输入超过配置限制时,插件会返回如下格式的错误响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 128000 tokens. Your request had approximately 140000 tokens.",
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 当前版本会统计文本承载字段,包括 text、tool schema、tool arguments、thinking、text document、search_result 等;图片、音频、base64/url/file document 等非文本内容会跳过 token 统计,整个请求直接放行。
|
||||
- 非 JSON 请求或非兼容协议的请求不会触发上下文限制。
|
||||
- 插件最多读取 8MB 请求体用于文本估算,超出部分不会被处理。
|
||||
68
plugins/wasm-go/extensions/ai-context-limit/README_EN.md
Normal file
68
plugins/wasm-go/extensions/ai-context-limit/README_EN.md
Normal file
@@ -0,0 +1,68 @@
|
||||
---
|
||||
title: AI Context Limit
|
||||
keywords: [ AI Gateway, Context Window, Token ]
|
||||
description: AI Context Limit plugin configuration reference
|
||||
---
|
||||
|
||||
## Functional Description
|
||||
|
||||
`ai-context-limit` estimates the input token count of OpenAI Chat Completions, Anthropic Messages and other compatible requests before forwarding them to the upstream model service. When the estimated input size exceeds the configured context window limit, the plugin returns an error response directly.
|
||||
|
||||
This plugin can be used to control context window size by route, service, domain, or MCP Server. It is suitable for setting independent context limits for different applications, models, or traffic entry points.
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
Plugin execution phase: `Default Phase`
|
||||
|
||||
Plugin execution priority: `1000`
|
||||
|
||||
## Build
|
||||
|
||||
The plugin requires an embedded BPE vocabulary file. Download it before the first build:
|
||||
|
||||
```bash
|
||||
make build
|
||||
```
|
||||
|
||||
Or step by step:
|
||||
|
||||
```bash
|
||||
make prepare # Download vocabulary to bpe/o200k_base.tiktoken
|
||||
make build-go # Compile WASM
|
||||
```
|
||||
|
||||
## Configuration Fields
|
||||
|
||||
| Name | Data Type | Requirement | Default Value | Description |
|
||||
|------|-----------|-------------|---------------|-------------|
|
||||
| `max_context_tokens` | int | Required | - | Maximum context token limit. Requests whose estimated input size exceeds this value will be blocked. Set to 0 to disable. |
|
||||
| `buffer_ratio` | float | Optional | 1.10 | Safety buffer ratio (valid range: 0–10). The estimated token count is multiplied by this ratio before comparison. |
|
||||
| `error_status_code` | int | Optional | 400 | HTTP status code returned when the request exceeds the context limit (valid range: 400–599). |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
max_context_tokens: 128000
|
||||
buffer_ratio: 1.10
|
||||
error_status_code: 400
|
||||
```
|
||||
|
||||
## Response Example
|
||||
|
||||
When a request exceeds the configured limit, the plugin returns an error response in the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 128000 tokens. Your request had approximately 140000 tokens.",
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The plugin counts text-bearing fields including text, tool schema, tool arguments, thinking, text document, and search_result. Non-text content such as images, audio, and base64/url/file documents will skip token counting and the entire request is passed through.
|
||||
- Non-JSON requests and requests that are not in a compatible protocol format will not trigger the context limit.
|
||||
- The plugin reads up to 8MB of the request body for text estimation; content beyond this limit will not be processed.
|
||||
1
plugins/wasm-go/extensions/ai-context-limit/VERSION
Normal file
1
plugins/wasm-go/extensions/ai-context-limit/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
1.0.0
|
||||
73
plugins/wasm-go/extensions/ai-context-limit/config.go
Normal file
73
plugins/wasm-go/extensions/ai-context-limit/config.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2026 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Config 上下文限制插件配置
|
||||
type Config struct {
|
||||
// MaxContextTokens 必填,输入侧 token 上限(用户阈值)
|
||||
MaxContextTokens int `json:"max_context_tokens"`
|
||||
// ErrorStatusCode 超限响应码,默认 400
|
||||
ErrorStatusCode int `json:"error_status_code"`
|
||||
// BufferRatio token 预估值放大系数,默认 1.10
|
||||
BufferRatio float64 `json:"buffer_ratio"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultErrorStatusCode = 400
|
||||
defaultBufferRatio = 1.10
|
||||
// MaxRequestBodyBytes 强制调大的 envoy 请求体 buffer 上限
|
||||
// 上下文限制仅需要读取请求体中的文本输入,8MB 可覆盖常见长上下文请求。
|
||||
MaxRequestBodyBytes uint32 = 8 * 1024 * 1024
|
||||
)
|
||||
|
||||
// parseConfig 解析 WasmPlugin defaultConfig 字段
|
||||
func parseConfig(jsonConfig gjson.Result, cfg *Config) error {
|
||||
if err := json.Unmarshal([]byte(jsonConfig.Raw), cfg); err != nil {
|
||||
return fmt.Errorf("parse config failed: %w", err)
|
||||
}
|
||||
if cfg.MaxContextTokens < 0 {
|
||||
return fmt.Errorf("max_context_tokens must be non-negative, got %d", cfg.MaxContextTokens)
|
||||
}
|
||||
if cfg.MaxContextTokens == 0 {
|
||||
// 阈值为 0 视为未启用,不拦截请求(防止误配置导致全量 5xx)
|
||||
return nil
|
||||
}
|
||||
if cfg.ErrorStatusCode == 0 {
|
||||
cfg.ErrorStatusCode = defaultErrorStatusCode
|
||||
} else if cfg.ErrorStatusCode < 400 || cfg.ErrorStatusCode > 599 {
|
||||
return fmt.Errorf("error_status_code must be between 400 and 599, got %d", cfg.ErrorStatusCode)
|
||||
}
|
||||
if cfg.BufferRatio < 0 {
|
||||
return fmt.Errorf("buffer_ratio must be non-negative, got %f", cfg.BufferRatio)
|
||||
}
|
||||
if cfg.BufferRatio == 0 {
|
||||
cfg.BufferRatio = defaultBufferRatio
|
||||
} else if cfg.BufferRatio > 10 {
|
||||
return fmt.Errorf("buffer_ratio must not exceed 10, got %f", cfg.BufferRatio)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled 判断当前配置是否需要执行拦截
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c.MaxContextTokens > 0
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
# ai-context-limit 设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
大模型网关经常需要在请求到达上游模型之前拦截超长 prompt。现有 token 相关插件更多依赖模型响应中的 usage 字段进行事后统计,适合做计量和配额,但无法阻止超限请求进入模型。
|
||||
|
||||
`ai-context-limit` 提供请求侧上下文窗口保护能力,面向 OpenAI Chat Completions、Anthropic Messages 等协议兼容请求,在转发前估算输入 token 数,并在估算结果超过配置阈值时直接返回错误响应。
|
||||
|
||||
## 目标
|
||||
|
||||
- 在请求发送到上游模型前拦截超限文本输入。
|
||||
- 支持通过 Higress WasmPlugin 在路由、服务、域名和 MCP Server 等粒度配置。
|
||||
- 首个版本保持数据面自包含,运行时不依赖网络下载分词资源。
|
||||
- 返回 OpenAI 兼容错误响应,方便常见 SDK 识别异常。
|
||||
|
||||
## 非目标
|
||||
|
||||
- 为每个模型系列精确适配专属分词器。
|
||||
- 根据模型名自动选择分词器。
|
||||
- 统计多模态内容的 token。
|
||||
- 做响应侧 token 统计或配额管理。
|
||||
|
||||
如果后续特定模型族对精度有更高要求,可以在该插件基础上继续扩展。
|
||||
|
||||
## 配置
|
||||
|
||||
```yaml
|
||||
max_context_tokens: 128000
|
||||
buffer_ratio: 1.10
|
||||
error_status_code: 400
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 默认值 | 校验规则 | 含义 |
|
||||
|---|---|---|---|---|
|
||||
| `max_context_tokens` | int | - | `>= 0` | 最大输入 token 估算阈值。设为 `0` 表示禁用拦截。 |
|
||||
| `buffer_ratio` | float | `1.10` | `0 <= value <= 10` | 安全缓冲系数,估算 token 数会先乘以该系数再与阈值比较。设为 `0` 使用默认值。 |
|
||||
| `error_status_code` | int | `400` | `400 <= value <= 599` | 请求被拦截时返回的 HTTP 状态码。 |
|
||||
|
||||
## 请求处理流程
|
||||
|
||||
1. 请求头阶段,非 JSON 请求和未启用配置直接放行,不读取请求体。
|
||||
2. JSON 请求会把请求体 buffer 上限调到 8MB,并等待请求体阶段处理。
|
||||
3. 插件会自动识别请求协议(OpenAI 或 Anthropic),并抽取对应字段的文本:
|
||||
- **OpenAI Chat Completions**:
|
||||
- `messages[].content` 字符串或 text parts 数组;
|
||||
- `messages[].role` 和 `messages[].name`;
|
||||
- `messages[].tool_calls[].function.name` 和 `arguments`;
|
||||
- `tools[].function.name`、`description`、`parameters`;
|
||||
- `response_format.json_schema.name`、`description`、`schema`;
|
||||
- 顶层 `system` 字段。
|
||||
- **Anthropic Messages**(通过 `tools[].input_schema` / `tool_use` / `tool_result` / `thinking` / `redacted_thinking` / `document` / `search_result` 等特有字段识别):
|
||||
- `messages[].content` 字符串或 content block 数组;
|
||||
- `messages[].role`;
|
||||
- `text` block 的 `text` 字段;
|
||||
- `tool_use` block 的 `name` 和 `input`(raw JSON);
|
||||
- `tool_result` block 的 `content`(字符串或 content block 数组,递归处理);
|
||||
- `thinking` block 的 `thinking` 字段;
|
||||
- `redacted_thinking` block 的 `data` 字段(保守计入);
|
||||
- `document` block:`source.type=text` 时计入 `title` + `source.data`,其他视为多模态;
|
||||
- `search_result` block 的 `title` + `source` + `content[]` 中的 text blocks;
|
||||
- `tools[].name`、`description`、`type`、`input_schema`(raw JSON);
|
||||
- 顶层 `system`(字符串或 text block 数组)。
|
||||
4. 如果检测到非文本二进制内容(如图片、音频、base64/url/file document),整个请求直接放行。插件会统计所有文本承载字段,未知非文本 block 视为多模态并放行。
|
||||
5. 插件对抽取文本进行 token 估算,并将结果乘以 `buffer_ratio` 后与 `max_context_tokens` 比较。
|
||||
6. 当 `estimated_tokens > max_context_tokens` 时,插件返回 OpenAI 兼容的 `context_length_exceeded` 错误。
|
||||
|
||||
## Token 估算策略
|
||||
|
||||
首个版本使用单一内嵌 BPE 词表,并在插件启动时初始化 tokenizer。这样可以避免 WASM 沙箱运行时下载资源,使请求处理完全在本地完成。
|
||||
|
||||
默认 `buffer_ratio` 为 `1.10`。基于长中文文档、混合 RAG、代码、多轮对话等文本的验证结果,该缓冲系数可以覆盖已观察到的低估场景,同时保持实现简单、确定。
|
||||
|
||||
## 错误响应
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 128000 tokens. Your request had approximately 140000 tokens.",
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 验证
|
||||
|
||||
实现已包含单元测试,覆盖:
|
||||
|
||||
- 配置默认值和非法值;
|
||||
- 从 messages、tools、顶层 system 字段抽取文本;
|
||||
- 多模态检测与放行;
|
||||
- token 计数基础行为;
|
||||
- 严格阈值比较逻辑。
|
||||
|
||||
插件已通过以下验证:
|
||||
|
||||
- `go test ./...`
|
||||
- `go vet ./...`
|
||||
- `make local-build PLUGIN_NAME=ai-context-limit`
|
||||
|
||||
---
|
||||
|
||||
# ai-context-limit Design
|
||||
|
||||
## Background
|
||||
|
||||
Large language model gateways often need to reject over-sized prompts before they reach the upstream model. Existing token-related plugins mainly rely on response-side usage fields, which is useful for accounting but cannot prevent an over-limit request from reaching the model.
|
||||
|
||||
`ai-context-limit` provides request-side context window protection for OpenAI Chat Completions, Anthropic Messages and other compatible traffic. It estimates input tokens before forwarding and returns an error response when the estimated input size exceeds the configured limit.
|
||||
|
||||
## Goals
|
||||
|
||||
- Block over-limit text requests before they are sent to the upstream model.
|
||||
- Support route, service, domain, and MCP Server level configuration through Higress WasmPlugin.
|
||||
- Keep the first version self-contained in the data plane, without runtime network access for tokenizer resources.
|
||||
- Provide OpenAI-compatible error responses so common SDKs can parse the failure.
|
||||
|
||||
## Non-goals
|
||||
|
||||
- Exact tokenizer matching for every model family.
|
||||
- Model-name based tokenizer selection.
|
||||
- Multimodal token counting.
|
||||
- Response-side token accounting or quota management.
|
||||
|
||||
These can be added later if there is a stronger precision requirement for specific model families.
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
max_context_tokens: 128000
|
||||
buffer_ratio: 1.10
|
||||
error_status_code: 400
|
||||
```
|
||||
|
||||
| Field | Type | Default | Validation | Meaning |
|
||||
|---|---|---|---|---|
|
||||
| `max_context_tokens` | int | - | `>= 0` | Maximum estimated input tokens. `0` disables blocking. |
|
||||
| `buffer_ratio` | float | `1.10` | `0 <= value <= 10` | Safety multiplier applied to estimated tokens before comparison. `0` uses the default. |
|
||||
| `error_status_code` | int | `400` | `400 <= value <= 599` | HTTP status code for blocked requests. |
|
||||
|
||||
## Request Processing
|
||||
|
||||
1. In the request header phase, non-JSON requests and disabled configs are passed through without reading the body.
|
||||
2. For JSON requests, the plugin raises the request body buffer limit to 8MB and waits for the body phase.
|
||||
3. The plugin auto-detects the request protocol (OpenAI or Anthropic) and extracts text from the corresponding fields:
|
||||
- **OpenAI Chat Completions**:
|
||||
- `messages[].content` string or text parts array;
|
||||
- `messages[].role` and `messages[].name`;
|
||||
- `messages[].tool_calls[].function.name` and `arguments`;
|
||||
- `tools[].function.name`, `description`, and `parameters`;
|
||||
- `response_format.json_schema.name`, `description`, and `schema`;
|
||||
- top-level `system`.
|
||||
- **Anthropic Messages** (detected via `tools[].input_schema` / `tool_use` / `tool_result` / `thinking` / `redacted_thinking` / `document` / `search_result`):
|
||||
- `messages[].content` string or content block array;
|
||||
- `messages[].role`;
|
||||
- `text` block `text` field;
|
||||
- `tool_use` block `name` and `input` (raw JSON);
|
||||
- `tool_result` block `content` (string or content block array, recursively processed);
|
||||
- `thinking` block `thinking` field;
|
||||
- `redacted_thinking` block `data` field (conservatively counted);
|
||||
- `document` block: `source.type=text` counts `title` + `source.data`, others treated as multimodal;
|
||||
- `search_result` block `title` + `source` + `content[]` text blocks;
|
||||
- `tools[].name`, `description`, `type`, and `input_schema` (raw JSON);
|
||||
- top-level `system` (string or text block array).
|
||||
4. If non-text binary content is detected (e.g., images, audio, base64/url/file documents), the request is passed through. The plugin counts all text-bearing fields; unknown non-text blocks are treated as multimodal and bypassed.
|
||||
5. The extracted text is tokenized, multiplied by `buffer_ratio`, and compared with `max_context_tokens`.
|
||||
6. If `estimated_tokens > max_context_tokens`, the plugin returns an OpenAI-compatible `context_length_exceeded` response.
|
||||
|
||||
## Token Estimation Strategy
|
||||
|
||||
The first version uses a single embedded BPE vocabulary and initializes the tokenizer once during plugin startup. This avoids runtime downloads in the WASM sandbox and keeps request processing fully local.
|
||||
|
||||
The default `buffer_ratio` is `1.10`. Internal validation on long Chinese documents, mixed RAG content, code, and multi-turn conversations showed that the buffer covers observed under-estimation cases while keeping the implementation simple and deterministic.
|
||||
|
||||
## Error Response
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 128000 tokens. Your request had approximately 140000 tokens.",
|
||||
"type": "invalid_request_error",
|
||||
"code": "context_length_exceeded"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Validation
|
||||
|
||||
The implementation includes unit tests for:
|
||||
|
||||
- configuration defaults and invalid values;
|
||||
- text extraction from messages, tools, and top-level system fields;
|
||||
- multimodal detection and pass-through behavior;
|
||||
- token counting basics;
|
||||
- strict threshold comparison.
|
||||
|
||||
The plugin has also been verified with:
|
||||
|
||||
- `go test ./...`
|
||||
- `go vet ./...`
|
||||
- `make local-build PLUGIN_NAME=ai-context-limit`
|
||||
384
plugins/wasm-go/extensions/ai-context-limit/extract.go
Normal file
384
plugins/wasm-go/extensions/ai-context-limit/extract.go
Normal file
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2026 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// extractResult 文本抽取结果
|
||||
type extractResult struct {
|
||||
// Text 拼接后的所有可计 token 文本
|
||||
Text string
|
||||
// HasMultimodal 是否检测到非 text 类型 part(image_url/audio/...),命中即放行
|
||||
HasMultimodal bool
|
||||
}
|
||||
|
||||
// extractPromptText 从请求体抽取所有需要计入 input tokens 的文本。
|
||||
//
|
||||
// 协议识别策略:通过检测 Anthropic 特有字段(tools[].input_schema、
|
||||
// content type=tool_use/tool_result/thinking/redacted_thinking/document/search_result)
|
||||
// 来判断是否为 Anthropic 协议请求。
|
||||
// 普通纯文本请求即使是 Anthropic 格式,走 OpenAI 路径也能正确统计。
|
||||
//
|
||||
// 多模态降级:非文本二进制内容(image/audio/base64 document)视为多模态,
|
||||
// 整个请求放行。
|
||||
func extractPromptText(body []byte) extractResult {
|
||||
if hasAnthropicSpecificFields(body) {
|
||||
return extractAnthropicText(body)
|
||||
}
|
||||
return extractOpenAIText(body)
|
||||
}
|
||||
|
||||
// hasAnthropicSpecificFields 保守识别 Anthropic 协议特有字段。
|
||||
//
|
||||
// 强信号(任一命中即判定为 Anthropic):
|
||||
// - tools[].input_schema 存在(OpenAI 用 tools[].function.parameters)
|
||||
// - tools[] 中存在无 function 包装但有 name 的条目(Anthropic server tools)
|
||||
// - messages[].content[] 中含 type=tool_use/tool_result/thinking/redacted_thinking/document/search_result
|
||||
//
|
||||
// 不以 content array + type=text 判断(OpenAI 多模态也有此结构)。
|
||||
func hasAnthropicSpecificFields(body []byte) bool {
|
||||
// 检查 tools[]
|
||||
tools := gjson.GetBytes(body, "tools").Array()
|
||||
for _, tool := range tools {
|
||||
if tool.Get("input_schema").Exists() {
|
||||
return true
|
||||
}
|
||||
if tool.Get("name").Exists() && !tool.Get("function").Exists() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 检查 messages[].content[] 中的 Anthropic 特有 block types
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for _, msg := range messages {
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
for _, part := range content.Array() {
|
||||
switch part.Get("type").String() {
|
||||
case "tool_use", "tool_result", "thinking", "redacted_thinking", "document", "search_result":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Chat Completions extractor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractOpenAIText 从 OpenAI Chat Completions 请求体抽取文本。
|
||||
//
|
||||
// 协议参考:https://platform.openai.com/docs/api-reference/chat/create
|
||||
//
|
||||
// 抽取范围:
|
||||
// - messages[].role / name / content(string 或 text parts array)
|
||||
// - messages[].tool_calls[].function.{name, arguments}
|
||||
// - tools[].function.{name, description, parameters}
|
||||
// - response_format.json_schema.{name, description, schema}
|
||||
// - 顶层 system 字段(兼容将 system prompt 放在顶层的协议)
|
||||
func extractOpenAIText(body []byte) extractResult {
|
||||
var sb strings.Builder
|
||||
result := extractResult{}
|
||||
|
||||
// 1. messages[]
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for _, msg := range messages {
|
||||
if name := msg.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if role := msg.Get("role").String(); role != "" {
|
||||
sb.WriteString(role)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
content := msg.Get("content")
|
||||
switch {
|
||||
case content.Type == gjson.String:
|
||||
sb.WriteString(content.String())
|
||||
sb.WriteByte('\n')
|
||||
case content.IsArray():
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "text" {
|
||||
sb.WriteString(part.Get("text").String())
|
||||
sb.WriteByte('\n')
|
||||
continue
|
||||
}
|
||||
// 任意非 text part → 多模态,立即返回触发放行
|
||||
result.HasMultimodal = true
|
||||
return result
|
||||
}
|
||||
}
|
||||
// messages[].tool_calls[](多轮对话中 assistant 的工具调用参数)
|
||||
toolCalls := msg.Get("tool_calls").Array()
|
||||
for _, tc := range toolCalls {
|
||||
fn := tc.Get("function")
|
||||
if !fn.Exists() {
|
||||
continue
|
||||
}
|
||||
if name := fn.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if args := fn.Get("arguments").String(); args != "" {
|
||||
sb.WriteString(args)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. tools[]
|
||||
tools := gjson.GetBytes(body, "tools").Array()
|
||||
for _, tool := range tools {
|
||||
fn := tool.Get("function")
|
||||
if !fn.Exists() {
|
||||
continue
|
||||
}
|
||||
if name := fn.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if desc := fn.Get("description").String(); desc != "" {
|
||||
sb.WriteString(desc)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if params := fn.Get("parameters"); params.Exists() {
|
||||
sb.WriteString(params.Raw)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
// 3. response_format.json_schema(结构化输出 schema 计入 input tokens)
|
||||
jsonSchema := gjson.GetBytes(body, "response_format.json_schema")
|
||||
if jsonSchema.Exists() {
|
||||
if name := jsonSchema.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if desc := jsonSchema.Get("description").String(); desc != "" {
|
||||
sb.WriteString(desc)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if schema := jsonSchema.Get("schema"); schema.Exists() {
|
||||
sb.WriteString(schema.Raw)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 顶层 system 字段
|
||||
extractTopLevelSystem(body, &sb)
|
||||
|
||||
result.Text = sb.String()
|
||||
return result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic Messages extractor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractAnthropicText 从 Anthropic Messages 请求体抽取文本。
|
||||
//
|
||||
// 协议参考:https://docs.anthropic.com/en/api/messages
|
||||
//
|
||||
// 抽取范围:
|
||||
// - system:string 或 text block array
|
||||
// - messages[].role
|
||||
// - messages[].content:string 或 content block array
|
||||
// - type=text → text 字段
|
||||
// - type=tool_use → name + input(raw JSON)
|
||||
// - type=tool_result → content(string 或 content block array)
|
||||
// - type=thinking → thinking 字段
|
||||
// - type=redacted_thinking → data 字段
|
||||
// - type=document → source.type=text 时计入,其他视为多模态
|
||||
// - type=search_result → title + source + content text blocks
|
||||
// - tools[].name / description / type / input_schema(raw JSON)
|
||||
func extractAnthropicText(body []byte) extractResult {
|
||||
var sb strings.Builder
|
||||
result := extractResult{}
|
||||
|
||||
// 1. system(string 或 text block array)
|
||||
extractTopLevelSystem(body, &sb)
|
||||
|
||||
// 2. messages[]
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for _, msg := range messages {
|
||||
if role := msg.Get("role").String(); role != "" {
|
||||
sb.WriteString(role)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
content := msg.Get("content")
|
||||
switch {
|
||||
case content.Type == gjson.String:
|
||||
sb.WriteString(content.String())
|
||||
sb.WriteByte('\n')
|
||||
case content.IsArray():
|
||||
for _, part := range content.Array() {
|
||||
if extractAnthropicContentBlock(part, &sb) {
|
||||
result.HasMultimodal = true
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. tools[]
|
||||
tools := gjson.GetBytes(body, "tools").Array()
|
||||
for _, tool := range tools {
|
||||
if tp := tool.Get("type").String(); tp != "" {
|
||||
sb.WriteString(tp)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if name := tool.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if desc := tool.Get("description").String(); desc != "" {
|
||||
sb.WriteString(desc)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if schema := tool.Get("input_schema"); schema.Exists() {
|
||||
sb.WriteString(schema.Raw)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
result.Text = sb.String()
|
||||
return result
|
||||
}
|
||||
|
||||
// extractAnthropicContentBlock 统一处理单个 Anthropic content block。
|
||||
// 顶层 messages[].content[] 和 tool_result.content[] 均复用此函数。
|
||||
// 返回 true 表示发现多模态内容(需放行)。
|
||||
func extractAnthropicContentBlock(part gjson.Result, sb *strings.Builder) bool {
|
||||
t := part.Get("type").String()
|
||||
switch t {
|
||||
case "text":
|
||||
sb.WriteString(part.Get("text").String())
|
||||
sb.WriteByte('\n')
|
||||
case "tool_use":
|
||||
if name := part.Get("name").String(); name != "" {
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
sb.WriteString(input.Raw)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
case "tool_result":
|
||||
content := part.Get("content")
|
||||
switch {
|
||||
case content.Type == gjson.String:
|
||||
sb.WriteString(content.String())
|
||||
sb.WriteByte('\n')
|
||||
case content.IsArray():
|
||||
for _, block := range content.Array() {
|
||||
if extractAnthropicContentBlock(block, sb) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
case "thinking":
|
||||
if thinking := part.Get("thinking").String(); thinking != "" {
|
||||
sb.WriteString(thinking)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
case "redacted_thinking":
|
||||
if data := part.Get("data").String(); data != "" {
|
||||
sb.WriteString(data)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
case "document":
|
||||
return extractAnthropicDocument(part, sb)
|
||||
case "search_result":
|
||||
extractAnthropicSearchResult(part, sb)
|
||||
default:
|
||||
// 真正的非文本 block(image/audio/等)视为多模态
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractAnthropicDocument 处理 Anthropic document content block。
|
||||
// source.type=="text" 时抽取文本内容,其他类型(base64/url/file)视为多模态。
|
||||
// 返回 true 表示多模态(需放行)。
|
||||
func extractAnthropicDocument(part gjson.Result, sb *strings.Builder) bool {
|
||||
sourceType := part.Get("source.type").String()
|
||||
if sourceType == "text" {
|
||||
// 纯文本文档,计入 token
|
||||
if title := part.Get("title").String(); title != "" {
|
||||
sb.WriteString(title)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if data := part.Get("source.data").String(); data != "" {
|
||||
sb.WriteString(data)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
return false
|
||||
}
|
||||
// base64/url/file 等非文本源 → 多模态
|
||||
return true
|
||||
}
|
||||
|
||||
// extractAnthropicSearchResult 处理 Anthropic search_result content block。
|
||||
// 抽取 title、source 和 content[] 中的 text blocks。
|
||||
func extractAnthropicSearchResult(part gjson.Result, sb *strings.Builder) {
|
||||
if title := part.Get("title").String(); title != "" {
|
||||
sb.WriteString(title)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if source := part.Get("source").String(); source != "" {
|
||||
sb.WriteString(source)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
// content 是 text blocks 数组
|
||||
contentArr := part.Get("content").Array()
|
||||
for _, block := range contentArr {
|
||||
if block.Get("type").String() == "text" {
|
||||
sb.WriteString(block.Get("text").String())
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 共用辅助函数
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractTopLevelSystem 抽取顶层 system 字段(string 或 text block array)。
|
||||
// OpenAI 和 Anthropic 均可能使用顶层 system。
|
||||
func extractTopLevelSystem(body []byte, sb *strings.Builder) {
|
||||
sys := gjson.GetBytes(body, "system")
|
||||
if !sys.Exists() {
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case sys.Type == gjson.String:
|
||||
sb.WriteString(sys.String())
|
||||
sb.WriteByte('\n')
|
||||
case sys.IsArray():
|
||||
for _, part := range sys.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
sb.WriteString(part.Get("text").String())
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
plugins/wasm-go/extensions/ai-context-limit/go.mod
Normal file
25
plugins/wasm-go/extensions/ai-context-limit/go.mod
Normal file
@@ -0,0 +1,25 @@
|
||||
module ai-context-limit
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
32
plugins/wasm-go/extensions/ai-context-limit/go.sum
Normal file
32
plugins/wasm-go/extensions/ai-context-limit/go.sum
Normal file
@@ -0,0 +1,32 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
153
plugins/wasm-go/extensions/ai-context-limit/main.go
Normal file
153
plugins/wasm-go/extensions/ai-context-limit/main.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2026 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package main 实现 ai-context-limit Higress WASM 插件。
|
||||
//
|
||||
// 插件会在 OpenAI / Anthropic 等协议兼容请求到达上游模型之前估算输入 token 数,
|
||||
// 并对超过配置阈值的请求提前返回错误响应。
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
// 在插件加载时一次性初始化 token 编码器。
|
||||
if err := initEncoder(); err != nil {
|
||||
// 初始化失败为致命错误:缺少编码器后续无法计算 token,
|
||||
// 但 wasm 运行期不能 panic,记录后所有请求走"未启用"兜底路径
|
||||
// 实际触发概率极低(embed 词表打包失败才会出现)
|
||||
// log 包在 init() 中尚不可用,使用 fmt.Println 兜底
|
||||
fmt.Println("[ai-context-limit] init encoder failed:", err)
|
||||
}
|
||||
wrapper.SetCtx(
|
||||
"ai-context-limit",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
// onHttpRequestHeaders 处理请求头阶段
|
||||
//
|
||||
// 关键约束:
|
||||
// - envoy 默认 http filter buffer 仅 14.3KB,必须在此阶段调 SetRequestBodyBufferLimit
|
||||
// - 必须返回 HeaderStopIteration,否则 envoy 不会等待 body 阶段
|
||||
// - 非 JSON 请求直接放行,不读 body
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg Config) types.Action {
|
||||
ctx.DisableReroute()
|
||||
|
||||
if !cfg.IsEnabled() {
|
||||
// 配置缺失时降级为不拦截,允许用户通过配置开关此插件
|
||||
log.Warnf("max_context_tokens not configured, plugin disabled for this request")
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
if !strings.Contains(strings.ToLower(contentType), "application/json") {
|
||||
log.Debugf("non-json content-type=%q, skip body inspection", contentType)
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !ctx.HasRequestBody() {
|
||||
log.Debugf("no request body, skip")
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// 强制调大 envoy http downstream decoder buffer
|
||||
// 写入 envoy property: set_decoder_buffer_limit
|
||||
ctx.SetRequestBodyBufferLimit(MaxRequestBodyBytes)
|
||||
// 移除 content-length,body 处理后由 envoy 重新计算
|
||||
_ = proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
// 暂停 header 流转,等待 onHttpRequestBody 处理完
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
// onHttpRequestBody 处理请求体阶段
|
||||
//
|
||||
// 流程:
|
||||
// 1. 抽取请求体中所有需计 token 的文本(兼容 OpenAI / Anthropic 等协议)
|
||||
// 2. 命中多模态(image_url/audio)→ 直接放行
|
||||
// 3. token 计数 → ×buffer_ratio → 与阈值比较
|
||||
// 4. 超阈值 → 发送 local response,OpenAI 风格错误体
|
||||
//
|
||||
// 各阶段统一 info 级耗时日志([aicl])方便 grep 与基准对照。
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, cfg Config, body []byte) types.Action {
|
||||
if !cfg.IsEnabled() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
bodyBytes := len(body)
|
||||
log.Infof("[aicl] body_received bytes=%d", bodyBytes)
|
||||
|
||||
if encoder == nil {
|
||||
log.Errorf("[aicl] token encoder not initialized, skip token counting")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
result := extractPromptText(body)
|
||||
extractMs := time.Since(t0).Milliseconds()
|
||||
log.Infof("[aicl] extract_done bytes=%d text_bytes=%d multimodal=%v elapsed_ms=%d",
|
||||
bodyBytes, len(result.Text), result.HasMultimodal, extractMs)
|
||||
|
||||
if result.HasMultimodal {
|
||||
log.Debugf("[aicl] multimodal request detected, bypass token counting")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
t1 := time.Now()
|
||||
rawTokens := CountTokens(result.Text)
|
||||
encodeMs := time.Since(t1).Milliseconds()
|
||||
estimatedTokens := int(float64(rawTokens) * cfg.BufferRatio)
|
||||
log.Infof("[aicl] encode_done bytes=%d text_bytes=%d raw_tokens=%d estimated=%d "+
|
||||
"threshold=%d extract_ms=%d encode_ms=%d total_ms=%d",
|
||||
bodyBytes, len(result.Text), rawTokens, estimatedTokens,
|
||||
cfg.MaxContextTokens, extractMs, encodeMs, extractMs+encodeMs)
|
||||
|
||||
if estimatedTokens > cfg.MaxContextTokens {
|
||||
return blockOverLimit(cfg, estimatedTokens)
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// blockOverLimit 发送 OpenAI 风格的超限错误响应
|
||||
//
|
||||
// 响应体复刻 OpenAI 官方 context_length_exceeded 格式,
|
||||
// 使客户端 SDK(openai-python / openai-node)可解析为 BadRequestError
|
||||
func blockOverLimit(cfg Config, estimatedTokens int) types.Action {
|
||||
body := fmt.Sprintf(
|
||||
`{"error":{"message":"This model's maximum context length is %d tokens. `+
|
||||
`Your request had approximately %d tokens.",`+
|
||||
`"type":"invalid_request_error","code":"context_length_exceeded"}}`,
|
||||
cfg.MaxContextTokens, estimatedTokens,
|
||||
)
|
||||
headers := [][2]string{{"content-type", "application/json"}}
|
||||
if err := proxywasm.SendHttpResponse(uint32(cfg.ErrorStatusCode), headers, []byte(body), -1); err != nil {
|
||||
log.Errorf("send local response failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Infof("blocked: estimated %d > limit %d", estimatedTokens, cfg.MaxContextTokens)
|
||||
return types.ActionContinue
|
||||
}
|
||||
730
plugins/wasm-go/extensions/ai-context-limit/main_test.go
Normal file
730
plugins/wasm-go/extensions/ai-context-limit/main_test.go
Normal file
@@ -0,0 +1,730 @@
|
||||
// Copyright (c) 2026 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMax int
|
||||
wantCode int
|
||||
wantRatio float64
|
||||
wantOk bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "完整配置",
|
||||
input: `{"max_context_tokens":128000,"error_status_code":413,"buffer_ratio":1.2}`,
|
||||
wantMax: 128000,
|
||||
wantCode: 413,
|
||||
wantRatio: 1.2,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "仅必填字段,其余取默认值",
|
||||
input: `{"max_context_tokens":32000}`,
|
||||
wantMax: 32000,
|
||||
wantCode: defaultErrorStatusCode,
|
||||
wantRatio: defaultBufferRatio,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "缺失阈值不抛错,IsEnabled=false",
|
||||
input: `{}`,
|
||||
wantMax: 0,
|
||||
wantCode: 0,
|
||||
wantRatio: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "阈值为 0 视为未启用",
|
||||
input: `{"max_context_tokens":0}`,
|
||||
wantMax: 0,
|
||||
wantCode: 0,
|
||||
wantRatio: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "max_context_tokens 负数拒绝",
|
||||
input: `{"max_context_tokens":-1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "buffer_ratio 负数拒绝",
|
||||
input: `{"max_context_tokens":1000,"buffer_ratio":-1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error_status_code=200 拒绝",
|
||||
input: `{"max_context_tokens":1000,"error_status_code":200}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error_status_code=600 拒绝",
|
||||
input: `{"max_context_tokens":1000,"error_status_code":600}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "buffer_ratio=11 拒绝",
|
||||
input: `{"max_context_tokens":1000,"buffer_ratio":11}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "buffer_ratio=10 边界允许",
|
||||
input: `{"max_context_tokens":1000,"buffer_ratio":10}`,
|
||||
wantMax: 1000,
|
||||
wantCode: defaultErrorStatusCode,
|
||||
wantRatio: 10,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "error_status_code=599 边界允许",
|
||||
input: `{"max_context_tokens":1000,"error_status_code":599}`,
|
||||
wantMax: 1000,
|
||||
wantCode: 599,
|
||||
wantRatio: defaultBufferRatio,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var cfg Config
|
||||
err := parseConfig(gjson.Parse(tc.input), &cfg)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.wantMax, cfg.MaxContextTokens)
|
||||
assert.Equal(t, tc.wantCode, cfg.ErrorStatusCode)
|
||||
assert.InDelta(t, tc.wantRatio, cfg.BufferRatio, 1e-9)
|
||||
assert.Equal(t, tc.wantOk, cfg.IsEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLightweightE2E 轻量端到端验证:
|
||||
// 用低阈值跑完 extract + CountTokens + 判定,确认新增字段真实影响拦截/放行决策。
|
||||
func TestLightweightE2E(t *testing.T) {
|
||||
require := assert.New(t)
|
||||
require.NoError(initEncoder())
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
body []byte
|
||||
maxTokens int
|
||||
wantBlock bool
|
||||
}{
|
||||
{
|
||||
name: "OpenAI tool_calls.arguments 超阈值 → 400",
|
||||
body: []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "go"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "c1", "type": "function", "function": {
|
||||
"name": "big_query",
|
||||
"arguments": "{\"sql\":\"SELECT a]very long query that should push tokens over the low threshold we set for this test, including multiple columns like id, name, email, phone, address, city, state, zip, country, created_at, updated_at, deleted_at FROM users WHERE status = active AND region IN (us-east-1, us-west-2, eu-west-1, ap-southeast-1) ORDER BY created_at DESC LIMIT 1000\"}"
|
||||
}}]}
|
||||
]
|
||||
}`),
|
||||
maxTokens: 5,
|
||||
wantBlock: true,
|
||||
},
|
||||
{
|
||||
name: "OpenAI response_format.json_schema 超阈值 → 400",
|
||||
body: []byte(`{
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"response_format": {"type": "json_schema", "json_schema": {
|
||||
"name": "big_schema",
|
||||
"description": "A very detailed schema for structured extraction of complex nested data",
|
||||
"schema": {"type": "object", "properties": {"a": {"type": "string"}, "b": {"type": "integer"}, "c": {"type": "array", "items": {"type": "object", "properties": {"d": {"type": "string"}}}}}}
|
||||
}}
|
||||
}`),
|
||||
maxTokens: 5,
|
||||
wantBlock: true,
|
||||
},
|
||||
{
|
||||
name: "Anthropic tools.input_schema 超阈值 → 400",
|
||||
body: []byte(`{
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "search", "description": "Search the database with complex filters", "input_schema": {"type": "object", "properties": {"query": {"type": "string"}, "filters": {"type": "array", "items": {"type": "object", "properties": {"field": {"type": "string"}, "op": {"type": "string"}, "value": {"type": "string"}}}}}}}]
|
||||
}`),
|
||||
maxTokens: 5,
|
||||
wantBlock: true,
|
||||
},
|
||||
{
|
||||
name: "Anthropic 短文本 → 放行",
|
||||
body: []byte(`{
|
||||
"system": "ok",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "t", "input_schema": {"type": "object"}}]
|
||||
}`),
|
||||
maxTokens: 100,
|
||||
wantBlock: false,
|
||||
},
|
||||
{
|
||||
name: "Anthropic image block → 放行",
|
||||
body: []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`),
|
||||
maxTokens: 1,
|
||||
wantBlock: false, // 多模态放行,不管阈值多低
|
||||
},
|
||||
{
|
||||
name: "Anthropic thinking block 超阈值 → 400(无 tools)",
|
||||
body: []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "solve this"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "thinking", "thinking": "Let me reason through this carefully. First, I need to analyze the problem from multiple angles. The key insight is that we need to consider all boundary conditions and edge cases before arriving at a solution. This requires systematic decomposition of the constraints."},
|
||||
{"type": "text", "text": "The answer is 42."}
|
||||
]}
|
||||
]
|
||||
}`),
|
||||
maxTokens: 5,
|
||||
wantBlock: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := extractPromptText(tc.body)
|
||||
tokens := CountTokens(r.Text)
|
||||
estimated := int(float64(tokens) * 1.10)
|
||||
blocked := !r.HasMultimodal && estimated > tc.maxTokens
|
||||
|
||||
t.Logf("multimodal=%v tokens=%d estimated=%d threshold=%d blocked=%v",
|
||||
r.HasMultimodal, tokens, estimated, tc.maxTokens, blocked)
|
||||
|
||||
assert.Equal(t, tc.wantBlock, blocked)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic 协议场景测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicDetection(t *testing.T) {
|
||||
// OpenAI 请求不应触发 Anthropic 路径
|
||||
openaiBody := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"tools": [{"type": "function", "function": {"name": "foo", "parameters": {}}}]
|
||||
}`)
|
||||
assert.False(t, hasAnthropicSpecificFields(openaiBody), "OpenAI content array + type=text 不应误判为 Anthropic")
|
||||
|
||||
// Anthropic tools[].input_schema
|
||||
anthropicTools := []byte(`{
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "get_weather", "input_schema": {"type": "object"}}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(anthropicTools), "tools[].input_schema 必须识别为 Anthropic")
|
||||
|
||||
// Anthropic tool_use content block
|
||||
toolUseBody := []byte(`{
|
||||
"messages": [{"role": "assistant", "content": [
|
||||
{"type": "tool_use", "id": "tu_1", "name": "calc", "input": {"expr": "1+1"}}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(toolUseBody), "content type=tool_use 必须识别为 Anthropic")
|
||||
|
||||
// Anthropic tool_result content block
|
||||
toolResultBody := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "tu_1", "content": "2"}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(toolResultBody), "content type=tool_result 必须识别为 Anthropic")
|
||||
|
||||
// 仅含 thinking block(无 tools)也应识别为 Anthropic
|
||||
thinkingOnly := []byte(`{
|
||||
"messages": [{"role": "assistant", "content": [
|
||||
{"type": "thinking", "thinking": "reasoning..."}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(thinkingOnly), "thinking block 必须识别为 Anthropic")
|
||||
|
||||
// 仅含 redacted_thinking block
|
||||
redactedOnly := []byte(`{
|
||||
"messages": [{"role": "assistant", "content": [
|
||||
{"type": "redacted_thinking", "data": "xxx"}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(redactedOnly), "redacted_thinking block 必须识别为 Anthropic")
|
||||
|
||||
// 仅含 document block
|
||||
docOnly := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "document", "source": {"type": "text", "data": "..."}}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(docOnly), "document block 必须识别为 Anthropic")
|
||||
|
||||
// 仅含 search_result block
|
||||
searchOnly := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "search_result", "title": "t", "content": []}
|
||||
]}]
|
||||
}`)
|
||||
assert.True(t, hasAnthropicSpecificFields(searchOnly), "search_result block 必须识别为 Anthropic")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_ToolUseAndResult(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"system": "You are a helpful assistant",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "text", "text": "Let me calculate that."},
|
||||
{"type": "tool_use", "id": "tu_1", "name": "calculator", "input": {"expression": "2+2"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "tu_1", "content": "4"}
|
||||
]},
|
||||
{"role": "assistant", "content": "The answer is 4."}
|
||||
],
|
||||
"tools": [
|
||||
{"name": "calculator", "description": "Evaluates math expressions", "input_schema": {"type": "object", "properties": {"expression": {"type": "string"}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
// system
|
||||
assert.Contains(t, r.Text, "You are a helpful assistant")
|
||||
// messages content string
|
||||
assert.Contains(t, r.Text, "What is 2+2?")
|
||||
assert.Contains(t, r.Text, "The answer is 4.")
|
||||
// tool_use: name + input
|
||||
assert.Contains(t, r.Text, "calculator")
|
||||
assert.Contains(t, r.Text, "expression")
|
||||
assert.Contains(t, r.Text, "2+2")
|
||||
// tool_result: content string
|
||||
assert.Contains(t, r.Text, "4")
|
||||
// tools[].input_schema
|
||||
assert.Contains(t, r.Text, "Evaluates math expressions")
|
||||
// text block in assistant
|
||||
assert.Contains(t, r.Text, "Let me calculate that.")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_ToolResultContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "tu_1", "content": [
|
||||
{"type": "text", "text": "Result line 1"},
|
||||
{"type": "text", "text": "Result line 2"}
|
||||
]}
|
||||
]}
|
||||
],
|
||||
"tools": [{"name": "dummy", "input_schema": {"type": "object"}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "Result line 1")
|
||||
assert.Contains(t, r.Text, "Result line 2")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_ImageMultimodal(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.True(t, r.HasMultimodal, "Anthropic image block 必须触发多模态放行")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_UnknownBlock(t *testing.T) {
|
||||
// 未知非文本 block(如 audio、unknown_binary)应触发多模态放行
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "text", "text": "listen to this"},
|
||||
{"type": "audio", "source": {"type": "base64", "data": "..."}}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.True(t, r.HasMultimodal, "未知非文本 block 必须触发多模态放行")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_ToolResultWithImage(t *testing.T) {
|
||||
// tool_result.content array 中包含非 text block 应触发多模态放行
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_result", "tool_use_id": "tu_1", "content": [
|
||||
{"type": "text", "text": "here is the result"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}}
|
||||
]}
|
||||
]}
|
||||
],
|
||||
"tools": [{"name": "screenshot", "input_schema": {"type": "object"}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.True(t, r.HasMultimodal, "tool_result 含非 text block 必须触发多模态放行")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_StringContent(t *testing.T) {
|
||||
// Anthropic 也支持 content 为纯字符串
|
||||
body := []byte(`{
|
||||
"system": [{"type": "text", "text": "system prompt"}],
|
||||
"messages": [{"role": "user", "content": "hello world"}],
|
||||
"tools": [{"name": "t1", "input_schema": {"type": "object"}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "system prompt")
|
||||
assert.Contains(t, r.Text, "hello world")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_ThinkingBlock(t *testing.T) {
|
||||
// Extended thinking block 应被计入,不触发多模态
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "solve this"},
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "thinking", "thinking": "Let me think about this step by step. First I need to consider the constraints and then work through the logic carefully."},
|
||||
{"type": "text", "text": "The answer is 42."}
|
||||
]}
|
||||
],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal, "thinking block 不应触发多模态")
|
||||
assert.Contains(t, r.Text, "step by step")
|
||||
assert.Contains(t, r.Text, "The answer is 42.")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_RedactedThinking(t *testing.T) {
|
||||
// Redacted thinking block 的 data 应被保守计入
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [
|
||||
{"type": "redacted_thinking", "data": "abc123encrypteddatahere456"}
|
||||
]}
|
||||
],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal, "redacted_thinking 不应触发多模态")
|
||||
assert.Contains(t, r.Text, "abc123encrypteddatahere456")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_DocumentText(t *testing.T) {
|
||||
// document source.type=text 应被计入
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "document", "title": "report.txt", "source": {"type": "text", "data": "This is a very long document content that should be counted as input tokens."}}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal, "text document 不应触发多模态")
|
||||
assert.Contains(t, r.Text, "report.txt")
|
||||
assert.Contains(t, r.Text, "very long document content")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_DocumentBase64(t *testing.T) {
|
||||
// document source.type=base64 应触发多模态放行
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "document", "title": "file.pdf", "source": {"type": "base64", "media_type": "application/pdf", "data": "..."}}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.True(t, r.HasMultimodal, "base64 document 应触发多模态放行")
|
||||
}
|
||||
|
||||
func TestExtractAnthropicText_SearchResult(t *testing.T) {
|
||||
// search_result 的 title/source/content text blocks 应被计入
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "search_result", "title": "Higress Documentation", "source": "https://higress.io/docs", "content": [
|
||||
{"type": "text", "text": "Higress is a cloud-native API gateway."},
|
||||
{"type": "text", "text": "It supports WASM plugins for extensibility."}
|
||||
]}
|
||||
]}],
|
||||
"tools": [{"name": "x", "input_schema": {}}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal, "search_result 不应触发多模态")
|
||||
assert.Contains(t, r.Text, "Higress Documentation")
|
||||
assert.Contains(t, r.Text, "https://higress.io/docs")
|
||||
assert.Contains(t, r.Text, "cloud-native API gateway")
|
||||
assert.Contains(t, r.Text, "WASM plugins")
|
||||
}
|
||||
|
||||
// TestVerifyToolCallsAndResponseFormat 端到端验证:
|
||||
// 真实场景请求体中的 tool_calls.arguments 和 response_format.json_schema
|
||||
// 确实被纳入 token 统计,不会被漏算。
|
||||
func TestVerifyToolCallsAndResponseFormat(t *testing.T) {
|
||||
require := assert.New(t)
|
||||
require.NoError(initEncoder())
|
||||
|
||||
// 构造包含大量 tool_calls arguments 的多轮对话
|
||||
bodyWithToolCalls := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "help"},
|
||||
{"role": "assistant", "tool_calls": [{
|
||||
"id": "call_1", "type": "function",
|
||||
"function": {
|
||||
"name": "search_database",
|
||||
"arguments": "{\"query\":\"SELECT id, name, email, phone, address, created_at, updated_at FROM users WHERE status = active AND region IN (us-east, us-west, eu-west) ORDER BY created_at DESC LIMIT 100\"}"
|
||||
}
|
||||
}]},
|
||||
{"role": "tool", "content": "found 100 rows", "tool_call_id": "call_1"}
|
||||
]
|
||||
}`)
|
||||
|
||||
// 同样的请求但不带 tool_calls(模拟修复前的漏算场景)
|
||||
bodyWithoutToolCalls := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "help"},
|
||||
{"role": "assistant"},
|
||||
{"role": "tool", "content": "found 100 rows", "tool_call_id": "call_1"}
|
||||
]
|
||||
}`)
|
||||
|
||||
rWith := extractPromptText(bodyWithToolCalls)
|
||||
rWithout := extractPromptText(bodyWithoutToolCalls)
|
||||
|
||||
tokensWithToolCalls := CountTokens(rWith.Text)
|
||||
tokensWithoutToolCalls := CountTokens(rWithout.Text)
|
||||
|
||||
t.Logf("含 tool_calls: text_bytes=%d, tokens=%d", len(rWith.Text), tokensWithToolCalls)
|
||||
t.Logf("不含 tool_calls: text_bytes=%d, tokens=%d", len(rWithout.Text), tokensWithoutToolCalls)
|
||||
t.Logf("tool_calls 贡献的额外 tokens: %d", tokensWithToolCalls-tokensWithoutToolCalls)
|
||||
|
||||
// tool_calls.arguments 包含大段 SQL,必须贡献显著的额外 token
|
||||
require.Greater(tokensWithToolCalls, tokensWithoutToolCalls+10,
|
||||
"tool_calls.arguments 必须被计入 token 统计")
|
||||
|
||||
// 验证 response_format.json_schema 被统计
|
||||
bodyWithSchema := []byte(`{
|
||||
"messages": [{"role": "user", "content": "extract"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "extraction_result",
|
||||
"description": "A comprehensive schema for extracting structured order information including customer details and line items",
|
||||
"schema": {"type": "object", "properties": {"customer_name": {"type": "string"}, "order_id": {"type": "string"}, "items": {"type": "array", "items": {"type": "object", "properties": {"sku": {"type": "string"}, "qty": {"type": "integer"}, "price": {"type": "number"}}}}}}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
bodyWithoutSchema := []byte(`{
|
||||
"messages": [{"role": "user", "content": "extract"}]
|
||||
}`)
|
||||
|
||||
rSchema := extractPromptText(bodyWithSchema)
|
||||
rNoSchema := extractPromptText(bodyWithoutSchema)
|
||||
|
||||
tokensWithSchema := CountTokens(rSchema.Text)
|
||||
tokensNoSchema := CountTokens(rNoSchema.Text)
|
||||
|
||||
t.Logf("含 json_schema: text_bytes=%d, tokens=%d", len(rSchema.Text), tokensWithSchema)
|
||||
t.Logf("不含 json_schema: text_bytes=%d, tokens=%d", len(rNoSchema.Text), tokensNoSchema)
|
||||
t.Logf("json_schema 贡献的额外 tokens: %d", tokensWithSchema-tokensNoSchema)
|
||||
|
||||
// json_schema 包含大段 schema 定义,必须贡献显著的额外 token
|
||||
require.Greater(tokensWithSchema, tokensNoSchema+20,
|
||||
"response_format.json_schema 必须被计入 token 统计")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_StringContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是一个助手"},
|
||||
{"role": "user", "content": "Hello world"}
|
||||
]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "你是一个助手")
|
||||
assert.Contains(t, r.Text, "Hello world")
|
||||
assert.Contains(t, r.Text, "system")
|
||||
assert.Contains(t, r.Text, "user")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_ArrayContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "text", "text": "in detail"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "describe this")
|
||||
assert.Contains(t, r.Text, "in detail")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_Multimodal(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "what is in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.True(t, r.HasMultimodal, "image_url 必须触发多模态放行")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_Tools(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": "查询天气"}],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "获取指定城市的天气信息",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "查询天气")
|
||||
assert.Contains(t, r.Text, "get_weather")
|
||||
assert.Contains(t, r.Text, "获取指定城市的天气信息")
|
||||
// parameters 整体序列化进入计数
|
||||
assert.Contains(t, r.Text, "city")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_TopLevelSystem(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"system": "你是有帮助的助手",
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.Contains(t, r.Text, "你是有帮助的助手")
|
||||
assert.Contains(t, r.Text, "hi")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_Empty(t *testing.T) {
|
||||
r := extractPromptText([]byte(`{}`))
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Equal(t, "", r.Text)
|
||||
}
|
||||
|
||||
func TestExtractPromptText_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "查询订单"},
|
||||
{"role": "assistant", "tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {
|
||||
"name": "lookup_order",
|
||||
"arguments": "{\"order_id\":\"12345\",\"detail\":true}"
|
||||
}}
|
||||
]},
|
||||
{"role": "tool", "content": "订单已发货", "tool_call_id": "call_1"}
|
||||
]
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "lookup_order")
|
||||
assert.Contains(t, r.Text, "order_id")
|
||||
assert.Contains(t, r.Text, "12345")
|
||||
assert.Contains(t, r.Text, "订单已发货")
|
||||
}
|
||||
|
||||
func TestExtractPromptText_ResponseFormat(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [{"role": "user", "content": "extract info"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "order_schema",
|
||||
"description": "Schema for order extraction",
|
||||
"schema": {"type": "object", "properties": {"id": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
r := extractPromptText(body)
|
||||
assert.False(t, r.HasMultimodal)
|
||||
assert.Contains(t, r.Text, "order_schema")
|
||||
assert.Contains(t, r.Text, "Schema for order extraction")
|
||||
assert.Contains(t, r.Text, "properties")
|
||||
}
|
||||
|
||||
// TestCountTokens 只做基本可用性断言,避免在单测中绑定具体词表细节。
|
||||
func TestCountTokens(t *testing.T) {
|
||||
require := assert.New(t)
|
||||
require.NoError(initEncoder())
|
||||
|
||||
require.Equal(0, CountTokens(""), "空字符串返回 0")
|
||||
require.Greater(CountTokens("Hello world"), 0)
|
||||
require.Greater(CountTokens("中文测试"), 0)
|
||||
|
||||
// 重复文本 token 数应近似线性
|
||||
once := CountTokens("hello")
|
||||
thrice := CountTokens("hello hello hello")
|
||||
require.Greater(thrice, once)
|
||||
}
|
||||
|
||||
// TestBlockDecision 拦截判定逻辑(×buffer_ratio 与阈值比较)
|
||||
// 直接用真实编码器,构造 prompt 控制估算值的相对位置
|
||||
func TestBlockDecision(t *testing.T) {
|
||||
require := assert.New(t)
|
||||
require.NoError(initEncoder())
|
||||
|
||||
// 构造一段已知 token 数的文本
|
||||
prompt := "Hello world. This is a test prompt for token counting."
|
||||
rawTokens := CountTokens(prompt)
|
||||
require.Greater(rawTokens, 0)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
bufferRatio float64
|
||||
threshold int
|
||||
shouldBlock bool
|
||||
}{
|
||||
{"远低于阈值 → 放行", 1.10, 100000, false},
|
||||
{"略低于阈值 → 放行", 1.10, rawTokens * 2, false},
|
||||
{"恰好等于阈值 → 放行(>不>=)", 1.0, rawTokens, false},
|
||||
{"略超阈值 → 拦截", 1.10, 1, true},
|
||||
{"buffer_ratio 抬高致超阈值", 10.0, rawTokens + 1, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
estimated := int(float64(rawTokens) * tc.bufferRatio)
|
||||
got := estimated > tc.threshold
|
||||
assert.Equal(t, tc.shouldBlock, got,
|
||||
"raw=%d ratio=%.2f estimated=%d threshold=%d",
|
||||
rawTokens, tc.bufferRatio, estimated, tc.threshold)
|
||||
})
|
||||
}
|
||||
}
|
||||
27
plugins/wasm-go/extensions/ai-context-limit/prepare.sh
Executable file
27
plugins/wasm-go/extensions/ai-context-limit/prepare.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/sh
|
||||
# prepare.sh — Download BPE vocabulary for ai-context-limit plugin.
|
||||
# Called by Makefile, root Dockerfile, and root Makefile local-build.
|
||||
|
||||
set -e
|
||||
|
||||
BPE_DIR="bpe"
|
||||
BPE_FILE="${BPE_DIR}/o200k_base.tiktoken"
|
||||
BPE_URL="https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
|
||||
|
||||
if [ -f "$BPE_FILE" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mkdir -p "$BPE_DIR"
|
||||
echo "Downloading o200k_base.tiktoken..."
|
||||
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
curl -sSfL -o "$BPE_FILE" "$BPE_URL"
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
wget -q -O "$BPE_FILE" "$BPE_URL"
|
||||
else
|
||||
echo "Error: curl or wget is required to download BPE vocabulary" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Downloaded ${BPE_FILE} ($(wc -c < "$BPE_FILE" | tr -d ' ') bytes)"
|
||||
83
plugins/wasm-go/extensions/ai-context-limit/tokenizer.go
Normal file
83
plugins/wasm-go/extensions/ai-context-limit/tokenizer.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2026 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// 内嵌 o200k_base 词表(约 3.4MB)
|
||||
//
|
||||
//go:embed bpe/o200k_base.tiktoken
|
||||
var o200kBaseRaw []byte
|
||||
|
||||
// embedBpeLoader 实现 tiktoken-go 的 BpeLoader 接口
|
||||
// 离线加载内嵌词表,避免运行时下载(WASM 环境无外网访问)
|
||||
type embedBpeLoader struct{}
|
||||
|
||||
// LoadTiktokenBpe 解析 .tiktoken 格式(每行 "<base64-token> <rank>")。
|
||||
// 内嵌词表为静态资源,任何解析异常均视为打包错误,直接返回 error。
|
||||
func (l *embedBpeLoader) LoadTiktokenBpe(_ string) (map[string]int, error) {
|
||||
bpeRanks := make(map[string]int, 200000)
|
||||
for i, line := range strings.Split(string(o200kBaseRaw), "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("bpe line %d: expected \"<base64> <rank>\", got %q", i+1, line)
|
||||
}
|
||||
token, err := base64.StdEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bpe line %d: base64 decode failed: %w", i+1, err)
|
||||
}
|
||||
rank, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bpe line %d: invalid rank %q: %w", i+1, parts[1], err)
|
||||
}
|
||||
bpeRanks[string(token)] = rank
|
||||
}
|
||||
return bpeRanks, nil
|
||||
}
|
||||
|
||||
// encoder 全局编码器实例(init 阶段加载,零拷贝复用)
|
||||
var encoder *tiktoken.Tiktoken
|
||||
|
||||
// initEncoder 初始化 o200k_base 编码器
|
||||
// 必须在插件 parseConfig 阶段或更早被调用
|
||||
func initEncoder() error {
|
||||
tiktoken.SetBpeLoader(&embedBpeLoader{})
|
||||
enc, err := tiktoken.GetEncoding("o200k_base")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
encoder = enc
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountTokens 用 o200k_base 编码计算文本 token 数
|
||||
// 输入空字符串返回 0
|
||||
func CountTokens(text string) int {
|
||||
if text == "" || encoder == nil {
|
||||
return 0
|
||||
}
|
||||
return len(encoder.Encode(text, nil, nil))
|
||||
}
|
||||
Reference in New Issue
Block a user